Skip to content

Commit

Permalink
Updated model in spirv tests
Browse files Browse the repository at this point in the history
  • Loading branch information
natgavrilenko committed Oct 4, 2024
1 parent b7e7fa9 commit 9b4e18a
Show file tree
Hide file tree
Showing 392 changed files with 1,340 additions and 702 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import java.util.Arrays;

public enum Arch implements OptionInterface {
C11, ARM8, POWER, PTX, TSO, IMM, LKMM, RISCV, VULKAN;
C11, ARM8, POWER, PTX, TSO, IMM, LKMM, RISCV, VULKAN, OPENCL;

// Used to display in UI
@Override
Expand All @@ -27,6 +27,8 @@ public String toString() {
return "RISCV";
case VULKAN:
return "VULKAN";
case OPENCL:
return "OpenCL";
}
throw new UnsupportedOperationException("Unrecognized architecture " + this);
}
Expand All @@ -37,7 +39,7 @@ public static Arch getDefault() {

// Used to decide the order shown by the selector in the UI
public static Arch[] orderedValues() {
Arch[] order = { C11, ARM8, IMM, LKMM, POWER, PTX, RISCV, TSO, VULKAN };
Arch[] order = { C11, ARM8, IMM, LKMM, POWER, PTX, RISCV, TSO, VULKAN, OPENCL };
// Be sure no element is missing
assert(Arrays.asList(order).containsAll(Arrays.asList(values())));
return order;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,8 @@ public Event visitOpStore(SpirvParser.OpStoreContext ctx) {
Set<String> tags = parseMemoryAccessTags(ctx.memoryAccess());
if (!tags.contains(Tag.Spirv.MEM_VISIBLE)) {
String storageClass = builder.getPointerStorageClass(ctx.pointer().getText());
String scope = getScope(storageClass);
event.addTags(tags);
event.addTags(storageClass);
if (scope != null) {
event.addTags(Tag.Spirv.MEM_NON_PRIVATE);
event.addTags(scope);
}
return builder.addEvent(event);
}
throw new ParsingException("OpStore cannot contain tag '%s'", Tag.Spirv.MEM_VISIBLE);
Expand All @@ -69,13 +64,8 @@ public Event visitOpLoad(SpirvParser.OpLoadContext ctx) {
Set<String> tags = parseMemoryAccessTags(ctx.memoryAccess());
if (!tags.contains(Tag.Spirv.MEM_AVAILABLE)) {
String storageClass = builder.getPointerStorageClass(ctx.pointer().getText());
String scope = getScope(storageClass);
event.addTags(tags);
event.addTags(storageClass);
if (scope != null) {
event.addTags(Tag.Spirv.MEM_NON_PRIVATE);
event.addTags(scope);
}
return builder.addEvent(event);
}
throw new ParsingException("OpLoad cannot contain tag '%s'", Tag.Spirv.MEM_AVAILABLE);
Expand Down Expand Up @@ -196,24 +186,6 @@ private Set<String> parseMemoryAccessTags(SpirvParser.MemoryAccessContext ctx) {
return Set.of();
}

private String getScope(String storageClass) {
return switch (storageClass) {
case Tag.Spirv.SC_UNIFORM_CONSTANT,
Tag.Spirv.SC_UNIFORM,
Tag.Spirv.SC_OUTPUT,
Tag.Spirv.SC_PUSH_CONSTANT,
Tag.Spirv.SC_STORAGE_BUFFER,
Tag.Spirv.SC_PHYS_STORAGE_BUFFER -> Tag.Spirv.DEVICE;
case Tag.Spirv.SC_PRIVATE,
Tag.Spirv.SC_FUNCTION,
Tag.Spirv.SC_INPUT -> null;
case Tag.Spirv.SC_WORKGROUP -> Tag.Spirv.WORKGROUP;
case Tag.Spirv.SC_CROSS_WORKGROUP -> Tag.Spirv.QUEUE_FAMILY;
default -> throw new UnsupportedOperationException(
"Unsupported storage class " + storageClass);
};
}

public Set<String> getSupportedOps() {
return Set.of(
"OpVariable",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.dat3m.dartagnan.parsers.program.visitors.spirv;

import com.dat3m.dartagnan.configuration.Arch;
import com.dat3m.dartagnan.exception.ParsingException;
import com.dat3m.dartagnan.parsers.SpirvBaseVisitor;
import com.dat3m.dartagnan.parsers.SpirvParser;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.builders.ProgramBuilder;
Expand All @@ -14,12 +16,27 @@ public VisitorOpsSetting(ProgramBuilder builder) {
this.builder = builder;
}

@Override
public Void visitOpMemoryModel(SpirvParser.OpMemoryModelContext ctx) {
builder.setArch(parseArch(ctx.memoryModel().getText()));
return null;
}

@Override
public Void visitOpEntryPoint(SpirvParser.OpEntryPointContext ctx) {
builder.setEntryPointId(ctx.entryPoint().getText());
return null;
}

private Arch parseArch(String memoryModel) {
return switch (memoryModel) {
case "Vulkan", "VulkanKHR" -> Arch.VULKAN;
case "OpenCL" -> Arch.OPENCL;
case "GLSL450", "Simple" -> throw new ParsingException("Unsupported memory model '%s'", memoryModel);
default -> throw new ParsingException("Illegal memory model '%s'", memoryModel);
};
}

public Set<String> getSupportedOps() {
return Set.of(
"OpCapability",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.dat3m.dartagnan.parsers.program.visitors.spirv.builders;

import com.dat3m.dartagnan.configuration.Arch;
import com.dat3m.dartagnan.exception.ParsingException;
import com.dat3m.dartagnan.expression.Expression;
import com.dat3m.dartagnan.expression.Type;
Expand Down Expand Up @@ -33,6 +34,7 @@ public class ProgramBuilder {
protected DecorationsBuilder decorationsBuilder;
protected Function currentFunction;
protected String entryPointId;
protected Arch arch;
protected Set<String> nextOps;

public ProgramBuilder(ThreadGrid grid) {
Expand Down Expand Up @@ -84,6 +86,13 @@ public void setEntryPointId(String id) {
entryPointId = id;
}

public void setArch(Arch arch) {
if (this.arch != null) {
throw new ParsingException("Illegal attempt to override memory model");
}
this.arch = arch;
}

public void setSpecification(Program.SpecificationType type, Expression condition) {
if (program.getSpecification() != null) {
throw new ParsingException("Attempt to override program specification");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ private VisitorBase getCompiler() {
case RISCV -> new VisitorRISCV(useRC11Scheme);
case PTX -> new VisitorPTX();
case VULKAN -> new VisitorVulkan();
case OPENCL -> throw new UnsupportedOperationException();
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,31 +151,6 @@ private Set<String> toVulkanTags(Set<String> tags) {
vTags.add(tag);
}
});
return adjustVulkanTags(tags, vTags);
}

private Set<String> adjustVulkanTags(Set<String> tags, Set<String> vTags) {
if (tags.contains(Tag.MEMORY) && toVulkanTag(Tag.Spirv.getStorageClassTag(tags)) != null) {
vTags.add(Tag.Vulkan.NON_PRIVATE);
if (vTags.contains(Tag.READ)) {
vTags.add(Tag.Vulkan.VISIBLE);
}
if (vTags.contains(Tag.WRITE)) {
vTags.add(Tag.Vulkan.AVAILABLE);
}
}
if (tags.contains(Tag.Spirv.MEM_AVAILABLE) && tags.contains(Tag.Spirv.DEVICE)) {
vTags.add(Tag.Vulkan.AVDEVICE);
}
if (tags.contains(Tag.Spirv.MEM_VISIBLE) && tags.contains(Tag.Spirv.DEVICE)) {
vTags.add(Tag.Vulkan.VISDEVICE);
}
if (tags.contains(Tag.Spirv.RELAXED)) {
vTags.remove(Tag.Vulkan.SEMSC0);
vTags.remove(Tag.Vulkan.SEMSC1);
vTags.remove(Tag.Vulkan.SEM_VISIBLE);
vTags.remove(Tag.Vulkan.SEM_AVAILABLE);
}
return vTags;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,10 @@ private void propagateTags(Event source, Event target) {
if (source.hasTag(Tag.Vulkan.SEM_VISIBLE)) {
target.addTags(Tag.Vulkan.SEM_VISIBLE);
}
// If a RMW is a release, we do not propagate semscX to the read
if (!(source.hasTag(Tag.Vulkan.ACQUIRE) || source.hasTag(Tag.Vulkan.ACQ_REL))) {
if (target.hasTag(Tag.Vulkan.SEMSC0)) {
target.removeTags(Tag.Vulkan.SEMSC0);
}
if (target.hasTag(Tag.Vulkan.SEMSC1)) {
target.removeTags(Tag.Vulkan.SEMSC1);
}
// Remove tag if it refers to the release write
if (!source.hasTag(Tag.Vulkan.ACQUIRE) && source.hasTag(Tag.Vulkan.RELEASE)) {
target.removeTags(Tag.Vulkan.SEMSC0);
target.removeTags(Tag.Vulkan.SEMSC1);
}
if (source.hasTag(Tag.Vulkan.VISDEVICE)) {
target.addTags(Tag.Vulkan.VISDEVICE);
Expand All @@ -129,14 +125,10 @@ private void propagateTags(Event source, Event target) {
if (source.hasTag(Tag.Vulkan.SEM_AVAILABLE)) {
target.addTags(Tag.Vulkan.SEM_AVAILABLE);
}
// If a RMW is an acquire, we do not propagate semscX to the write
if (!(source.hasTag(Tag.Vulkan.RELEASE) || source.hasTag(Tag.Vulkan.ACQ_REL))) {
if (target.hasTag(Tag.Vulkan.SEMSC0)) {
target.removeTags(Tag.Vulkan.SEMSC0);
}
if (target.hasTag(Tag.Vulkan.SEMSC1)) {
target.removeTags(Tag.Vulkan.SEMSC1);
}
// Remove tag if it refers to the acquire read
if (!source.hasTag(Tag.Vulkan.RELEASE) && source.hasTag(Tag.Vulkan.ACQUIRE)) {
target.removeTags(Tag.Vulkan.SEMSC0);
target.removeTags(Tag.Vulkan.SEMSC1);
}
if (source.hasTag(Tag.Vulkan.AVDEVICE)) {
target.addTags(Tag.Vulkan.AVDEVICE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ public void testLoad() {
assertNotNull(load);
assertEquals(pointer, load.getAddress());
assertEquals(iType, load.getAccessType());
assertEquals(Set.of(Tag.VISIBLE, Tag.MEMORY, Tag.READ, Tag.Spirv.SC_UNIFORM,
Tag.Spirv.MEM_NON_PRIVATE, Tag.Spirv.DEVICE), load.getTags());
assertEquals(Set.of(Tag.VISIBLE, Tag.MEMORY, Tag.READ, Tag.Spirv.SC_UNIFORM), load.getTags());

Register register = load.getResultRegister();
assertEquals("%result", register.getName());
Expand All @@ -60,7 +59,7 @@ public void testLoad() {
@Test
public void testLoadWithTags() {
// given
String input = "%result = OpLoad %int %ptr MakePointerVisible %scope";
String input = "%result = OpLoad %int %ptr MakePointerVisible|NonPrivatePointer %scope";
IntegerType iType = builder.mockIntType("%int", 32);
builder.mockPtrType("%int_ptr", "%int", "Workgroup");
ScopedPointerVariable pointer = builder.mockVariable("%ptr", "%int_ptr");
Expand Down Expand Up @@ -121,14 +120,13 @@ public void testStore() {
assertEquals(pointer, store.getAddress());
assertEquals(iType, store.getAccessType());
assertEquals(value, store.getMemValue());
assertEquals(Set.of(Tag.VISIBLE, Tag.MEMORY, Tag.WRITE, Tag.Spirv.SC_UNIFORM,
Tag.Spirv.MEM_NON_PRIVATE, Tag.Spirv.DEVICE), store.getTags());
assertEquals(Set.of(Tag.VISIBLE, Tag.MEMORY, Tag.WRITE, Tag.Spirv.SC_UNIFORM), store.getTags());
}

@Test
public void testStoreWithTags() {
// given
String input = "OpStore %ptr %value MakePointerAvailable %scope";
String input = "OpStore %ptr %value MakePointerAvailable|NonPrivatePointer %scope";
IntegerType iType = builder.mockIntType("%int", 32);
builder.mockPtrType("%int_ptr", "%int", "Workgroup");
ScopedPointerVariable pointer = builder.mockVariable("%ptr", "%int_ptr");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ public void testLoad() {
);
doTestLoad(
Set.of(Tag.Spirv.SC_WORKGROUP),
Set.of(Tag.Vulkan.SC1, Tag.Vulkan.NON_PRIVATE, Tag.Vulkan.VISIBLE)
Set.of(Tag.Vulkan.SC1)
);
doTestLoad(
Set.of(Tag.Spirv.MEM_VISIBLE, Tag.Spirv.SC_UNIFORM),
Set.of(Tag.Vulkan.VISIBLE, Tag.Vulkan.SC0, Tag.Vulkan.NON_PRIVATE)
Set.of(Tag.Spirv.MEM_VISIBLE, Tag.Spirv.DEVICE, Tag.Spirv.SC_UNIFORM),
Set.of(Tag.Vulkan.VISIBLE, Tag.Vulkan.DEVICE, Tag.Vulkan.SC0)
);
doTestLoad(
Set.of(Tag.Spirv.MEM_VISIBLE, Tag.Spirv.DEVICE, Tag.Spirv.SC_UNIFORM),
Set.of(Tag.Vulkan.VISIBLE, Tag.Vulkan.DEVICE, Tag.Vulkan.VISDEVICE, Tag.Vulkan.SC0, Tag.Vulkan.NON_PRIVATE)
Set.of(Tag.Spirv.MEM_NON_PRIVATE, Tag.Spirv.MEM_VISIBLE, Tag.Spirv.DEVICE, Tag.Spirv.SC_UNIFORM),
Set.of(Tag.Vulkan.NON_PRIVATE, Tag.Vulkan.VISIBLE, Tag.Vulkan.DEVICE, Tag.Vulkan.SC0)
);
}

Expand Down Expand Up @@ -77,15 +77,15 @@ public void testStore() {
);
doTestStore(
Set.of(Tag.Spirv.SC_WORKGROUP),
Set.of(Tag.Vulkan.SC1, Tag.Vulkan.NON_PRIVATE, Tag.Vulkan.AVAILABLE)
Set.of(Tag.Vulkan.SC1)
);
doTestStore(
Set.of(Tag.Spirv.MEM_AVAILABLE, Tag.Spirv.SC_UNIFORM),
Set.of(Tag.Vulkan.AVAILABLE, Tag.Vulkan.SC0, Tag.Vulkan.NON_PRIVATE)
Set.of(Tag.Spirv.MEM_AVAILABLE, Tag.Spirv.DEVICE, Tag.Spirv.SC_UNIFORM),
Set.of(Tag.Vulkan.AVAILABLE, Tag.Vulkan.DEVICE, Tag.Vulkan.SC0)
);
doTestStore(
Set.of(Tag.Spirv.MEM_AVAILABLE, Tag.Spirv.DEVICE, Tag.Spirv.SC_UNIFORM),
Set.of(Tag.Vulkan.AVAILABLE, Tag.Vulkan.DEVICE, Tag.Vulkan.AVDEVICE, Tag.Vulkan.SC0, Tag.Vulkan.NON_PRIVATE)
Set.of(Tag.Spirv.MEM_NON_PRIVATE, Tag.Spirv.MEM_AVAILABLE, Tag.Spirv.DEVICE, Tag.Spirv.SC_UNIFORM),
Set.of(Tag.Vulkan.NON_PRIVATE, Tag.Vulkan.AVAILABLE, Tag.Vulkan.DEVICE, Tag.Vulkan.SC0)
);
}

Expand Down Expand Up @@ -126,7 +126,7 @@ public void testSpirvLoad() {
);
doTestSpirvLoad(
Set.of(Tag.Spirv.RELAXED, Tag.Spirv.DEVICE, Tag.Spirv.SC_UNIFORM),
Set.of(Tag.Vulkan.DEVICE, Tag.Vulkan.VISDEVICE, Tag.Vulkan.SC0, Tag.Vulkan.NON_PRIVATE)
Set.of(Tag.Vulkan.DEVICE, Tag.Vulkan.SC0, Tag.Vulkan.NON_PRIVATE)
);
}

Expand Down Expand Up @@ -169,7 +169,7 @@ public void testSpirvStore() {
);
doTestSpirvStore(
Set.of(Tag.Spirv.RELAXED, Tag.Spirv.DEVICE, Tag.Spirv.SC_UNIFORM),
Set.of(Tag.Vulkan.DEVICE, Tag.Vulkan.AVDEVICE, Tag.Vulkan.SC0, Tag.Vulkan.NON_PRIVATE)
Set.of(Tag.Vulkan.DEVICE, Tag.Vulkan.SC0, Tag.Vulkan.NON_PRIVATE)
);
}

Expand Down Expand Up @@ -221,8 +221,8 @@ public void testSpirvXchg() {
);
doTestSpirvXchg(
Set.of(Tag.Spirv.RELAXED, Tag.Spirv.DEVICE, Tag.Spirv.SC_WORKGROUP),
Set.of(Tag.Vulkan.DEVICE, Tag.Vulkan.VISDEVICE, Tag.Vulkan.SC1, Tag.Vulkan.NON_PRIVATE),
Set.of(Tag.Vulkan.DEVICE, Tag.Vulkan.AVDEVICE, Tag.Vulkan.SC1, Tag.Vulkan.NON_PRIVATE)
Set.of(Tag.Vulkan.DEVICE, Tag.Vulkan.SC1, Tag.Vulkan.NON_PRIVATE),
Set.of(Tag.Vulkan.DEVICE, Tag.Vulkan.SC1, Tag.Vulkan.NON_PRIVATE)
);
}

Expand Down Expand Up @@ -282,8 +282,8 @@ public void testSpirvRmw() {
);
doTestSpirvRmw(
Set.of(Tag.Spirv.RELAXED, Tag.Spirv.DEVICE, Tag.Spirv.SC_WORKGROUP),
Set.of(Tag.Vulkan.DEVICE, Tag.Vulkan.VISDEVICE, Tag.Vulkan.SC1, Tag.Vulkan.NON_PRIVATE),
Set.of(Tag.Vulkan.DEVICE, Tag.Vulkan.AVDEVICE, Tag.Vulkan.SC1, Tag.Vulkan.NON_PRIVATE)
Set.of(Tag.Vulkan.DEVICE, Tag.Vulkan.SC1, Tag.Vulkan.NON_PRIVATE),
Set.of(Tag.Vulkan.DEVICE, Tag.Vulkan.SC1, Tag.Vulkan.NON_PRIVATE)
);
}

Expand Down Expand Up @@ -350,8 +350,8 @@ public void testSpirvCmpXchg() {
Tag.Spirv.DEVICE,
Set.of(Tag.Spirv.RELAXED, Tag.Spirv.SC_WORKGROUP),
Set.of(Tag.Spirv.RELAXED, Tag.Spirv.SC_WORKGROUP),
Set.of(Tag.Vulkan.DEVICE, Tag.Vulkan.VISDEVICE, Tag.Vulkan.SC1, Tag.Vulkan.NON_PRIVATE),
Set.of(Tag.Vulkan.DEVICE, Tag.Vulkan.AVDEVICE, Tag.Vulkan.SC1, Tag.Vulkan.NON_PRIVATE));
Set.of(Tag.Vulkan.DEVICE, Tag.Vulkan.SC1, Tag.Vulkan.NON_PRIVATE),
Set.of(Tag.Vulkan.DEVICE, Tag.Vulkan.SC1, Tag.Vulkan.NON_PRIVATE));
}

private void doTestSpirvCmpXchg(String scope, Set<String> eqTags, Set<String> neqTags, Set<String> loadTags, Set<String> storeTags) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ public static Iterable<Object[]> data() throws IOException {
{"ttaslock-dv2wg-1.1.2.spv.dis", 1, FAIL},

{"xf-barrier-2.1.2.spv.dis", 4, PASS},
// Slow test
// {"xf-barrier-3.1.3.spv.dis", 9, PASS},
{"xf-barrier-3.1.3.spv.dis", 9, PASS},
// TODO: IMO should pass (spinloop handling?)
// {"xf-barrier-1.1.2.spv.dis", 2, PASS},
{"xf-barrier-2.1.1.spv.dis", 2, PASS},
Expand All @@ -87,8 +86,7 @@ public static Iterable<Object[]> data() throws IOException {
{"xf-barrier-weakest.spv.dis", 4, FAIL},

{"xf-barrier-local-2.1.2.spv.dis", 4, FAIL},
// Slow test
// {"xf-barrier-local-3.1.3.spv.dis", 9, FAIL},
{"xf-barrier-local-3.1.3.spv.dis", 9, FAIL},
// TODO: ??
// {"xf-barrier-local-1.1.2.spv.dis", 2, FAIL},
{"xf-barrier-local-2.1.1.spv.dis", 2, FAIL},
Expand All @@ -99,8 +97,7 @@ public static Iterable<Object[]> data() throws IOException {
{"xf-barrier-local-weakest.spv.dis", 4, FAIL},

{"xf-barrier-zero-2.1.2.spv.dis", 4, FAIL},
// Slow test
// {"xf-barrier-zero-3.1.3.spv.dis", 9, FAIL},
{"xf-barrier-zero-3.1.3.spv.dis", 9, FAIL},
// TODO: ??
// {"xf-barrier-zero-1.1.2.spv.dis", 2, FAIL},
{"xf-barrier-zero-2.1.1.spv.dis", 2, FAIL},
Expand Down
Loading

0 comments on commit 9b4e18a

Please sign in to comment.