Skip to content

automatic differentiation with Enzyme #452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 105 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
105 commits
Select commit Hold shift + click to select a range
70f86f6
convert HLO to StableHLO
joelberkeley Dec 1, 2024
866d15a
working sometimes
joelberkeley Dec 1, 2024
793bfa1
working! thanks kevin
joelberkeley Dec 2, 2024
e07c7f8
wip
joelberkeley Dec 2, 2024
e1daab0
Merge branch 'master' into stablehlo-convert
joelberkeley Dec 2, 2024
c8bf80c
wip
joelberkeley Dec 2, 2024
b9a5f04
temp remove pack switch HEAD
joelberkeley Dec 2, 2024
fea8646
revert
joelberkeley Dec 3, 2024
364158b
wip
joelberkeley Dec 4, 2024
efab2a5
build mlir::ModuleOp
joelberkeley Dec 4, 2024
01c3cf2
working on linux
joelberkeley Dec 8, 2024
de00a64
wip
joelberkeley Dec 8, 2024
6cea6ba
wip
joelberkeley Dec 8, 2024
4f66569
tan
joelberkeley Dec 8, 2024
cf0d80f
wip
joelberkeley Dec 13, 2024
ed838f8
llvm
joelberkeley Dec 13, 2024
59a317d
tidy
joelberkeley Dec 13, 2024
b983e13
wip
joelberkeley Dec 13, 2024
189f40a
draft AD with enzyme
joelberkeley Dec 16, 2024
b2f7a65
Merge branch 'master' into stablehlo-ad
joelberkeley Dec 16, 2024
8abda25
shellcheck
joelberkeley Dec 16, 2024
edd3f39
shellcheck
joelberkeley Dec 16, 2024
c933e6a
update enzyme version
joelberkeley Dec 16, 2024
4e5a253
wip
joelberkeley Dec 19, 2024
8828373
first (almost e2e) draft
joelberkeley Dec 23, 2024
0a5f935
compiling with runtime errors
joelberkeley Dec 23, 2024
41bc6ad
wip
joelberkeley Dec 28, 2024
ff85828
moses suggestion
joelberkeley Jan 4, 2025
6bd295d
wip
joelberkeley Jan 4, 2025
768b5b5
everything
joelberkeley Jan 4, 2025
964c875
enz version
joelberkeley Jan 4, 2025
3460cb7
revert xla version
joelberkeley Jan 4, 2025
33843fc
wip
joelberkeley Jan 4, 2025
d50956a
wip
joelberkeley Jan 4, 2025
2b6d8dc
wip
joelberkeley Jan 4, 2025
857f3da
Merge branch 'master' into stablehlo-ad
joelberkeley Jan 20, 2025
69558cb
use loadDialect
joelberkeley Jan 20, 2025
8c319ee
wip
joelberkeley Jan 20, 2025
d4332df
start debugging properly
joelberkeley Jan 22, 2025
2e8f5d4
really solid progress
joelberkeley Jan 23, 2025
e7e94f9
most things there
joelberkeley Jan 26, 2025
351a0fe
almost there
joelberkeley Jan 27, 2025
44c7bdc
IR correct!!!!!!?!
joelberkeley Jan 28, 2025
272e122
wip
joelberkeley Jan 29, 2025
997bd92
wip
joelberkeley Jan 29, 2025
bdbd953
science, b****
joelberkeley Jan 31, 2025
3597b00
tidy
joelberkeley Feb 2, 2025
fa2e9a2
working, mostly
joelberkeley Feb 2, 2025
dd46941
remove unnecessary stuff
joelberkeley Feb 2, 2025
ca03fac
further simplify
joelberkeley Feb 2, 2025
5482770
wip
joelberkeley Feb 2, 2025
c204aa5
reverse mode
joelberkeley Feb 2, 2025
8796a9e
tidy
joelberkeley Feb 2, 2025
ee4bcea
add (broken) test for second derivative
joelberkeley Feb 3, 2025
d134f51
docs
joelberkeley Feb 3, 2025
fb810b1
wip
joelberkeley Feb 3, 2025
280fb2e
merge public XLADerivates
joelberkeley Feb 4, 2025
55a4be8
working to third order
joelberkeley Feb 4, 2025
4c08b5a
wip
joelberkeley Feb 5, 2025
b8c062a
wip
joelberkeley Feb 5, 2025
5a826e6
wip
joelberkeley Feb 5, 2025
afeef9a
wip
joelberkeley Feb 8, 2025
4a4855c
add reverse test
joelberkeley Feb 8, 2025
4a11594
working, with passes pulled out into individual functions
joelberkeley Feb 9, 2025
ede69bd
tidy
joelberkeley Feb 9, 2025
35a8d26
remove push_backs
joelberkeley Feb 9, 2025
a06dfbf
remove need for Region
joelberkeley Feb 9, 2025
e469973
wip
joelberkeley Feb 9, 2025
638fade
wip
joelberkeley Feb 9, 2025
2e3c47c
save migration progress to idris
joelberkeley Feb 21, 2025
be6b47b
wip
joelberkeley Mar 2, 2025
8525bf8
wip
joelberkeley Mar 18, 2025
9416a1f
wip
joelberkeley Mar 19, 2025
31a0055
wip
joelberkeley Mar 19, 2025
42f0452
wip
joelberkeley Mar 19, 2025
d1c6d44
wip
joelberkeley Mar 19, 2025
e2e779b
wip
joelberkeley Mar 21, 2025
356de13
idris compiling
joelberkeley Mar 22, 2025
f53161c
wip
joelberkeley Mar 22, 2025
62cac19
wip
joelberkeley Mar 23, 2025
263a1ca
wip
joelberkeley Mar 23, 2025
6c1ab18
wip
joelberkeley Mar 23, 2025
bfddbac
wip
joelberkeley Mar 23, 2025
d16c2db
wip
joelberkeley Mar 23, 2025
85d85a3
wip
joelberkeley Mar 23, 2025
d9a8ff9
working!
joelberkeley Mar 24, 2025
19fa52e
wip
joelberkeley Mar 24, 2025
f55af55
wip
joelberkeley Mar 24, 2025
5d037b6
tidy
joelberkeley Mar 24, 2025
0fd9a32
year
joelberkeley Mar 24, 2025
2c50467
year
joelberkeley Mar 24, 2025
6c2afc7
tidy
joelberkeley Mar 24, 2025
b950b1c
Merge branch 'master' into stablehlo-ad
joelberkeley Mar 25, 2025
7a55be2
delete some unused modules
joelberkeley Mar 25, 2025
4544f2f
wip
joelberkeley Mar 25, 2025
6544b6a
wip
joelberkeley Mar 25, 2025
6257e34
wip
joelberkeley Mar 25, 2025
b4355ec
wip
joelberkeley Mar 25, 2025
5aa3bd7
wip
joelberkeley Mar 25, 2025
cb5fa98
Merge branch 'master' into stablehlo-ad
joelberkeley Apr 26, 2025
9a569bc
Merge branch 'master' into stablehlo-ad
joelberkeley Apr 26, 2025
b0f0615
opd
joelberkeley Apr 26, 2025
6b9d12a
deletes
joelberkeley Apr 27, 2025
624feae
wip
joelberkeley Apr 27, 2025
1dad11a
fix gc
joelberkeley Apr 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion XLA_VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2fb20601f1cc6cab7f29f8bc73d90cd31e74bba0
4ec7e2a7721ace136bae967e6881fa7035d0b35a
17 changes: 10 additions & 7 deletions dev.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
#!/bin/sh -e

