Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

symmetric difference operation for sets via xor #24286

Merged
merged 2 commits into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ rounding guarantees (via the

## Standard library additions and changes

[//]: # "Additions:"
- `setutils.symmetricDifference` along with its operator version
`` setutils.`-+-` `` and in-place version `setutils.toggle` have been added
to more efficiently calculate the symmetric difference of bitsets.

[//]: # "Changes:"
- `std/math` The `^` symbol now supports floating-point as exponent in addition to the Natural type.

Expand Down
4 changes: 2 additions & 2 deletions compiler/ast.nim
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ type
mAnd, mOr,
mImplies, mIff, mExists, mForall, mOld,
mEqStr, mLeStr, mLtStr,
mEqSet, mLeSet, mLtSet, mMulSet, mPlusSet, mMinusSet,
mEqSet, mLeSet, mLtSet, mMulSet, mPlusSet, mMinusSet, mXorSet,
mConStrStr, mSlice,
mDotDot, # this one is only necessary to give nice compile time warnings
mFields, mFieldPairs, mOmpParFor,
Expand Down Expand Up @@ -559,7 +559,7 @@ const
mStrToStr, mEnumToStr,
mAnd, mOr,
mEqStr, mLeStr, mLtStr,
mEqSet, mLeSet, mLtSet, mMulSet, mPlusSet, mMinusSet,
mEqSet, mLeSet, mLtSet, mMulSet, mPlusSet, mMinusSet, mXorSet,
mConStrStr, mAppendStrCh, mAppendStrStr, mAppendSeqElem,
mInSet, mRepr, mOpenArrayToSeq}

Expand Down
10 changes: 6 additions & 4 deletions compiler/ccgexprs.nim
Original file line number Diff line number Diff line change
Expand Up @@ -2044,7 +2044,7 @@ proc genInOp(p: BProc, e: PNode, d: var TLoc) =

proc genSetOp(p: BProc, e: PNode, d: var TLoc, op: TMagic) =
const
lookupOpr: array[mLeSet..mMinusSet, string] = [
lookupOpr: array[mLeSet..mXorSet, string] = [
"for ($1 = 0; $1 < $2; $1++) { $n" &
" $3 = (($4[$1] & ~ $5[$1]) == 0);$n" &
" if (!$3) break;}$n",
Expand All @@ -2054,7 +2054,8 @@ proc genSetOp(p: BProc, e: PNode, d: var TLoc, op: TMagic) =
"if ($3) $3 = (#nimCmpMem($4, $5, $2) != 0);$n",
"&",
"|",
"& ~"]
"& ~",
"^"]
var a, b: TLoc
var i: TLoc
var setType = skipTypes(e[1].typ, abstractVar)
Expand Down Expand Up @@ -2085,6 +2086,7 @@ proc genSetOp(p: BProc, e: PNode, d: var TLoc, op: TMagic) =
of mMulSet: binaryExpr(p, e, d, "($1 & $2)")
of mPlusSet: binaryExpr(p, e, d, "($1 | $2)")
of mMinusSet: binaryExpr(p, e, d, "($1 & ~ $2)")
of mXorSet: binaryExpr(p, e, d, "($1 ^ $2)")
of mInSet:
genInOp(p, e, d)
else: internalError(p.config, e.info, "genSetOp()")
Expand Down Expand Up @@ -2112,7 +2114,7 @@ proc genSetOp(p: BProc, e: PNode, d: var TLoc, op: TMagic) =
var a = initLocExpr(p, e[1])
var b = initLocExpr(p, e[2])
putIntoDest(p, d, e, ropecg(p.module, "(#nimCmpMem($1, $2, $3)==0)", [a.rdCharLoc, b.rdCharLoc, size]))
of mMulSet, mPlusSet, mMinusSet:
of mMulSet, mPlusSet, mMinusSet, mXorSet:
# we inline the simple for loop for better code generation:
i = getTemp(p, getSysType(p.module.g.graph, unknownLineInfo, tyInt)) # our counter
a = initLocExpr(p, e[1])
Expand Down Expand Up @@ -2548,7 +2550,7 @@ proc genMagicExpr(p: BProc, e: PNode, d: var TLoc, op: TMagic) =
of mSetLengthStr: genSetLengthStr(p, e, d)
of mSetLengthSeq: genSetLengthSeq(p, e, d)
of mIncl, mExcl, mCard, mLtSet, mLeSet, mEqSet, mMulSet, mPlusSet, mMinusSet,
mInSet:
mInSet, mXorSet:
genSetOp(p, e, d, op)
of mNewString, mNewStringOfCap, mExit, mParseBiggestFloat:
var opr = e[0].sym
Expand Down
1 change: 1 addition & 0 deletions compiler/condsyms.nim
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,4 @@ proc initDefines*(symbols: StringTableRef) =
defineSymbol("nimHasGenericsOpenSym3")
defineSymbol("nimHasJsNoLambdaLifting")
defineSymbol("nimHasDefaultFloatRoundtrip")
defineSymbol("nimHasXorSet")
1 change: 1 addition & 0 deletions compiler/jsgen.nim
Original file line number Diff line number Diff line change
Expand Up @@ -2458,6 +2458,7 @@ proc genMagic(p: PProc, n: PNode, r: var TCompRes) =
of mMulSet: binaryExpr(p, n, r, "SetMul", "SetMul($1, $2)")
of mPlusSet: binaryExpr(p, n, r, "SetPlus", "SetPlus($1, $2)")
of mMinusSet: binaryExpr(p, n, r, "SetMinus", "SetMinus($1, $2)")
of mXorSet: binaryExpr(p, n, r, "SetXor", "SetXor($1, $2)")
of mIncl: binaryExpr(p, n, r, "", "$1[$2] = true")
of mExcl: binaryExpr(p, n, r, "", "delete $1[$2]")
of mInSet:
Expand Down
3 changes: 3 additions & 0 deletions compiler/semfold.nim
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@ proc evalOp(m: TMagic, n, a, b, c: PNode; idgen: IdGenerator; g: ModuleGraph): P
of mMinusSet:
result = nimsets.diffSets(g.config, a, b)
result.info = n.info
of mXorSet:
result = nimsets.symdiffSets(g.config, a, b)
result.info = n.info
of mConStrStr: result = newStrNodeT(getStrOrChar(a) & getStrOrChar(b), n, g)
of mInSet: result = newIntNodeT(toInt128(ord(inSet(a, b))), n, idgen, g)
of mRepr:
Expand Down
5 changes: 5 additions & 0 deletions compiler/vm.nim
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,11 @@ proc rawExecute(c: PCtx, start: int, tos: PStackFrame): TFullReg =
createSet(regs[ra])
move(regs[ra].node.sons,
nimsets.diffSets(c.config, regs[rb].node, regs[rc].node).sons)
of opcXorSet:
decodeBC(rkNode)
createSet(regs[ra])
move(regs[ra].node.sons,
nimsets.symdiffSets(c.config, regs[rb].node, regs[rc].node).sons)
of opcConcatStr:
decodeBC(rkNode)
createStr regs[ra]
Expand Down
2 changes: 1 addition & 1 deletion compiler/vmdef.nim
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ type
opcEqRef, opcEqNimNode, opcSameNodeType,
opcXor, opcNot, opcUnaryMinusInt, opcUnaryMinusFloat, opcBitnotInt,
opcEqStr, opcEqCString, opcLeStr, opcLtStr, opcEqSet, opcLeSet, opcLtSet,
opcMulSet, opcPlusSet, opcMinusSet, opcConcatStr,
opcMulSet, opcPlusSet, opcMinusSet, opcXorSet, opcConcatStr,
opcContainsSet, opcRepr, opcSetLenStr, opcSetLenSeq,
opcIsNil, opcOf, opcIs,
opcParseFloat, opcConv, opcCast,
Expand Down
1 change: 1 addition & 0 deletions compiler/vmgen.nim
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,7 @@ proc genMagic(c: PCtx; n: PNode; dest: var TDest; flags: TGenFlags = {}, m: TMag
of mMulSet: genBinarySet(c, n, dest, opcMulSet)
of mPlusSet: genBinarySet(c, n, dest, opcPlusSet)
of mMinusSet: genBinarySet(c, n, dest, opcMinusSet)
of mXorSet: genBinarySet(c, n, dest, opcXorSet)
of mConStrStr: genVarargsABC(c, n, dest, opcConcatStr)
of mInSet: genBinarySet(c, n, dest, opcContainsSet)
of mRepr: genUnaryABC(c, n, dest, opcRepr)
Expand Down
28 changes: 28 additions & 0 deletions lib/std/setutils.nim
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,31 @@ func `[]=`*[T](t: var set[T], key: T, val: bool) {.inline.} =
s[a3] = true
assert s == {a2, a3}
if val: t.incl key else: t.excl key

when defined(nimHasXorSet):
func symmetricDifference*[T](x, y: set[T]): set[T] {.magic: "XorSet".} =
## This operator computes the symmetric difference of two sets,
## equivalent to but more efficient than `x + y - x * y` or
## `(x - y) + (y - x)`.
runnableExamples:
assert symmetricDifference({1, 2, 3}, {2, 3, 4}) == {1, 4}
else:
func symmetricDifference*[T](x, y: set[T]): set[T] {.inline.} =
result = x + y - (x * y)

proc `-+-`*[T](x, y: set[T]): set[T] {.inline.} =
## Operator alias for `symmetricDifference`.
runnableExamples:
assert {1, 2, 3} -+- {2, 3, 4} == {1, 4}
result = symmetricDifference(x, y)

proc toggle*[T](x: var set[T], y: set[T]) {.inline.} =
## Toggles the existence of each value of `y` in `x`.
## If any element in `y` is also in `x`, it is excluded from `x`;
## otherwise it is included.
## Equivalent to `x = symmetricDifference(x, y)`.
runnableExamples:
var x = {1, 2, 3}
x.toggle({2, 3, 4})
assert x == {1, 4}
x = symmetricDifference(x, y)
12 changes: 12 additions & 0 deletions lib/system/jssys.nim
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,18 @@ proc SetMinus(a, b: int): int {.compilerproc, asmNoStackFrame.} =
return result;
""".}

proc SetXor(a, b: int): int {.compilerproc, asmNoStackFrame.} =
{.emit: """
var result = {};
for (var elem in `a`) {
if (!`b`[elem]) { result[elem] = true; }
}
for (var elem in `b`) {
if (!`a`[elem]) { result[elem] = true; }
}
return result;
""".}

proc cmpStrings(a, b: string): int {.asmNoStackFrame, compilerproc.} =
{.emit: """
if (`a` == `b`) return 0;
Expand Down
20 changes: 20 additions & 0 deletions tests/stdlib/tsetutils.nim
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,26 @@ template main =
s[a2] = true
s[a3] = true
doAssert s == {a2, a3}

block: # set symmetric difference (xor), https://github.com/nim-lang/RFCs/issues/554
type T = set[range[0..15]]
let x: T = {1, 4, 5, 8, 9}
let y: T = {0, 2..6, 9}
let res = symmetricDifference(x, y)
doAssert res == {0, 1, 2, 3, 6, 8}
doAssert res == (x + y - x * y)
doAssert res == ((x - y) + (y - x))
var z = x
doAssert z == {1, 4, 5, 8, 9}
doAssert z == x
z.toggle(y)
doAssert z == res
z.toggle(y)
doAssert z == x
z.toggle({1, 5})
doAssert z == {4, 8, 9}
z.toggle({3, 8})
doAssert z == {3, 4, 9}

main()
static: main()