Skip to content

Commit

Permalink
visit thread declarator before variable declarator
Browse files Browse the repository at this point in the history
  • Loading branch information
tonghaining committed Oct 9, 2024
1 parent 65fb74a commit 32c2dc2
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 33 deletions.
16 changes: 8 additions & 8 deletions dartagnan/src/main/antlr4/LitmusC.g4
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import static com.dat3m.dartagnan.program.event.Tag.*;
}

main
: LitmusLanguage ~(LBrace)* variableDeclaratorList program variableList? assertionFilter? assertionList? comment? EOF
: LitmusLanguage ~(LBrace)* variableDeclaratorList (threadDeclarator threadContent)+ variableList? assertionFilter? assertionList? comment? EOF
;

variableDeclaratorList
Expand All @@ -24,16 +24,16 @@ globalDeclarator
| typeSpecifier? varName LBracket DigitSequence? RBracket (Equals initArray)? # globalDeclaratorArray
;

program
: thread+
threadDeclarator
: threadId (At threadScope)?
;

thread
: threadId (At threadScope)? LPar threadArguments? RPar LBrace expression* RBrace
threadContent
: LPar threadArguments? RPar LBrace expression* RBrace
;

threadScope
: OpenCLWG scopeID Comma OpenCLDEV scopeID # OpenCLThreadScope
: OpenCLWG scopeID Comma OpenCLDEV scopeID
;

threadArguments
Expand Down Expand Up @@ -215,9 +215,9 @@ nre locals [IntBinaryOp op, String mo, String name]

| C11AtomicFence LPar c11Mo RPar # nreC11Fence

| OpenCLAtomicFenceWI LPar openCLFenceFlags Comma c11Mo Comma openCLScope RPar # nreOpenCLFence
| OpenCLAtomicFenceWI LPar openCLFenceFlags Comma c11Mo Comma openCLScope RPar # nreOpenCLFence

| barrierId Colon OpenCLBarrier LPar openCLFenceFlags (Comma openCLScope)? RPar # nreOpenCLBarrier
| barrierId Colon OpenCLBarrier LPar openCLFenceFlags (Comma openCLScope)? RPar # nreOpenCLBarrier

