Skip to content

Commit

Permalink
Fixed spirv memory operands parser
Browse files Browse the repository at this point in the history
  • Loading branch information
natgavrilenko committed Sep 12, 2024
1 parent 951f326 commit a596a79
Show file tree
Hide file tree
Showing 56 changed files with 585 additions and 16,354 deletions.
18 changes: 11 additions & 7 deletions dartagnan/src/main/antlr4/Spirv.g4
Original file line number Diff line number Diff line change
Expand Up @@ -3231,13 +3231,17 @@ loopControl
;

memoryAccess
: AliasScopeINTELMask idRef
| Aligned literalInteger
| MakePointerAvailable idScope
| MakePointerAvailableKHR idScope
| MakePointerVisible idScope
| MakePointerVisibleKHR idScope
| NoAliasINTELMask idRef
: memoryAccessTag (Pipe memoryAccessTag)* literalInteger? idRef*
;

memoryAccessTag
: AliasScopeINTELMask
| Aligned
| MakePointerAvailable
| MakePointerAvailableKHR
| MakePointerVisible
| MakePointerVisibleKHR
| NoAliasINTELMask
| NonPrivatePointer
| NonPrivatePointerKHR
| None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@
import com.dat3m.dartagnan.program.event.Event;
import com.dat3m.dartagnan.program.event.EventFactory;
import com.dat3m.dartagnan.program.event.Tag;
import org.antlr.v4.runtime.tree.ParseTree;
import org.antlr.v4.runtime.RuleContext;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

Expand Down Expand Up @@ -187,32 +186,14 @@ private void visitOpAccessChain(String id, String typeId, String baseId,
}

