Skip to content

Commit

Permalink
Add cross and dot product functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathanhogg committed Oct 10, 2024
1 parent 14002df commit a602bea
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 3 deletions.
7 changes: 7 additions & 0 deletions docs/builtins.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,16 @@ the length of the longest vector; short arguments will repeat, so
generalises to n-vectors, returning a vector the length of the longest of
*x*, *y* and *z*, and repeating items from any of the vectors as necessary.

`cross(` *x*, *y* `)`
: Compute the cross product of two 3-vectors.

`cos(` *x* `)`
: Return cosine of *x* (with *x* expressed in *turns*).

`dot(` *x*, *y* `)`
: Compute the dot product of two n-vectors. This is the equivalent of
`sum(x * y)`.

`exp(` *x* `)`
: Return $e$ raised to the power of *x*.

Expand Down
14 changes: 12 additions & 2 deletions src/flitter/language/functions.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,14 @@ def hypot(*args):
return ys


def cross(Vector xs not None, Vector ys not None):
return xs.cross(ys)


def dot(Vector xs not None, Vector ys not None):
return xs.dot(ys)


def normalize(Vector xs not None):
return xs.normalize()

Expand All @@ -815,9 +823,9 @@ def quaternion(Vector axis not None, Vector angle not None):


def qmul(Vector a not None, Vector b not None):
if a.numbers == NULL or a.length != 4 or b.numbers == NULL or b.length != 4:
if a.numbers == NULL or a.length != 4 or b.numbers == NULL or b.length not in (3, 4):
return null_
return Quaternion._coerce(a) @ Quaternion._coerce(b)
return Quaternion._coerce(a) @ (Quaternion._coerce(b) if b.length == 4 else b)


def qbetween(Vector a not None, Vector b not None):
Expand Down Expand Up @@ -1046,7 +1054,9 @@ STATIC_FUNCTIONS = {
'colortemp': Vector(colortemp),
'cos': Vector(cosv),
'count': Vector(count),
'cross': Vector(cross),
'cubic': Vector(cubic),
'dot': Vector(dot),
'exp': Vector(expv),
'floor': Vector(floorv),
'fract': Vector(fract),
Expand Down
17 changes: 16 additions & 1 deletion tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from flitter.language.functions import (uniform, normal, beta,
lenv, sumv, accumulate, minv, maxv, minindex, maxindex, mapv, clamp, zipv, count,
roundv, absv, expv, sqrtv, logv, log2v, log10v, ceilv, floorv, fract,
cosv, acosv, sinv, asinv, tanv, hypot, normalize, polar, angle, length,
cosv, acosv, sinv, asinv, tanv, hypot, normalize, polar, angle, length, cross, dot,
quaternion, qmul, qbetween, slerp,
split, ordv, chrv,
colortemp, oklab, oklch)
Expand Down Expand Up @@ -537,6 +537,21 @@ def test_normalize(self):
self.assertEqual(normalize(Vector([3, 4])), Vector([3, 4]) / Vector(5))
self.assertEqual(normalize(Vector([3, 4, 5])), Vector([3, 4, 5]) / Vector(math.sqrt(50)))

def test_cross(self):
self.assertEqual(cross(null, null), null)
self.assertEqual(cross(Vector([3, 4]), Vector([3, 4, 5])), null)
self.assertEqual(cross(Vector([3, 4, 5]), Vector([3, 4])), null)
self.assertEqual(cross(Vector([3, 4, 5]), Vector([3, 4, 5])), Vector([0, 0, 0]))
self.assertEqual(cross(Vector([1, 0, 0]), Vector([0, 1, 0])), Vector([0, 0, 1]))
self.assertEqual(cross(Vector([0, 1, 0]), Vector([0, 0, 1])), Vector([1, 0, 0]))
self.assertEqual(cross(Vector([0, 0, 1]), Vector([1, 0, 0])), Vector([0, 1, 0]))

def test_dot(self):
self.assertEqual(dot(null, null), null)
self.assertEqual(dot(Vector([3, 4]), Vector([3, 4, 5])), Vector(3*3 + 4*4 + 3*5))
self.assertEqual(dot(Vector([3, 4, 5]), Vector([3, 4])), Vector(3*3 + 4*4 + 5*3))
self.assertEqual(dot(Vector([3, 4, 5]), Vector([5, 4, 3])), Vector(3*5 + 4*4 + 5*3))

def test_polar(self):
self.assertEqual(polar(null), null)
self.assertEqual(polar(Vector('hello')), null)
Expand Down

0 comments on commit a602bea

Please sign in to comment.