Skip to content

Commit

Permalink
Fix saturating add matching in associativity checking (#8220)
Browse files Browse the repository at this point in the history
* Fix saturating add matching in associativity checking

The associative ops table defined saturating add as
saturating_narrow(widen(x + y)), instead of saturating_narrow(widen(x) +
y)
  • Loading branch information
abadams authored May 24, 2024
1 parent b5f5065 commit 33d5ba9
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 109 deletions.
122 changes: 44 additions & 78 deletions src/AssociativeOpsTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,6 @@ using std::vector;

namespace {

enum class RootExpr {
Add = 0,
Mul = 1,
Max = 2,
Min = 3,
Sub = 4,
Select = 5,
And = 6,
Or = 7,
Cast = 8,
Unknown = 9, // Not supported IR type
};

enum class ValType {
UInt1 = 0,
UInt8 = 1,
Expand Down Expand Up @@ -93,12 +80,12 @@ vector<ValType> convert_halide_types_to_val_types(const vector<Type> &halide_typ

struct TableKey {
vector<ValType> types;
RootExpr root;
IRNodeType root;
size_t dim;
TableKey(ValType t, RootExpr r, size_t d)
TableKey(ValType t, IRNodeType r, size_t d)
: types({t}), root(r), dim(d) {
}
TableKey(const vector<ValType> &t, RootExpr r, size_t d)
TableKey(const vector<ValType> &t, IRNodeType r, size_t d)
: types(t), root(r), dim(d) {
}

Expand Down Expand Up @@ -169,6 +156,14 @@ void populate_ops_table_single_general_select(const vector<Type> &types, vector<
declare_vars_single(types);
}

void populate_ops_table_single_general_call(const vector<Type> &types, vector<AssociativePattern> &table) {
declare_vars_single(types);
if (types[0].code() == Type::UInt) {
table.emplace_back(saturating_add(x0, y0), zero_0, true);
table.emplace_back(saturating_cast(types[0], widening_add(x0, y0)), zero_0, true);
}
}

void populate_ops_table_double_general_add(const vector<Type> &types, vector<AssociativePattern> &table) {
declare_vars_double(types);
if (types[0] == types[1]) {
Expand Down Expand Up @@ -217,9 +212,9 @@ void populate_ops_table_single_uint8_cast(const vector<Type> &types, vector<Asso
Expr k0_uint16 = Variable::make(UInt(16), "k0");
Expr k0_uint32 = Variable::make(UInt(32), "k0");
Expr k0_uint64 = Variable::make(UInt(64), "k0");
table.emplace_back(cast<uint8_t>(min(cast<uint16_t>(x0 + y0), k0_uint16)), zero_0, true);
table.emplace_back(cast<uint8_t>(min(cast<uint32_t>(x0 + y0), k0_uint32)), zero_0, true);
table.emplace_back(cast<uint8_t>(min(cast<uint64_t>(x0 + y0), k0_uint64)), zero_0, true);
table.emplace_back(cast<uint8_t>(min(cast<uint16_t>(x0) + y0, k0_uint16)), zero_0, true);
table.emplace_back(cast<uint8_t>(min(cast<uint32_t>(x0) + y0, k0_uint32)), zero_0, true);
table.emplace_back(cast<uint8_t>(min(cast<uint64_t>(x0) + y0, k0_uint64)), zero_0, true);
}

void populate_ops_table_single_uint8_select(const vector<Type> &types, vector<AssociativePattern> &table) {
Expand All @@ -232,8 +227,8 @@ void populate_ops_table_single_uint16_cast(const vector<Type> &types, vector<Ass
declare_vars_single(types);
Expr k0_uint32 = Variable::make(UInt(32), "k0");
Expr k0_uint64 = Variable::make(UInt(64), "k0");
table.emplace_back(cast<uint16_t>(min(cast<uint32_t>(x0 + y0), k0_uint32)), zero_0, true);
table.emplace_back(cast<uint16_t>(min(cast<uint64_t>(x0 + y0), k0_uint64)), zero_0, true);
table.emplace_back(cast<uint16_t>(min(cast<uint32_t>(x0) + y0, k0_uint32)), zero_0, true);
table.emplace_back(cast<uint16_t>(min(cast<uint64_t>(x0) + y0, k0_uint64)), zero_0, true);
}

void populate_ops_table_single_uint16_select(const vector<Type> &types, vector<AssociativePattern> &table) {
Expand All @@ -255,33 +250,34 @@ void populate_ops_table_single_uint32_select(const vector<Type> &types, vector<A
}

const map<TableKey, void (*)(const vector<Type> &types, vector<AssociativePattern> &)> val_type_to_populate_luts_fn = {
{TableKey(ValType::All, RootExpr::Add, 1), &populate_ops_table_single_general_add},
{TableKey(ValType::All, RootExpr::Mul, 1), &populate_ops_table_single_general_mul},
{TableKey(ValType::All, RootExpr::Max, 1), &populate_ops_table_single_general_max},
{TableKey(ValType::All, RootExpr::Min, 1), &populate_ops_table_single_general_min},
{TableKey(ValType::All, RootExpr::Sub, 1), &populate_ops_table_single_general_sub},
{TableKey(ValType::All, RootExpr::Select, 1), &populate_ops_table_single_general_select},
{TableKey(ValType::All, RootExpr::Add, 2), &populate_ops_table_double_general_add},
{TableKey(ValType::All, RootExpr::Mul, 2), &populate_ops_table_double_general_mul},
{TableKey(ValType::All, RootExpr::Max, 2), &populate_ops_table_double_general_max},
{TableKey(ValType::All, RootExpr::Min, 2), &populate_ops_table_double_general_min},
{TableKey(ValType::All, RootExpr::Sub, 2), &populate_ops_table_double_general_sub},
{TableKey(ValType::All, RootExpr::Select, 2), &populate_ops_table_double_general_select},

{TableKey(ValType::UInt1, RootExpr::And, 1), &populate_ops_table_single_uint1_and},
{TableKey(ValType::UInt1, RootExpr::Or, 1), &populate_ops_table_single_uint1_or},

{TableKey(ValType::UInt8, RootExpr::Cast, 1), &populate_ops_table_single_uint8_cast},
{TableKey(ValType::UInt8, RootExpr::Select, 1), &populate_ops_table_single_uint8_select},

{TableKey(ValType::UInt16, RootExpr::Cast, 1), &populate_ops_table_single_uint16_cast},
{TableKey(ValType::UInt16, RootExpr::Select, 1), &populate_ops_table_single_uint16_select},

{TableKey(ValType::UInt32, RootExpr::Cast, 1), &populate_ops_table_single_uint32_cast},
{TableKey(ValType::UInt32, RootExpr::Select, 1), &populate_ops_table_single_uint32_select},
{TableKey(ValType::All, IRNodeType::Add, 1), &populate_ops_table_single_general_add},
{TableKey(ValType::All, IRNodeType::Mul, 1), &populate_ops_table_single_general_mul},
{TableKey(ValType::All, IRNodeType::Max, 1), &populate_ops_table_single_general_max},
{TableKey(ValType::All, IRNodeType::Min, 1), &populate_ops_table_single_general_min},
{TableKey(ValType::All, IRNodeType::Sub, 1), &populate_ops_table_single_general_sub},
{TableKey(ValType::All, IRNodeType::Select, 1), &populate_ops_table_single_general_select},
{TableKey(ValType::All, IRNodeType::Call, 1), &populate_ops_table_single_general_call},
{TableKey(ValType::All, IRNodeType::Add, 2), &populate_ops_table_double_general_add},
{TableKey(ValType::All, IRNodeType::Mul, 2), &populate_ops_table_double_general_mul},
{TableKey(ValType::All, IRNodeType::Max, 2), &populate_ops_table_double_general_max},
{TableKey(ValType::All, IRNodeType::Min, 2), &populate_ops_table_double_general_min},
{TableKey(ValType::All, IRNodeType::Sub, 2), &populate_ops_table_double_general_sub},
{TableKey(ValType::All, IRNodeType::Select, 2), &populate_ops_table_double_general_select},

{TableKey(ValType::UInt1, IRNodeType::And, 1), &populate_ops_table_single_uint1_and},
{TableKey(ValType::UInt1, IRNodeType::Or, 1), &populate_ops_table_single_uint1_or},

{TableKey(ValType::UInt8, IRNodeType::Cast, 1), &populate_ops_table_single_uint8_cast},
{TableKey(ValType::UInt8, IRNodeType::Select, 1), &populate_ops_table_single_uint8_select},

{TableKey(ValType::UInt16, IRNodeType::Cast, 1), &populate_ops_table_single_uint16_cast},
{TableKey(ValType::UInt16, IRNodeType::Select, 1), &populate_ops_table_single_uint16_select},

{TableKey(ValType::UInt32, IRNodeType::Cast, 1), &populate_ops_table_single_uint32_cast},
{TableKey(ValType::UInt32, IRNodeType::Select, 1), &populate_ops_table_single_uint32_select},
};

const vector<AssociativePattern> &get_ops_table_helper(const vector<Type> &types, RootExpr root, size_t dim) {
const vector<AssociativePattern> &get_ops_table_helper(const vector<Type> &types, IRNodeType root, size_t dim) {
TableKey gen_key(ValType::All, root, dim);
TableKey key(convert_halide_types_to_val_types(types), root, dim);

Expand Down Expand Up @@ -336,43 +332,13 @@ const vector<AssociativePattern> &get_ops_table(const vector<Expr> &exprs) {
types[i] = exprs[i].type();
}

RootExpr root = RootExpr::Unknown;
if (exprs[0].as<Halide::Internal::Add>()) {
debug(5) << "Returning Add root table for type " << print_types(types) << "\n";
root = RootExpr::Add;
} else if (exprs[0].as<Halide::Internal::Sub>()) {
debug(5) << "Returning Sub root table for type " << print_types(types) << "\n";
root = RootExpr::Sub;
} else if (exprs[0].as<Halide::Internal::Mul>()) {
debug(5) << "Returning Mul root table for type " << print_types(types) << "\n";
root = RootExpr::Mul;
} else if (exprs[0].as<Halide::Internal::Min>()) {
debug(5) << "Returning Min root table for type " << print_types(types) << "\n";
root = RootExpr::Min;
} else if (exprs[0].as<Halide::Internal::Max>()) {
debug(5) << "Returning Max root table for type " << print_types(types) << "\n";
root = RootExpr::Max;
} else if (exprs[0].as<Halide::Internal::Select>()) {
debug(5) << "Returning Select root table for type " << print_types(types) << "\n";
root = RootExpr::Select;
} else if (exprs[0].as<Halide::Internal::And>()) {
debug(5) << "Returning And root table for type " << print_types(types) << "\n";
root = RootExpr::And;
} else if (exprs[0].as<Halide::Internal::Or>()) {
debug(5) << "Returning Or root table for type " << print_types(types) << "\n";
root = RootExpr::Or;
} else if (exprs[0].as<Halide::Internal::Cast>()) {
debug(5) << "Returning Cast root table for type " << print_types(types) << "\n";
root = RootExpr::Cast;
}

if (root != RootExpr::Unknown) {
{
// get_ops_table_helper() lazily initializes the table, so ensure
// that multiple threads can't try to do so at the same time.
static std::mutex ops_table_lock;
std::lock_guard<std::mutex> lock_guard(ops_table_lock);

const vector<AssociativePattern> &table = get_ops_table_helper(types, root, exprs.size());
const vector<AssociativePattern> &table = get_ops_table_helper(types, exprs[0].node_type(), exprs.size());
debug(7) << "Table size: " << table.size() << "\n";
for (const auto &p : table) {
debug(7) << p;
Expand Down
45 changes: 14 additions & 31 deletions src/Associativity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,37 +543,20 @@ void associativity_test() {
Expr x_idx = Variable::make(Int(32), "x_idx");
Expr f_call_0 = Call::make(t, "f", {x_idx}, Call::CallType::Halide, FunctionPtr(), 0);

// f(x) = uint8(uint16(x + y), 255)
check_associativity("f", {x_idx}, {Cast::make(UInt(8), min(Cast::make(UInt(16), y + f_call_0), make_const(t, 255)))},
AssociativeOp(
AssociativePattern(Cast::make(UInt(8), min(Cast::make(UInt(16), x + y), make_const(t, 255))), make_const(t, 0), true),
{Replacement("x", f_call_0)},
{Replacement("y", y)},
true));

// f(x) = uint8(uint16(x + y), uint16(255))
check_associativity("f", {x_idx}, {Cast::make(UInt(8), min(Cast::make(UInt(16), y + f_call_0), Cast::make(UInt(16), make_const(t, 255))))},
AssociativeOp(
AssociativePattern(Cast::make(UInt(8), min(Cast::make(UInt(16), x + y), make_const(t, 255))), make_const(t, 0), true),
{Replacement("x", f_call_0)},
{Replacement("y", y)},
true));

// f(x) = select(x > 255 - y, 255, y)
check_associativity("f", {x_idx}, {select(f_call_0 > make_const(t, 255) - y, make_const(t, 255), y)},
AssociativeOp(
AssociativePattern(select(x > make_const(t, 255) - y, make_const(t, 255), y), make_const(t, 0), true),
{Replacement("x", f_call_0)},
{Replacement("y", y)},
true));

// f(x) = select(x >= -y, 255, y)
check_associativity("f", {x_idx}, {select(f_call_0 >= -y, make_const(t, 255), y)},
AssociativeOp(
AssociativePattern(select(x < -y, y, make_const(t, 255)), make_const(t, 0), true),
{Replacement("x", f_call_0)},
{Replacement("y", y)},
true));
for (const Expr &e : {cast<uint8_t>(min(cast<uint16_t>(x) + y, 255)),
select(x > 255 - y, cast<uint8_t>(255), y),
select(x < -y, y, cast<uint8_t>(255)),
saturating_add(x, y),
saturating_add(y, x),
saturating_cast<uint8_t>(widening_add(x, y))}) {
check_associativity("f", {x_idx}, {substitute("x", f_call_0, e)},
AssociativeOp(
AssociativePattern(solve_expression(e, "x").result,
make_const(t, 0), true),
{Replacement("x", f_call_0)},
{Replacement("y", y)},
true));
}
}

{
Expand Down
31 changes: 31 additions & 0 deletions src/Solve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,37 @@ class SolveExpression : public IRMutator {
// Ignore intrinsics that shouldn't affect the results.
if (Call::as_tag(op)) {
return mutate(op->args[0]);
} else if (op->is_intrinsic({Call::absd, Call::bitwise_and, Call::bitwise_or,
Call::bitwise_xor, Call::halving_add, Call::rounding_halving_add,
Call::saturating_add, Call::widening_add, Call::widening_mul})) {
// It's a commutative intrinsic. We won't try to lift uses of the
// var out of the call, but we will reorder the args if it would
// help.
internal_assert(op->args.size() == 2);
bool old_uses_var = uses_var;
uses_var = false;
bool old_failed = failed;
failed = false;
Expr a = mutate(op->args[0]);
bool a_uses_var = uses_var;
bool a_failed = failed;
uses_var = false;
failed = false;
Expr b = mutate(op->args[1]);
bool b_uses_var = uses_var;
bool b_failed = failed;
uses_var = old_uses_var || a_uses_var || b_uses_var;
failed = old_failed || a_failed || b_failed;

failed |= a_uses_var && b_uses_var;

if (b_uses_var && !a_uses_var) {
return Call::make(op->type, op->name, {b, a}, op->call_type);
} else if (a.same_as(op->args[0]) && b.same_as(op->args[1])) {
return op;
} else {
return Call::make(op->type, op->name, {a, b}, op->call_type);
}
} else {
return IRMutator::visit(op);
}
Expand Down

0 comments on commit 33d5ba9

Please sign in to comment.