private Set<String> parseMemoryAccessTags(SpirvParser.MemoryAccessContext ctx) {
if (ctx == null || ctx.None() != null) {
return Set.of();
if (ctx != null) {
List<String> operands = ctx.memoryAccessTag().stream().map(RuleContext::getText).toList();
Integer alignment = ctx.literalInteger() != null ? Integer.parseInt(ctx.literalInteger().getText()) : null;
List<String> paramIds = ctx.idRef().stream().map(RuleContext::getText).toList();
List<Expression> params = ctx.idRef().stream().map(c -> builder.getExpression(c.getText())).toList();
return HelperTags.parseMemoryOperandsTags(operands, alignment, paramIds, params);
}
if (ctx.Volatile() != null) {
return Set.of(Tag.Spirv.MEM_VOLATILE);
}
if (ctx.Nontemporal() != null) {
return Set.of(Tag.Spirv.MEM_NON_TEMPORAL);
}
if (ctx.NonPrivatePointer() != null || ctx.NonPrivatePointerKHR() != null) {
return Set.of(Tag.Spirv.MEM_NON_PRIVATE);
}
if (ctx.idScope() != null) {
String scopeId = ctx.idScope().getText();
String scopeTag = HelperTags.parseScope(scopeId, builder.getExpression(scopeId));
Set<String> tags = new HashSet<>(Set.of(scopeTag, Tag.Spirv.MEM_NON_PRIVATE));
if (ctx.MakePointerAvailable() != null || ctx.MakePointerAvailableKHR() != null) {
tags.add(Tag.Spirv.MEM_AVAILABLE);
}
if (ctx.MakePointerVisible() != null || ctx.MakePointerVisibleKHR() != null) {
tags.add(Tag.Spirv.MEM_VISIBLE);
}
return tags;
}
throw new ParsingException("Unsupported memory access tag '%s'",
String.join(" ", ctx.children.stream().map(ParseTree::getText).toList()));
return Set.of();
}

private String getScope(String storageClass) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.dat3m.dartagnan.exception.ParsingException;
import com.dat3m.dartagnan.expression.Expression;
import com.dat3m.dartagnan.expression.integers.IntLiteral;
import com.dat3m.dartagnan.program.event.Tag;
import com.google.common.collect.Sets;

import java.util.*;
Expand Down Expand Up @@ -37,6 +38,89 @@ public static Set<String> parseMemorySemanticsTags(String id, Expression expr) {
return tags;
}

public static Set<String> parseMemoryOperandsTags(List<String> operands, Integer alignment,
List<String> paramIds, List<Expression> params) {
List<String> tagList = parseTagList(operands, alignment);
Set<String> tags = new HashSet<>(tagList);
if (tagList.size() != tags.size()) {
throwDuplicatesException(operands);
}
tags = parseMemoryAccessParameters(operands, tags, params, paramIds);
if (tags.contains(Tag.Spirv.MEM_VISIBLE) && tags.contains(Tag.Spirv.MEM_AVAILABLE)) {
// TODO: This is a legal combination for OpCopyMemory and OpCopyMemorySized.
// Refactor spirv tags to have av_scope and vis_scope.
throw new ParsingException("Unsupported combination of memory operands '%s'",
String.join("|", operands));
}
return tags;
}

private static List<String> parseTagList(List<String> operands, Integer alignment) {
boolean isNone = false;
boolean isAligned = false;
List<String> tagList = new LinkedList<>();
for (String tag : operands) {
switch (tag) {
case "None" -> {
if (isNone) {
throwDuplicatesException(operands);
}
isNone = true;
}
case "Volatile" -> tagList.add(Tag.Spirv.MEM_VOLATILE);
case "Aligned" -> {
if (isAligned) {
throwDuplicatesException(operands);
}
isAligned = true;
}
case "Nontemporal" -> tagList.add(Tag.Spirv.MEM_NON_TEMPORAL);
case "MakePointerAvailable", "MakePointerAvailableKHR" -> tagList.add(Tag.Spirv.MEM_AVAILABLE);
case "MakePointerVisible", "MakePointerVisibleKHR" -> tagList.add(Tag.Spirv.MEM_VISIBLE);
case "NonPrivatePointer", "NonPrivatePointerKHR" -> tagList.add(Tag.Spirv.MEM_NON_PRIVATE);
case "AliasScopeINTELMask", "NoAliasINTELMask" ->
throw new ParsingException("Unsupported memory operand '%s'", tag);
default -> throw new ParsingException("Unexpected memory operand '%s'", tag);
}
}
if (isNone && (isAligned || !tagList.isEmpty())) {
throw new ParsingException("Memory operand 'None' cannot be combined with other operands");
}
if (isAligned && alignment == null || !isAligned && alignment != null) {
throwIllegalParametersException(operands);
}
// TODO: Implement proper support for alignment
return tagList;
}

private static Set<String> parseMemoryAccessParameters(List<String> operands, Set<String> tags, List<Expression> parameters, List<String> parameterIds) {
int i = 0;
for (String tag : List.of(Tag.Spirv.MEM_AVAILABLE, Tag.Spirv.MEM_VISIBLE)) {
if (tags.contains(tag)) {
if (parameters.size() <= i) {
throwIllegalParametersException(operands);
}
String scopeTag = HelperTags.parseScope(parameterIds.get(i), parameters.get(i));
tags.add(scopeTag);
i++;
}
}
if (i != parameters.size()) {
throwIllegalParametersException(operands);
}
return tags;
}

private static void throwDuplicatesException(List<String> operands) {
throw new ParsingException("Duplicated memory operands definition(s) in '%s'",
String.join("|", operands));
}

private static void throwIllegalParametersException(List<String> operands) {
throw new ParsingException("Illegal parameter(s) in memory operands definition '%s'",
String.join("|", operands));
}

public static String parseScope(String id, Expression expr) {
int value = getIntValue(id, expr);
if (value >= 0 && value < scopes.size()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,35 @@ public class ParserSpirvTest {

@Test
public void testParsingProgram() throws IOException {
String path = Paths.get(getTestResourcePath("parsers/program/spirv/valid/fibonacci.spv.dis")).toString();
doTestParsingValidProgram("fibonacci.spv.dis");
doTestParsingValidProgram("mp-memory-operands.spv.dis");
}

@Test
public void testInvalidControlFlow() throws IOException {
String error = "Unexpected operation 'OpLogicalNot'";
doTestParsingInvalidProgram("control-flow/malformed-selection-merge-label.spv.dis", error);
doTestParsingInvalidProgram("control-flow/malformed-selection-merge.spv.dis", error);
doTestParsingInvalidProgram("control-flow/malformed-loop-merge.spv.dis", error);
doTestParsingInvalidProgram("control-flow/malformed-loop-merge-true-label.spv.dis", error);
}

@Test
public void testInvalidMemoryOperands() throws IOException {
doTestParsingInvalidProgram("memory-operands/illegal-parameter-order-1.spv.dis", null);
doTestParsingInvalidProgram("memory-operands/illegal-parameter-order-2.spv.dis", null);
doTestParsingInvalidProgram("memory-operands/missing-alignment.spv.dis", null);
doTestParsingInvalidProgram("memory-operands/missing-scope-1.spv.dis", null);
doTestParsingInvalidProgram("memory-operands/missing-scope-2.spv.dis", null);
doTestParsingInvalidProgram("memory-operands/unnecessary-alignment-1.spv.dis", null);
doTestParsingInvalidProgram("memory-operands/unnecessary-alignment-2.spv.dis", null);
doTestParsingInvalidProgram("memory-operands/unnecessary-alignment-3.spv.dis", null);
doTestParsingInvalidProgram("memory-operands/unnecessary-scope-1.spv.dis", null);
doTestParsingInvalidProgram("memory-operands/unnecessary-scope-2.spv.dis", null);
}

private void doTestParsingValidProgram(String file) throws IOException {
String path = Paths.get(getTestResourcePath("parsers/program/spirv/valid/" + file)).toString();
try (FileInputStream stream = new FileInputStream(path)) {
CharStream charStream = CharStreams.fromStream(stream);
ParserSpirv parser = new ParserSpirv();
Expand All @@ -26,15 +54,7 @@ public void testParsingProgram() throws IOException {
}
}

@Test
public void testParsingInvalidProgram() throws IOException {
doTestParsingInvalidProgram("malformed-selection-merge-label.spv.dis");
doTestParsingInvalidProgram("malformed-selection-merge.spv.dis");
doTestParsingInvalidProgram("malformed-loop-merge.spv.dis");
doTestParsingInvalidProgram("malformed-loop-merge-true-label.spv.dis");
}

private void doTestParsingInvalidProgram(String file) throws IOException {
private void doTestParsingInvalidProgram(String file, String error) throws IOException {
String path = Paths.get(getTestResourcePath("parsers/program/spirv/invalid/" + file)).toString();
try (FileInputStream stream = new FileInputStream(path)) {
CharStream charStream = CharStreams.fromStream(stream);
Expand All @@ -43,7 +63,9 @@ private void doTestParsingInvalidProgram(String file) throws IOException {
parser.parse(charStream);
fail("Should throw exception");
} catch (ParsingException e) {
assertEquals("Unexpected operation 'OpLogicalNot'", e.getMessage());
if (error != null) {
assertEquals(error, e.getMessage());
}
}
}
}
Expand Down
Loading

0 comments on commit a596a79

Please sign in to comment.