Skip to content

Commit

Permalink
[SelectionDAG] Add space-optimized forms of OPC_CheckPredicate
Browse files Browse the repository at this point in the history
We record the usage of each `Predicate` and sort them by usage.

For the top 8 `Predicate`s, we will emit a `PC_CheckPredicateN` to
save one byte.

Overall this reduces the llc binary size with all in-tree targets by
about 61K.

This PR is stacked on llvm#73310.
  • Loading branch information
wangpc-pp committed Jan 11, 2024
1 parent 5c8d123 commit e8c1533
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 45 deletions.
8 changes: 8 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAGISel.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,14 @@ class SelectionDAGISel : public MachineFunctionPass {
OPC_CheckPatternPredicate7,
OPC_CheckPatternPredicateTwoByte,
OPC_CheckPredicate,
OPC_CheckPredicate0,
OPC_CheckPredicate1,
OPC_CheckPredicate2,
OPC_CheckPredicate3,
OPC_CheckPredicate4,
OPC_CheckPredicate5,
OPC_CheckPredicate6,
OPC_CheckPredicate7,
OPC_CheckPredicateWithOperands,
OPC_CheckOpcode,
OPC_SwitchOpcode,
Expand Down
30 changes: 25 additions & 5 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2712,9 +2712,13 @@ CheckPatternPredicate(unsigned Opcode, const unsigned char *MatcherTable,

/// CheckNodePredicate - Implements OP_CheckNodePredicate.
LLVM_ATTRIBUTE_ALWAYS_INLINE static bool
CheckNodePredicate(const unsigned char *MatcherTable, unsigned &MatcherIndex,
const SelectionDAGISel &SDISel, SDNode *N) {
return SDISel.CheckNodePredicate(N, MatcherTable[MatcherIndex++]);
CheckNodePredicate(unsigned Opcode, const unsigned char *MatcherTable,
unsigned &MatcherIndex, const SelectionDAGISel &SDISel,
SDNode *N) {
unsigned PredNo = Opcode == SelectionDAGISel::OPC_CheckPredicate
? MatcherTable[MatcherIndex++]
: Opcode - SelectionDAGISel::OPC_CheckPredicate0;
return SDISel.CheckNodePredicate(N, PredNo);
}

LLVM_ATTRIBUTE_ALWAYS_INLINE static bool
Expand Down Expand Up @@ -2868,7 +2872,15 @@ static unsigned IsPredicateKnownToFail(const unsigned char *Table,
Result = !::CheckPatternPredicate(Opcode, Table, Index, SDISel);
return Index;
case SelectionDAGISel::OPC_CheckPredicate:
Result = !::CheckNodePredicate(Table, Index, SDISel, N.getNode());
case SelectionDAGISel::OPC_CheckPredicate0:
case SelectionDAGISel::OPC_CheckPredicate1:
case SelectionDAGISel::OPC_CheckPredicate2:
case SelectionDAGISel::OPC_CheckPredicate3:
case SelectionDAGISel::OPC_CheckPredicate4:
case SelectionDAGISel::OPC_CheckPredicate5:
case SelectionDAGISel::OPC_CheckPredicate6:
case SelectionDAGISel::OPC_CheckPredicate7:
Result = !::CheckNodePredicate(Opcode, Table, Index, SDISel, N.getNode());
return Index;
case SelectionDAGISel::OPC_CheckOpcode:
Result = !::CheckOpcode(Table, Index, N.getNode());
Expand Down Expand Up @@ -3359,8 +3371,16 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
if (!::CheckPatternPredicate(Opcode, MatcherTable, MatcherIndex, *this))
break;
continue;
case SelectionDAGISel::OPC_CheckPredicate0:
case SelectionDAGISel::OPC_CheckPredicate1:
case SelectionDAGISel::OPC_CheckPredicate2:
case SelectionDAGISel::OPC_CheckPredicate3:
case SelectionDAGISel::OPC_CheckPredicate4:
case SelectionDAGISel::OPC_CheckPredicate5:
case SelectionDAGISel::OPC_CheckPredicate6:
case SelectionDAGISel::OPC_CheckPredicate7:
case OPC_CheckPredicate:
if (!::CheckNodePredicate(MatcherTable, MatcherIndex, *this,
if (!::CheckNodePredicate(Opcode, MatcherTable, MatcherIndex, *this,
N.getNode()))
break;
continue;
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/TableGen/address-space-patfrags.td
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def inst_d : Instruction {
let InOperandList = (ins GPR32:$src0, GPR32:$src1);
}

// SDAG: case 2: {
// SDAG: case 1: {
// SDAG-NEXT: // Predicate_pat_frag_b
// SDAG-NEXT: // Predicate_truncstorei16_addrspace
// SDAG-NEXT: SDNode *N = Node;
Expand All @@ -69,7 +69,7 @@ def : Pat <
>;


// SDAG: case 3: {
// SDAG: case 6: {
// SDAG: // Predicate_pat_frag_a
// SDAG-NEXT: SDNode *N = Node;
// SDAG-NEXT: (void)N;
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/TableGen/predicate-patfags.td
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def TGTmul24_oneuse : PatFrag<
}

// SDAG: OPC_CheckOpcode, TARGET_VAL(ISD::INTRINSIC_W_CHAIN),
// SDAG: OPC_CheckPredicate, 0, // Predicate_TGTmul24_oneuse
// SDAG: OPC_CheckPredicate0, // Predicate_TGTmul24_oneuse

// SDAG: OPC_CheckOpcode, TARGET_VAL(TargetISD::MUL24),
// SDAG: OPC_CheckPredicate, 0, // Predicate_TGTmul24_oneuse
// SDAG: OPC_CheckPredicate0, // Predicate_TGTmul24_oneuse

// GISEL: GIM_CheckOpcode, /*MI*/1, GIMT_Encode2(TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS),
// GISEL: GIM_CheckIntrinsicID, /*MI*/1, /*Op*/1, GIMT_Encode2(Intrinsic::tgt_mul24),
Expand Down
93 changes: 57 additions & 36 deletions llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ class MatcherTableEmitter {

SmallVector<unsigned, Matcher::HighestKind+1> OpcodeCounts;

DenseMap<TreePattern *, unsigned> NodePredicateMap;
std::vector<TreePredicateFn> NodePredicates;
std::vector<TreePredicateFn> NodePredicatesWithOperands;
std::vector<TreePattern *> NodePredicates;
std::vector<TreePattern *> NodePredicatesWithOperands;

// We de-duplicate the predicates by code string, and use this map to track
// all the patterns with "identical" predicates.
Expand Down Expand Up @@ -88,6 +87,8 @@ class MatcherTableEmitter {
DenseMap<const ComplexPattern *, unsigned> ComplexPatternUsage;
// Record the usage of PatternPredicate.
std::map<StringRef, unsigned> PatternPredicateUsage;
// Record the usage of Predicate.
DenseMap<TreePattern *, unsigned> PredicateUsage;

// Iterate the whole MatcherTable once and do some statistics.
std::function<void(const Matcher *)> Statistic = [&](const Matcher *N) {
Expand All @@ -105,6 +106,8 @@ class MatcherTableEmitter {
++ComplexPatternUsage[&CPM->getPattern()];
else if (auto *CPPM = dyn_cast<CheckPatternPredicateMatcher>(N))
++PatternPredicateUsage[CPPM->getPredicate()];
else if (auto *PM = dyn_cast<CheckPredicateMatcher>(N))
++PredicateUsage[PM->getPredicate().getOrigPatFragRecord()];
N = N->getNext();
}
};
Expand All @@ -125,6 +128,39 @@ class MatcherTableEmitter {
[](const auto &A, const auto &B) { return A.second > B.second; });
for (const auto &PatternPredicate : PatternPredicateList)
PatternPredicates.push_back(PatternPredicate.first);

// Sort Predicates by usage.
// Merge predicates with same code.
for (const auto &Usage : PredicateUsage) {
TreePattern *TP = Usage.first;
TreePredicateFn Pred(TP);
NodePredicatesByCodeToRun[Pred.getCodeToRunOnSDNode()].push_back(TP);
}

std::vector<std::pair<TreePattern *, unsigned>> PredicateList;
// Sum the usage.
for (auto &Predicate : NodePredicatesByCodeToRun) {
TinyPtrVector<TreePattern *> &TPs = Predicate.second;
sort(TPs, [](const auto *A, const auto *B) {
return A->getRecord()->getName() < B->getRecord()->getName();
});
unsigned Uses = 0;
for (TreePattern *TP : TPs)
Uses += PredicateUsage.at(TP);

// We only add the first predicate here since they are with the same code.
PredicateList.push_back({TPs[0], Uses});
}

sort(PredicateList,
[](const auto &A, const auto &B) { return A.second > B.second; });
for (const auto &Predicate : PredicateList) {
TreePattern *TP = Predicate.first;
if (TreePredicateFn(TP).usesOperands())
NodePredicatesWithOperands.push_back(TP);
else
NodePredicates.push_back(TP);
}
}

unsigned EmitMatcherList(const Matcher *N, const unsigned Indent,
Expand All @@ -139,7 +175,7 @@ class MatcherTableEmitter {
void EmitPatternMatchTable(raw_ostream &OS);

private:
void EmitNodePredicatesFunction(const std::vector<TreePredicateFn> &Preds,
void EmitNodePredicatesFunction(const std::vector<TreePattern *> &Preds,
StringRef Decl, raw_ostream &OS);

unsigned SizeMatcher(Matcher *N, raw_ostream &OS);
Expand All @@ -148,33 +184,13 @@ class MatcherTableEmitter {
raw_ostream &OS);

unsigned getNodePredicate(TreePredicateFn Pred) {
TreePattern *TP = Pred.getOrigPatFragRecord();
unsigned &Entry = NodePredicateMap[TP];
if (Entry == 0) {
TinyPtrVector<TreePattern *> &SameCodePreds =
NodePredicatesByCodeToRun[Pred.getCodeToRunOnSDNode()];
if (SameCodePreds.empty()) {
// We've never seen a predicate with the same code: allocate an entry.
if (Pred.usesOperands()) {
NodePredicatesWithOperands.push_back(Pred);
Entry = NodePredicatesWithOperands.size();
} else {
NodePredicates.push_back(Pred);
Entry = NodePredicates.size();
}
} else {
// We did see an identical predicate: re-use it.
Entry = NodePredicateMap[SameCodePreds.front()];
assert(Entry != 0);
assert(TreePredicateFn(SameCodePreds.front()).usesOperands() ==
Pred.usesOperands() &&
"PatFrags with some code must have same usesOperands setting");
}
// In both cases, we've never seen this particular predicate before, so
// mark it in the list of predicates sharing the same code.
SameCodePreds.push_back(TP);
}
return Entry-1;
// We use the first predicate.
TreePattern *PredPat =
NodePredicatesByCodeToRun[Pred.getCodeToRunOnSDNode()][0];
return Pred.usesOperands()
? llvm::find(NodePredicatesWithOperands, PredPat) -
NodePredicatesWithOperands.begin()
: llvm::find(NodePredicates, PredPat) - NodePredicates.begin();
}

unsigned getPatternPredicate(StringRef PredName) {
Expand Down Expand Up @@ -529,6 +545,7 @@ EmitMatcher(const Matcher *N, const unsigned Indent, unsigned CurrentIdx,
case Matcher::CheckPredicate: {
TreePredicateFn Pred = cast<CheckPredicateMatcher>(N)->getPredicate();
unsigned OperandBytes = 0;
unsigned PredNo = getNodePredicate(Pred);

if (Pred.usesOperands()) {
unsigned NumOps = cast<CheckPredicateMatcher>(N)->getNumOperands();
Expand All @@ -537,10 +554,15 @@ EmitMatcher(const Matcher *N, const unsigned Indent, unsigned CurrentIdx,
OS << cast<CheckPredicateMatcher>(N)->getOperandNo(i) << ", ";
OperandBytes = 1 + NumOps;
} else {
OS << "OPC_CheckPredicate, ";
if (PredNo < 8) {
OperandBytes = -1;
OS << "OPC_CheckPredicate" << PredNo << ", ";
} else
OS << "OPC_CheckPredicate, ";
}

OS << getNodePredicate(Pred) << ',';
if (PredNo >= 8 || Pred.usesOperands())
OS << PredNo << ',';
if (!OmitComments)
OS << " // " << Pred.getFnName();
OS << '\n';
Expand Down Expand Up @@ -1029,8 +1051,7 @@ EmitMatcherList(const Matcher *N, const unsigned Indent, unsigned CurrentIdx,
}

void MatcherTableEmitter::EmitNodePredicatesFunction(
const std::vector<TreePredicateFn> &Preds, StringRef Decl,
raw_ostream &OS) {
const std::vector<TreePattern *> &Preds, StringRef Decl, raw_ostream &OS) {
if (Preds.empty())
return;

Expand All @@ -1040,7 +1061,7 @@ void MatcherTableEmitter::EmitNodePredicatesFunction(
OS << " default: llvm_unreachable(\"Invalid predicate in table?\");\n";
for (unsigned i = 0, e = Preds.size(); i != e; ++i) {
// Emit the predicate code corresponding to this pattern.
const TreePredicateFn PredFn = Preds[i];
TreePredicateFn PredFn(Preds[i]);
assert(!PredFn.isAlwaysTrue() && "No code in this predicate");
std::string PredFnCodeStr = PredFn.getCodeToRunOnSDNode();

Expand Down

0 comments on commit e8c1533

Please sign in to comment.