diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/utils/ProgramBuilder.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/utils/ProgramBuilder.java index 9f1189ba44..a0cf8dd9ca 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/utils/ProgramBuilder.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/utils/ProgramBuilder.java @@ -7,7 +7,6 @@ import com.dat3m.dartagnan.expression.Type; import com.dat3m.dartagnan.expression.integers.IntLiteral; import com.dat3m.dartagnan.expression.type.FunctionType; -import com.dat3m.dartagnan.expression.type.IntegerType; import com.dat3m.dartagnan.expression.type.TypeFactory; import com.dat3m.dartagnan.program.*; import com.dat3m.dartagnan.program.Thread; @@ -46,10 +45,6 @@ public class ProgramBuilder { private final Map> fid2LabelsMap = new HashMap<>(); private final Map locations = new HashMap<>(); private final Map reg2LocMap = new HashMap<>(); - private final Map> id2RegTypeMap = new HashMap<>(); - private final Map> id2RegConstMap = new HashMap<>(); - private final Map> id2RegLocPtrMap = new HashMap<>(); - private final Map> id2RegLocValMap = new HashMap<>(); private final Program program; @@ -119,19 +114,13 @@ public void setAssertFilter(Expression ass) { // This method creates a "default" thread that has no parameters, no return value, and runs unconditionally. // It is only useful for creating threads of Litmus code. - public Thread newThread(int tid, Thread thread) { + public Thread newThread(String name, int tid) { + if(id2FunctionsMap.containsKey(tid)) { + throw new MalformedProgramException("Function or thread with id " + tid + " already exists."); + } + final Thread thread = new Thread(name, DEFAULT_THREAD_TYPE, List.of(), tid, EventFactory.newThreadStart(null)); id2FunctionsMap.put(tid, thread); program.addThread(thread); - if (id2RegConstMap.containsKey(tid)) { - id2RegConstMap.get(tid).forEach((regName, value) -> - initRegEqConst(tid, regName, value)); - } else if (id2RegLocPtrMap.containsKey(tid)) { - id2RegLocPtrMap.get(tid).forEach((regName, value) -> - initRegEqLocPtr(tid, regName, value, getRegType(tid, regName))); - } else if (id2RegLocValMap.containsKey(tid)) { - id2RegLocValMap.get(tid).forEach((regName, value) -> - initRegEqLocVal(tid, regName, value, getRegType(tid, regName))); - } return thread; } @@ -146,12 +135,8 @@ public Function newFunction(String name, int fid, FunctionType type, List new HashMap<>()).put(regName, type); - } - - public IntegerType getRegType(int tid, String regName) { - if (id2RegTypeMap.containsKey(tid) && id2RegTypeMap.get(tid).containsKey(regName)) { - return id2RegTypeMap.get(tid).get(regName); - } - throw new IllegalStateException("Register " + tid + ":" + regName + " is not initialised"); - } - - public void addRegToConstMap(int tid, String regName, Expression value) { - id2RegConstMap.computeIfAbsent(tid, k -> new HashMap<>()).put(regName, value); - } - - public void addRegToLocPtrMap(int tid, String regName, String locName) { - id2RegLocPtrMap.computeIfAbsent(tid, k -> new HashMap<>()).put(regName, locName); - } - - public void addRegToLocValMap(int tid, String regName, String locName) { - id2RegLocValMap.computeIfAbsent(tid, k -> new HashMap<>()).put(regName, locName); - } - private Expression getInitialValue(String name) { return getOrNewMemoryObject(name).getInitialValue(0); } @@ -314,26 +276,28 @@ public Label getEndOfThreadLabel(int tid) { // ---------------------------------------------------------------------------------------------------------------- // GPU - public void newScopedThread(Arch arch, String name, int id, int ...scopeIds) { - if(id2FunctionsMap.containsKey(id)) { - throw new MalformedProgramException("Function or thread with id " + id + " already exists."); - } - // Litmus threads run unconditionally (have no creator) and have no parameters/return types. - ThreadStart threadEntry = EventFactory.newThreadStart(null); - Thread scopedThread = switch (arch) { - case PTX -> new Thread(name, DEFAULT_THREAD_TYPE, List.of(), id, threadEntry, - ScopeHierarchy.ScopeHierarchyForPTX(scopeIds[0], scopeIds[1]), new HashSet<>()); - case VULKAN -> new Thread(name, DEFAULT_THREAD_TYPE, List.of(), id, threadEntry, - ScopeHierarchy.ScopeHierarchyForVulkan(scopeIds[0], scopeIds[1], scopeIds[2]), new HashSet<>()); - case OPENCL -> new Thread(name, DEFAULT_THREAD_TYPE, List.of(), id, threadEntry, - ScopeHierarchy.ScopeHierarchyForOpenCL(scopeIds[0], scopeIds[1]), new HashSet<>()); + public void setOrCreateScopedThread(Arch arch, String name, int id, int ...scopeIds) { + ScopeHierarchy scopeHierarchy = switch (arch) { + case PTX -> ScopeHierarchy.ScopeHierarchyForPTX(scopeIds[0], scopeIds[1]); + case VULKAN -> ScopeHierarchy.ScopeHierarchyForVulkan(scopeIds[0], scopeIds[1], scopeIds[2]); + case OPENCL -> ScopeHierarchy.ScopeHierarchyForOpenCL(scopeIds[0], scopeIds[1]); default -> throw new UnsupportedOperationException("Unsupported architecture: " + arch); }; - newThread(id, scopedThread); + + if(id2FunctionsMap.containsKey(id)) { + Thread thread = (Thread) id2FunctionsMap.get(id); + thread.setScopeHierarchy(scopeHierarchy); + } else { + // Litmus threads run unconditionally (have no creator) and have no parameters/return types. + ThreadStart threadEntry = EventFactory.newThreadStart(null); + Thread scopedThread = new Thread(name, DEFAULT_THREAD_TYPE, List.of(), id, threadEntry, scopeHierarchy, new HashSet<>()); + id2FunctionsMap.put(id, scopedThread); + program.addThread(scopedThread); + } } - public void newScopedThread(Arch arch, int id, int ...ids) { - newScopedThread(arch, String.valueOf(id), id, ids); + public void setOrCreateScopedThread(Arch arch, int id, int ...ids) { + setOrCreateScopedThread(arch, String.valueOf(id), id, ids); } // ---------------------------------------------------------------------------------------------------------------- diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusC.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusC.java index 83862cf514..711c308356 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusC.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusC.java @@ -2,7 +2,6 @@ import com.dat3m.dartagnan.configuration.Arch; import com.dat3m.dartagnan.exception.ParsingException; -import com.dat3m.dartagnan.expression.BinaryExpression; import com.dat3m.dartagnan.expression.Expression; import com.dat3m.dartagnan.expression.ExpressionFactory; import com.dat3m.dartagnan.expression.integers.IntLiteral; @@ -46,6 +45,8 @@ public VisitorLitmusC(){ @Override public Program visitMain(LitmusCParser.MainContext ctx) { + //FIXME: We should visit thread declarations before variable declarations + // because variable declaration refer to threads. visitVariableDeclaratorList(ctx.variableDeclaratorList()); visitProgram(ctx.program()); VisitorLitmusAssertions.parseAssertions(programBuilder, ctx.assertionList(), ctx.assertionFilter()); @@ -68,9 +69,10 @@ public Object visitGlobalDeclaratorLocation(LitmusCParser.GlobalDeclaratorLocati @Override public Object visitGlobalDeclaratorRegister(LitmusCParser.GlobalDeclaratorRegisterContext ctx) { if (ctx.initConstantValue() != null) { + // FIXME: We visit declarators before threads, so we need to create threads early + programBuilder.getOrNewThread(ctx.threadId().id); IntLiteral value = expressions.parseValue(ctx.initConstantValue().constant().getText(), archType); - programBuilder.addRegType(ctx.threadId().id, ctx.varName().getText(), archType); - programBuilder.addRegToConstMap(ctx.threadId().id, ctx.varName().getText(), value); + programBuilder.initRegEqConst(ctx.threadId().id,ctx.varName().getText(), value); } return null; } @@ -93,19 +95,17 @@ public Object visitGlobalDeclaratorLocationLocation(LitmusCParser.GlobalDeclarat @Override public Object visitGlobalDeclaratorRegisterLocation(LitmusCParser.GlobalDeclaratorRegisterLocationContext ctx) { - int threadId = ctx.threadId().id; - String regName = ctx.varName(0).getText(); - String locName = ctx.varName(1).getText(); - programBuilder.addRegType(threadId, regName, archType); + // FIXME: We visit declarators before threads, so we need to create threads early + programBuilder.getOrNewThread(ctx.threadId().id); if(ctx.Ast() == null){ - programBuilder.addRegToLocPtrMap(threadId, regName, locName); + programBuilder.initRegEqLocPtr(ctx.threadId().id, ctx.varName(0).getText(), ctx.varName(1).getText(), archType); } else { String rightName = ctx.varName(1).getText(); MemoryObject object = programBuilder.getMemoryObject(rightName); if(object != null){ - programBuilder.addRegToConstMap(threadId, regName, object); + programBuilder.initRegEqConst(ctx.threadId().id, ctx.varName(0).getText(), object); } else { - programBuilder.addRegToLocValMap(threadId, regName, locName); + programBuilder.initRegEqLocVal(ctx.threadId().id, ctx.varName(0).getText(), ctx.varName(1).getText(), archType); } } return null; @@ -158,7 +158,7 @@ public Object visitThread(LitmusCParser.ThreadContext ctx) { // Declarations in the preamble may have created the thread already if (ctx.threadScope() == null) { // Set dummy scope for C11 threads - programBuilder.newScopedThread(Arch.OPENCL, currentThread, 0, 0); + programBuilder.setOrCreateScopedThread(Arch.OPENCL, currentThread, 0, 0); } else { ctx.threadScope().accept(this); this.isOpenCL = true; @@ -176,7 +176,7 @@ public Object visitThread(LitmusCParser.ThreadContext ctx) { public Object visitOpenCLThreadScope(LitmusCParser.OpenCLThreadScopeContext ctx) { int wgID = ctx.scopeID(0).id; int devID = ctx.scopeID(1).id; - programBuilder.newScopedThread(Arch.OPENCL, currentThread, devID, wgID); + programBuilder.setOrCreateScopedThread(Arch.OPENCL, currentThread, devID, wgID); return null; } diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusPTX.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusPTX.java index cf99a696e1..520741b3b1 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusPTX.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusPTX.java @@ -100,7 +100,7 @@ public Object visitThreadDeclaratorList(LitmusPTXParser.ThreadDeclaratorListCont int ctaID = threadScopeContext.scopeID().ctaID().id; int gpuID = threadScopeContext.scopeID().gpuID().id; // NB: the order of scopeIDs is important - programBuilder.newScopedThread(Arch.PTX, threadScopeContext.threadId().id, gpuID, ctaID); + programBuilder.setOrCreateScopedThread(Arch.PTX, threadScopeContext.threadId().id, gpuID, ctaID); threadCount++; } return null; diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusVulkan.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusVulkan.java index cb50070057..9314d52b63 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusVulkan.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/VisitorLitmusVulkan.java @@ -117,7 +117,7 @@ public Object visitThreadDeclaratorList(LitmusVulkanParser.ThreadDeclaratorListC int workgroupID = threadScopeContext.workgroupScope().scopeID().id; int queuefamilyID = threadScopeContext.queuefamilyScope().scopeID().id; // NB: the order of scopeIDs is important - programBuilder.newScopedThread(Arch.VULKAN, threadScopeContext.threadId().id, + programBuilder.setOrCreateScopedThread(Arch.VULKAN, threadScopeContext.threadId().id, queuefamilyID, workgroupID, subgroupID); threadCount++; } diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/program/Thread.java b/dartagnan/src/main/java/com/dat3m/dartagnan/program/Thread.java index a4280b0b1b..aa3c718d7a 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/program/Thread.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/program/Thread.java @@ -15,7 +15,7 @@ public class Thread extends Function { // Scope hierarchy of the thread - private final Optional scopeHierarchy; + private Optional scopeHierarchy; // Threads that are system-synchronized-with this thread private final Optional> syncSet; @@ -56,6 +56,10 @@ public Set getSyncSet() { return syncSet.get(); } + public void setScopeHierarchy(ScopeHierarchy scopeHierarchy) { + this.scopeHierarchy = Optional.of(scopeHierarchy); + } + @Override public ThreadStart getEntry() { return (ThreadStart) entry;