install_xla () {
if [ -z "$2" ]; then
echo "Usage: install_xla <xla-revision> <install-path>."
exit 1;
fi

install_git_repository () {
if [ "$(ls -A "$2")" ]; then
echo "Directory at path $2 is not empty, refusing to install XLA to this directory."
exit 1;
Expand All @@ -14,8 +9,16 @@ install_xla () {
(
cd "$2"
git init
git remote add origin https://github.com/openxla/xla
git remote add origin "$3"
git fetch --depth 1 origin "$1"
git checkout FETCH_HEAD
)
}

install_xla () {
install_git_repository "$1" "$2" https://github.com/openxla/xla
}

install_enzyme () {
install_git_repository "$1" "$2" https://github.com/EnzymeAD/Enzyme-JAX.git
}
1 change: 1 addition & 0 deletions pjrt-plugins/xla-cpu/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
xla/
3 changes: 2 additions & 1 deletion pjrt-plugins/xla-cpu/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ case $osu in
;;
esac

xla_dir=$(mktemp -d)
xla_dir=pjrt-plugins/xla-cpu/xla
mkdir "$xla_dir"
install_xla "$xla_version" "$xla_dir"
(
cd "$xla_dir"
Expand Down
3 changes: 2 additions & 1 deletion pjrt-plugins/xla-cuda/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ case $osu in
;;
esac

xla_dir=$(mktemp -d)
xla_dir=pjrt-plugins/xla-cuda/xla
mkdir "$xla_dir"
install_xla "$xla_version" "$xla_dir"
(
cd "$xla_dir"
Expand Down
1 change: 1 addition & 0 deletions spidr/backend/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
/Enzyme-JAX
/xla
16 changes: 16 additions & 0 deletions spidr/backend/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,37 @@ cc_binary(
linkshared = True,
linkstatic = True,
srcs = [
"//src/Enzyme-JAX",
"//src/Enzyme",
"//src/llvm",
"//src/mlir",
"//src/stablehlo",
"//src/xla",
"//src/xla/client",
"//src/xla/hlo/builder",
"//src/xla/hlo/builder/lib",
"//src/xla/hlo/translate",
"//src/xla/mlir_hlo/mhlo/IR",
"//src/xla/pjrt",
"//src/xla/pjrt/c",
"//src/xla/service",
"//src",
],
deps = [
"//src/Enzyme-JAX",
"//src/Enzyme",
"//src/llvm",
"//src/mlir",
"//src/stablehlo",
"//src/xla",
"//src/xla/client",
"//src/xla/hlo/builder",
"//src/xla/hlo/builder/lib",
"//src/xla/hlo/translate",
"//src/xla/mlir_hlo/mhlo/IR",
"//src/xla/pjrt",
"//src/xla/pjrt/c",
"//src/xla/service",
"//src",
],
)
15 changes: 15 additions & 0 deletions spidr/backend/BUILD.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
--- BUILD 2025-02-04 00:03:30.224943988 +0000
+++ BUILD.diff 2025-02-04 00:02:33.364805000 +0000
@@ -385,12 +385,6 @@
"Passes/*.h",
"Dialect/*.h",
]),
- copts = [
- "-Werror=unused-variable",
- "-Werror=unused-but-set-variable",
- "-Werror=return-type",
- "-Werror=unused-result",
- ],
deps = [
":EnzymeXLAOpsIncGen",
":EnzymeXLAPassesIncGen",
1 change: 1 addition & 0 deletions spidr/backend/ENZYME_JAX_VERSION
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ae84d009fe9385e8752a8ee8406115d662d24967
2 changes: 1 addition & 1 deletion spidr/backend/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.16
0.0.17
33 changes: 33 additions & 0 deletions spidr/backend/WORKSPACE
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
### xla

# this must be a local repository not http archive
# so we can run ./configure.py before invoking bazel
local_repository(name = "xla", path = "xla")
Expand Down Expand Up @@ -28,3 +30,34 @@ xla_workspace0()

load("@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", "cuda_configure")
cuda_configure(name = "local_config_cuda")

# Enzyme-JAX

# should we use http_archive for enzyme-jax?
# how best to use enzyme's XLA version?
local_repository(name = "enzyme-jax", path = "Enzyme-JAX")
load("@enzyme-jax//:workspace.bzl", "JAX_COMMIT", "JAX_SHA256", "ENZYME_COMMIT", "ENZYME_SHA256")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

http_archive(
name = "hedron_compile_commands",
url = "https://github.com/hedronvision/bazel-compile-commands-extractor/archive/4f28899228fb3ad0126897876f147ca15026151e.tar.gz",
strip_prefix = "bazel-compile-commands-extractor-4f28899228fb3ad0126897876f147ca15026151e",
)

http_archive(
name = "jax",
sha256 = JAX_SHA256,
strip_prefix = "jax-" + JAX_COMMIT,
urls = ["https://github.com/google/jax/archive/{commit}.tar.gz".format(commit = JAX_COMMIT)],
patch_args = ["-p1"],
patches = ["@enzyme-jax//:patches/jax.patch"],
)

http_archive(
name = "enzyme",
sha256 = ENZYME_SHA256,
strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme",
urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)],
)
12 changes: 9 additions & 3 deletions spidr/backend/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
script_dir=$(CDPATH="" cd -- "$(dirname -- "$0")" && pwd)
cd "$script_dir/../.."
. ./dev.sh
rev="$(cat XLA_VERSION)"
xla_rev="$(cat XLA_VERSION)"
enzyme_rev="$(cat spidr/backend/ENZYME_JAX_VERSION)"

osu="$(uname)"
case $osu in
Expand All @@ -26,9 +27,14 @@ esac
(
cd spidr/backend
mkdir xla
install_xla "$rev" xla
install_xla "$xla_rev" xla
(cd xla; ./configure.py --backend=cpu --os=$os)
# depending on Enzyme-JAX is problematic as it fixes the XLA version. Can we only depend on enzyme?
# seems unlikely that they could decouple XLA entirely. They almost certainly can't decouple stablehlo
mkdir Enzyme-JAX
install_enzyme "$enzyme_rev" Enzyme-JAX
patch Enzyme-JAX/src/enzyme_ad/jax/BUILD < BUILD.patch
bazel build //:c_xla
rm -rf xla
rm -rf xla Enzyme-JAX
)
mv "spidr/backend/bazel-bin/libc_xla.$ext" "libc_xla-$os-$arch.$ext"
14 changes: 14 additions & 0 deletions spidr/backend/src/Enzyme-JAX/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
cc_library(
name = "Enzyme-JAX",
linkstatic = True,
alwayslink = True,
visibility = ["//visibility:public"],
srcs = glob(["**/*.cpp"]),
hdrs = glob(["**/*.h"]),
deps = [
"@enzyme//:EnzymeMLIR",
"@enzyme-jax//src/enzyme_ad/jax:XLADerivatives",
"@llvm-project//mlir:AllPassesAndDialects",
"//src/mlir",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
Copyright 2025 Joel Berkeley

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "Enzyme/MLIR/Dialect/Dialect.h"
#include "Enzyme/MLIR/Dialect/Ops.h"
#include "Enzyme/MLIR/Passes/Passes.h"
#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h"

#include "../../../../../mlir/IR/DialectRegistry.h"

extern "C" {
void registerStableHLODialectAutoDiffInterface(DialectRegistry* registry) {
auto registry_ = reinterpret_cast<mlir::DialectRegistry*>(registry);
mlir::enzyme::registerStableHLODialectAutoDiffInterface(*registry_);
}

void registerCHLODialectAutoDiffInterface(DialectRegistry* registry) {
auto registry_ = reinterpret_cast<mlir::DialectRegistry*>(registry);
mlir::enzyme::registerCHLODialectAutoDiffInterface(*registry_);
}
}
26 changes: 26 additions & 0 deletions spidr/backend/src/Enzyme-JAX/src/enzyme_ad/jax/Passes/Passes.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
Copyright 2025 Joel Berkeley

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "mlir/Pass/PassManager.h"
#include "src/enzyme_ad/jax/Passes/Passes.h"

#include "../../../../../mlir/Pass/PassManager.h"

extern "C" {
void PassManager_addPass_ArithRaisingPass(PassManager& s) {
auto& s_ = reinterpret_cast<mlir::PassManager&>(s);
s_.addPass(mlir::enzyme::createArithRaisingPass());
}
}
13 changes: 13 additions & 0 deletions spidr/backend/src/Enzyme/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
cc_library(
name = "Enzyme",
linkstatic = True,
alwayslink = True,
srcs = glob(["**/*.cpp"]),
hdrs = glob(["**/*.h"]),
deps = [
"@enzyme//:EnzymeMLIR",
"@llvm-project//mlir:IR",
"//src/mlir",
],
visibility = ["//visibility:public"],
)
30 changes: 30 additions & 0 deletions spidr/backend/src/Enzyme/MLIR/Dialect/Dialect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
Copyright 2025 Joel Berkeley

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "mlir/IR/DialectRegistry.h"
#include "Enzyme/MLIR/Dialect/Dialect.h"

#include "../../../mlir/IR/DialectRegistry.h"
#include "../../../mlir/IR/MLIRContext.h"

extern "C" {
void DialectRegistry_insert_EnzymeDialect(DialectRegistry& s) {
reinterpret_cast<mlir::DialectRegistry&>(s).insert<mlir::enzyme::EnzymeDialect>();
}

void MLIRContext_loadDialect_EnzymeDialect(MLIRContext& s) {
reinterpret_cast<mlir::MLIRContext&>(s).loadDialect<mlir::enzyme::EnzymeDialect>();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
Copyright 2025 Joel Berkeley

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h"

#include "../../../mlir/IR/DialectRegistry.h"

extern "C" {
void registerCoreDialectAutodiffInterfaces(DialectRegistry* registry) {
auto registry_ = reinterpret_cast<mlir::DialectRegistry*>(registry);
mlir::enzyme::registerCoreDialectAutodiffInterfaces(*registry_);
}
}
31 changes: 31 additions & 0 deletions spidr/backend/src/Enzyme/MLIR/Passes/Passes.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
Copyright 2025 Joel Berkeley

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "Enzyme/MLIR/Passes/Passes.h"
#include "mlir/Pass/PassManager.h"

#include "../../../mlir/Pass/PassManager.h"

extern "C" {
void registerenzymePasses() {
// where on earth does this come from?
mlir::registerenzymePasses();
}

void PassManager_addPass_RemoveUnusedEnzymeOpsPass(PassManager& s) {
auto& s_ = reinterpret_cast<mlir::PassManager&>(s);
s_.addPass(mlir::enzyme::createRemoveUnusedEnzymeOpsPass());
}
}
Loading
Loading