Skip to content

Commit

Permalink
symmetric difference operation for sets via xor (#24286)
Browse files Browse the repository at this point in the history
closes nim-lang/RFCs#554

Adds a symmetric difference operation to the language bitset type. This
maps to a simple `xor` operation on the backend and thus is likely
faster than the current alternatives, namely `(a - b) + (b - a)` or `a +
b - a * b`. The compiler VM implementation of bitsets already
implemented this via `symdiffSets` but it was never used.

The standalone binary operation is added to `setutils`, named
`symmetricDifference` in line with [hash
sets](https://nim-lang.org/docs/sets.html#symmetricDifference%2CHashSet%5BA%5D%2CHashSet%5BA%5D).
An operator version `-+-` and an in-place version like `toggle` as
described in the RFC are also added, implemented as trivial sugar.
  • Loading branch information
metagn authored Oct 19, 2024
1 parent 0a058a6 commit ae9287c
Show file tree
Hide file tree
Showing 12 changed files with 85 additions and 7 deletions.
5 changes: 5 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ rounding guarantees (via the

## Standard library additions and changes

[//]: # "Additions:"
- `setutils.symmetricDifference` along with its operator version
`` setutils.`-+-` `` and in-place version `setutils.toggle` have been added
to more efficiently calculate the symmetric difference of bitsets.

[//]: # "Changes:"
- `std/math` The `^` symbol now supports floating-point as exponent in addition to the Natural type.

Expand Down
4 changes: 2 additions & 2 deletions compiler/ast.nim
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ type
mAnd, mOr,
mImplies, mIff, mExists, mForall, mOld,
mEqStr, mLeStr, mLtStr,
mEqSet, mLeSet, mLtSet, mMulSet, mPlusSet, mMinusSet,
mEqSet, mLeSet, mLtSet, mMulSet, mPlusSet, mMinusSet, mXorSet,
mConStrStr, mSlice,
mDotDot, # this one is only necessary to give nice compile time warnings
mFields, mFieldPairs, mOmpParFor,
Expand Down Expand Up @@ -559,7 +559,7 @@ const
mStrToStr, mEnumToStr,
mAnd, mOr,
mEqStr, mLeStr, mLtStr,
mEqSet, mLeSet, mLtSet, mMulSet, mPlusSet, mMinusSet,
mEqSet, mLeSet, mLtSet, mMulSet, mPlusSet, mMinusSet, mXorSet,
mConStrStr, mAppendStrCh, mAppendStrStr, mAppendSeqElem,
mInSet, mRepr, mOpenArrayToSeq}

Expand Down
10 changes: 6 additions & 4 deletions compiler/ccgexprs.nim
Original file line number Diff line number Diff line change
Expand Up @@ -2044,7 +2044,7 @@ proc genInOp(p: BProc, e: PNode, d: var TLoc) =

proc genSetOp(p: BProc, e: PNode, d: var TLoc, op: TMagic) =
const
lookupOpr: array[mLeSet..mMinusSet, string] = [
lookupOpr: array[mLeSet..mXorSet, string] = [
"for ($1 = 0; $1 < $2; $1++) { $n" &
" $3 = (($4[$1] & ~ $5[$1]) == 0);$n" &
" if (!$3) break;}$n",
Expand All @@ -2054,7 +2054,8 @@ proc genSetOp(p: BProc, e: PNode, d: var TLoc, op: TMagic) =
"if ($3) $3 = (#nimCmpMem($4, $5, $2) != 0);$n",
"&",
"|",
"& ~"]
"& ~",
"^"]
var a, b: TLoc
var i: TLoc
var setType = skipTypes(e[1].typ, abstractVar)
Expand Down Expand Up @@ -2085,6 +2086,7 @@ proc genSetOp(p: BProc, e: PNode, d: var TLoc, op: TMagic) =
of mMulSet: binaryExpr(p, e, d, "($1 & $2)")
of mPlusSet: binaryExpr(p, e, d, "($1 | $2)")
of mMinusSet: binaryExpr(p, e, d, "($1 & ~ $2)")
of mXorSet: binaryExpr(p, e, d, "($1 ^ $2)")
of mInSet:
genInOp(p, e, d)
else: internalError(p.config, e.info, "genSetOp()")
Expand Down Expand Up @@ -2112,7 +2114,7 @@ proc genSetOp(p: BProc, e: PNode, d: var TLoc, op: TMagic) =
var a = initLocExpr(p, e[1])
var b = initLocExpr(p, e[2])
putIntoDest(p, d, e, ropecg(p.module, "(#nimCmpMem($1, $2, $3)==0)", [a.rdCharLoc, b.rdCharLoc, size]))
of mMulSet, mPlusSet, mMinusSet:
of mMulSet, mPlusSet, mMinusSet, mXorSet:
# we inline the simple for loop for better code generation:
i = getTemp(p, getSysType(p.module.g.graph, unknownLineInfo, tyInt)) # our counter
a = initLocExpr(p, e[1])
Expand Down Expand Up @@ -2548,7 +2550,7 @@ proc genMagicExpr(p: BProc, e: PNode, d: var TLoc, op: TMagic) =
of mSetLengthStr: genSetLengthStr(p, e, d)
of mSetLengthSeq: genSetLengthSeq(p, e, d)
of mIncl, mExcl, mCard, mLtSet, mLeSet, mEqSet, mMulSet, mPlusSet, mMinusSet,
mInSet:
mInSet, mXorSet:
genSetOp(p, e, d, op)
of mNewString, mNewStringOfCap, mExit, mParseBiggestFloat:
var opr = e[0].sym
Expand Down
1 change: 1 addition & 0 deletions compiler/condsyms.nim
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,4 @@ proc initDefines*(symbols: StringTableRef) =
defineSymbol("nimHasGenericsOpenSym3")
defineSymbol("nimHasJsNoLambdaLifting")
defineSymbol("nimHasDefaultFloatRoundtrip")
defineSymbol("nimHasXorSet")
1 change: 1 addition & 0 deletions compiler/jsgen.nim
Original file line number Diff line number Diff line change
Expand Up @@ -2458,6 +2458,7 @@ proc genMagic(p: PProc, n: PNode, r: var TCompRes) =
of mMulSet: binaryExpr(p, n, r, "SetMul", "SetMul($1, $2)")
of mPlusSet: binaryExpr(p, n, r, "SetPlus", "SetPlus($1, $2)")
of mMinusSet: binaryExpr(p, n, r, "SetMinus", "SetMinus($1, $2)")
of mXorSet: binaryExpr(p, n, r, "SetXor", "SetXor($1, $2)")
of mIncl: binaryExpr(p, n, r, "", "$1[$2] = true")
of mExcl: binaryExpr(p, n, r, "", "delete $1[$2]")
of mInSet:
Expand Down
3 changes: 3 additions & 0 deletions compiler/semfold.nim
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@ proc evalOp(m: TMagic, n, a, b, c: PNode; idgen: IdGenerator; g: ModuleGraph): P
of mMinusSet:
result = nimsets.diffSets(g.config, a, b)
result.info = n.info
of mXorSet:
result = nimsets.symdiffSets(g.config, a, b)
result.info = n.info
of mConStrStr: result = newStrNodeT(getStrOrChar(a) & getStrOrChar(b), n, g)
of mInSet: result = newIntNodeT(toInt128(ord(inSet(a, b))), n, idgen, g)
of mRepr:
Expand Down
5 changes: 5 additions & 0 deletions compiler/vm.nim
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,11 @@ proc rawExecute(c: PCtx, start: int, tos: PStackFrame): TFullReg =
createSet(regs[ra])
move(regs[ra].node.sons,
nimsets.diffSets(c.config, regs[rb].node, regs[rc].node).sons)
of opcXorSet:
decodeBC(rkNode)
createSet(regs[ra])
move(regs[ra].node.sons,
nimsets.symdiffSets(c.config, regs[rb].node, regs[rc].node).sons)
of opcConcatStr:
decodeBC(rkNode)
createStr regs[ra]
Expand Down
2 changes: 1 addition & 1 deletion compiler/vmdef.nim
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ type
opcEqRef, opcEqNimNode, opcSameNodeType,
opcXor, opcNot, opcUnaryMinusInt, opcUnaryMinusFloat, opcBitnotInt,
opcEqStr, opcEqCString, opcLeStr, opcLtStr, opcEqSet, opcLeSet, opcLtSet,
opcMulSet, opcPlusSet, opcMinusSet, opcConcatStr,
opcMulSet, opcPlusSet, opcMinusSet, opcXorSet, opcConcatStr,
opcContainsSet, opcRepr, opcSetLenStr, opcSetLenSeq,
opcIsNil, opcOf, opcIs,
opcParseFloat, opcConv, opcCast,
Expand Down
1 change: 1 addition & 0 deletions compiler/vmgen.nim
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,7 @@ proc genMagic(c: PCtx; n: PNode; dest: var TDest; flags: TGenFlags = {}, m: TMag
of mMulSet: genBinarySet(c, n, dest, opcMulSet)
of mPlusSet: genBinarySet(c, n, dest, opcPlusSet)
of mMinusSet: genBinarySet(c, n, dest, opcMinusSet)
of mXorSet: genBinarySet(c, n, dest, opcXorSet)
of mConStrStr: genVarargsABC(c, n, dest, opcConcatStr)
of mInSet: genBinarySet(c, n, dest, opcContainsSet)
of mRepr: genUnaryABC(c, n, dest, opcRepr)
Expand Down
28 changes: 28 additions & 0 deletions lib/std/setutils.nim
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,31 @@ func `[]=`*[T](t: var set[T], key: T, val: bool) {.inline.} =
s[a3] = true
assert s == {a2, a3}
if val: t.incl key else: t.excl key

when defined(nimHasXorSet):
func symmetricDifference*[T](x, y: set[T]): set[T] {.magic: "XorSet".} =
## This operator computes the symmetric difference of two sets,
## equivalent to but more efficient than `x + y - x * y` or
## `(x - y) + (y - x)`.
runnableExamples:
assert symmetricDifference({1, 2, 3}, {2, 3, 4}) == {1, 4}
else:
func symmetricDifference*[T](x, y: set[T]): set[T] {.inline.} =
result = x + y - (x * y)

proc `-+-`*[T](x, y: set[T]): set[T] {.inline.} =
## Operator alias for `symmetricDifference`.
runnableExamples:
assert {1, 2, 3} -+- {2, 3, 4} == {1, 4}
result = symmetricDifference(x, y)

proc toggle*[T](x: var set[T], y: set[T]) {.inline.} =
## Toggles the existence of each value of `y` in `x`.
## If any element in `y` is also in `x`, it is excluded from `x`;
## otherwise it is included.
## Equivalent to `x = symmetricDifference(x, y)`.
runnableExamples:
var x = {1, 2, 3}
x.toggle({2, 3, 4})
assert x == {1, 4}
x = symmetricDifference(x, y)
12 changes: 12 additions & 0 deletions lib/system/jssys.nim
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,18 @@ proc SetMinus(a, b: int): int {.compilerproc, asmNoStackFrame.} =
return result;
""".}

proc SetXor(a, b: int): int {.compilerproc, asmNoStackFrame.} =
{.emit: """
var result = {};
for (var elem in `a`) {
if (!`b`[elem]) { result[elem] = true; }
}
for (var elem in `b`) {
if (!`a`[elem]) { result[elem] = true; }
}
return result;
""".}

proc cmpStrings(a, b: string): int {.asmNoStackFrame, compilerproc.} =
{.emit: """
if (`a` == `b`) return 0;
Expand Down
20 changes: 20 additions & 0 deletions tests/stdlib/tsetutils.nim
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,26 @@ template main =
s[a2] = true
s[a3] = true
doAssert s == {a2, a3}

block: # set symmetric difference (xor), https://github.com/nim-lang/RFCs/issues/554
type T = set[range[0..15]]
let x: T = {1, 4, 5, 8, 9}
let y: T = {0, 2..6, 9}
let res = symmetricDifference(x, y)
doAssert res == {0, 1, 2, 3, 6, 8}
doAssert res == (x + y - x * y)
doAssert res == ((x - y) + (y - x))
var z = x
doAssert z == {1, 4, 5, 8, 9}
doAssert z == x
z.toggle(y)
doAssert z == res
z.toggle(y)
doAssert z == x
z.toggle({1, 5})
doAssert z == {4, 8, 9}
z.toggle({3, 8})
doAssert z == {3, 4, 9}

main()
static: main()

0 comments on commit ae9287c

Please sign in to comment.