;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ public Label getEndOfThreadLabel(int tid) {

// ----------------------------------------------------------------------------------------------------------------
// GPU
public void setOrCreateScopedThread(Arch arch, String name, int id, int ...scopeIds) {
public void newScopedThread(Arch arch, String name, int id, int ...scopeIds) {
ScopeHierarchy scopeHierarchy = switch (arch) {
case PTX -> ScopeHierarchy.ScopeHierarchyForPTX(scopeIds[0], scopeIds[1]);
case VULKAN -> ScopeHierarchy.ScopeHierarchyForVulkan(scopeIds[0], scopeIds[1], scopeIds[2]);
Expand All @@ -296,8 +296,8 @@ public void setOrCreateScopedThread(Arch arch, String name, int id, int ...scope
}
}

public void setOrCreateScopedThread(Arch arch, int id, int ...ids) {
setOrCreateScopedThread(arch, String.valueOf(id), id, ids);
public void newScopedThread(Arch arch, int id, int ...ids) {
newScopedThread(arch, String.valueOf(id), id, ids);
}

// ----------------------------------------------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public class VisitorLitmusC extends LitmusCBaseVisitor<Object> {
private int whileId = 0;
private Register returnRegister;
private boolean isOpenCL = false;
private final List<Integer> threadIds = new ArrayList<>();

public VisitorLitmusC(){
}
Expand All @@ -45,11 +46,16 @@ public VisitorLitmusC(){

@Override
public Program visitMain(LitmusCParser.MainContext ctx) {
//FIXME: We should visit thread declarations before variable declarations
// because variable declaration refer to threads.
isOpenCL = ctx.LitmusLanguage().getText().equals("OPENCL");
isOpenCL = ctx.LitmusLanguage().getText().equals("OPENCL");
for (LitmusCParser.ThreadDeclaratorContext threadDeclaratorContext : ctx.threadDeclarator()) {
visitThreadDeclarator(threadDeclaratorContext);
}
visitVariableDeclaratorList(ctx.variableDeclaratorList());
visitProgram(ctx.program());
int threadIndex = 0;
for (LitmusCParser.ThreadContentContext threadContentContext : ctx.threadContent()) {
scope = currentThread = threadIds.get(threadIndex++);
visitThreadContent(threadContentContext);
}
VisitorLitmusAssertions.parseAssertions(programBuilder, ctx.assertionList(), ctx.assertionFilter());
return programBuilder.build();
}
Expand All @@ -70,8 +76,6 @@ public Object visitGlobalDeclaratorLocation(LitmusCParser.GlobalDeclaratorLocati
@Override
public Object visitGlobalDeclaratorRegister(LitmusCParser.GlobalDeclaratorRegisterContext ctx) {
if (ctx.initConstantValue() != null) {
// FIXME: We visit declarators before threads, so we need to create threads early
programBuilder.getOrNewThread(ctx.threadId().id);
IntLiteral value = expressions.parseValue(ctx.initConstantValue().constant().getText(), archType);
programBuilder.initRegEqConst(ctx.threadId().id,ctx.varName().getText(), value);
}
Expand All @@ -97,7 +101,6 @@ public Object visitGlobalDeclaratorLocationLocation(LitmusCParser.GlobalDeclarat
@Override
public Object visitGlobalDeclaratorRegisterLocation(LitmusCParser.GlobalDeclaratorRegisterLocationContext ctx) {
// FIXME: We visit declarators before threads, so we need to create threads early
programBuilder.getOrNewThread(ctx.threadId().id);
if(ctx.Ast() == null){
programBuilder.initRegEqLocPtr(ctx.threadId().id, ctx.varName(0).getText(), ctx.varName(1).getText(), archType);
} else {
Expand Down Expand Up @@ -154,28 +157,25 @@ public Object visitGlobalDeclaratorArray(LitmusCParser.GlobalDeclaratorArrayCont
// Threads (the program itself)

@Override
public Object visitThread(LitmusCParser.ThreadContext ctx) {
public Object visitThreadDeclarator(LitmusCParser.ThreadDeclaratorContext ctx) {
scope = currentThread = ctx.threadId().id;
// Declarations in the preamble may have created the thread already
threadIds.add(currentThread);
if (isOpenCL && ctx.threadScope() != null) {
ctx.threadScope().accept(this);
int wgID = ctx.threadScope().scopeID(0).id;
int devID = ctx.threadScope().scopeID(1).id;
programBuilder.newScopedThread(Arch.OPENCL, currentThread, devID, wgID);
} else {
programBuilder.getOrNewThread(currentThread);
programBuilder.newThread(currentThread);
}
visitThreadArguments(ctx.threadArguments());

for(LitmusCParser.ExpressionContext expressionContext : ctx.expression())
expressionContext.accept(this);

scope = currentThread = -1;
return null;
}

@Override
public Object visitOpenCLThreadScope(LitmusCParser.OpenCLThreadScopeContext ctx) {
int wgID = ctx.scopeID(0).id;
int devID = ctx.scopeID(1).id;
programBuilder.setOrCreateScopedThread(Arch.OPENCL, currentThread, devID, wgID);
public Object visitThreadContent(LitmusCParser.ThreadContentContext ctx) {
visitThreadArguments(ctx.threadArguments());
for(LitmusCParser.ExpressionContext expressionContext : ctx.expression())
expressionContext.accept(this);
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public Object visitThreadDeclaratorList(LitmusPTXParser.ThreadDeclaratorListCont
int ctaID = threadScopeContext.scopeID().ctaID().id;
int gpuID = threadScopeContext.scopeID().gpuID().id;
// NB: the order of scopeIDs is important
programBuilder.setOrCreateScopedThread(Arch.PTX, threadScopeContext.threadId().id, gpuID, ctaID);
programBuilder.newScopedThread(Arch.PTX, threadScopeContext.threadId().id, gpuID, ctaID);
threadCount++;
}
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public Object visitThreadDeclaratorList(LitmusVulkanParser.ThreadDeclaratorListC
int workgroupID = threadScopeContext.workgroupScope().scopeID().id;
int queuefamilyID = threadScopeContext.queuefamilyScope().scopeID().id;
// NB: the order of scopeIDs is important
programBuilder.setOrCreateScopedThread(Arch.VULKAN, threadScopeContext.threadId().id,
programBuilder.newScopedThread(Arch.VULKAN, threadScopeContext.threadId().id,
queuefamilyID, workgroupID, subgroupID);
threadCount++;
}
Expand Down

0 comments on commit 32c2dc2

Please sign in to comment.