Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable verifier in main pass manager pipeline #269

Merged
merged 4 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 24 additions & 31 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,14 +244,27 @@ const opt_passes::String = join(
',',
)

function run_pass_pipeline!(mod, pass_pipeline)
function run_pass_pipeline!(mod, pass_pipeline; enable_verifier=true)
pm = MLIR.IR.PassManager()
MLIR.IR.enable_verifier!(pm, enable_verifier)
opm = MLIR.IR.OpPassManager(pm)
MLIR.IR.add_pipeline!(opm, pass_pipeline)
MLIR.IR.run!(pm, mod)
return mod
end

# helper for debug purposes: String -> Text
function run_pass_pipeline_on_source(source, pass_pipeline; enable_verifier=true)
ctx = MLIR.IR.Context(Reactant.registry[], false)
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
MLIR.IR.context!(ctx) do
mod = parse(MLIR.IR.Module, source)
run_pass_pipeline!(mod, pass_pipeline; enable_verifier)
MLIR.IR.verifyall(MLIR.IR.Operation(mod); debug=true)
Text(repr(mod))
end
end

function compile_mlir(f, args; kwargs...)
ctx = MLIR.IR.Context(Reactant.registry[], false)
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
Expand Down Expand Up @@ -280,15 +293,12 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
optimize isa Bool && (optimize = ifelse(optimize, :all, :none))

if optimize === :all
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
run_pass_pipeline!(
mod,
join(
[
opt_passes,
"enzyme-batch",
opt_passes,
"enzyme",
"arith-raise{stablehlo=true}",
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
Expand All @@ -298,28 +308,22 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
),
)
elseif optimize === :only_enzyme
run_pass_pipeline!(mod, "enzyme-batch")
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
run_pass_pipeline!(
mod,
join(
[
"enzyme-batch",
"enzyme",
"arith-raise{stablehlo=true}",
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
],
["canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math"],
',',
),
)
elseif optimize === :after_enzyme
run_pass_pipeline!(mod, "enzyme-batch")
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
run_pass_pipeline!(
mod,
join(
[
"enzyme-batch",
"enzyme",
"arith-raise{stablehlo=true}",
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
Expand All @@ -329,21 +333,10 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
),
)
elseif optimize === :before_enzyme
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes]))
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
run_pass_pipeline!(
mod,
join(
[
opt_passes,
"enzyme-batch",
opt_passes,
"enzyme",
"arith-raise{stablehlo=true}",
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
],
',',
),
mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math"
)
elseif optimize !== :none
error("Invalid optimize option: $(Meta.quot(optimize))")
Expand Down
19 changes: 14 additions & 5 deletions src/mlir/IR/IR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,15 @@ Base.String(str::API.MlirIdentifier) = String(API.mlirIdentifierStr(str))
### Utils

function visit(f, op)
all_ok = true
for region in RegionIterator(op)
for block in BlockIterator(region)
for op in OperationIterator(block)
f(op)
all_ok &= f(op)
end
end
end
return all_ok
end

"""
Expand All @@ -115,13 +117,20 @@ end
Prints the operations which could not be verified.
"""
function verifyall(operation::Operation; debug=false)
io = IOContext(stdout, :debug => debug)
io = IOBuffer()
visit(operation) do op
if !verify(op)
show(io, op)
ok = verifyall(op; debug)
if !ok || !verify(op)
if ok
show(IOContext(io, :debug => debug), op)
error(String(take!(io)))
end
false
else
true
end
end
end
verifyall(module_::IR.Module) = verifyall(Operation(module_))
verifyall(module_::IR.Module; debug=false) = verifyall(Operation(module_); debug)

end # module IR
13 changes: 11 additions & 2 deletions src/mlir/IR/Pass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,17 @@ Base.convert(::Core.Type{API.MlirPassManager}, pass::PassManager) = pass.pass

Enable mlir-print-ir-after-all.
"""
function enable_ir_printing!(pm)
API.mlirPassManagerEnableIRPrinting(pm)
function enable_ir_printing!(
pm;
before_all=false,
after_all=false,
module_scope=false,
after_only_on_change=false,
after_only_on_failure=false,
)
API.mlirPassManagerEnableIRPrinting(
pm, before_all, after_all, module_scope, after_only_on_change, after_only_on_failure
)
return pm
end

Expand Down
Loading