-
Notifications
You must be signed in to change notification settings - Fork 0
/
Derivation.scala
91 lines (81 loc) · 2.4 KB
/
Derivation.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
package ilc
package feature
package base
/** modular derivation */
trait Derivation
extends Syntax
with functions.Syntax // for application
{
def deltaType(tau: Type): Type =
throw IDontKnow(s"the Δ-type of $tau")
/** a term that updates a value according to a change:
*
* update-term : ∀ {τ Γ} → Term Γ (ΔType τ → τ → τ)
*/
// to be specialized by subclasses
def updateTerm(tau: Type): Term =
throw IDontKnow(s"the update-term for type $tau")
/** a term that computes the difference of two values:
*
* diff-term : ∀ {τ Γ} → Term Γ (τ → τ → ΔType τ)
*/
def diffTerm(tau: Type): Term =
throw IDontKnow(s"the diff-term for type $tau")
def derive(t: Term): Term = t match {
case v: Var =>
DVar(v)
// For all terms we don't know how to derive,
// we produce a derivative that does recomputation.
// This makes adding new constants easy.
case _ =>
Diff ! t ! t
}
/** @constructor creates the variable dx
* @param original: the variable x
*
* The cool thing with this set up is that we can nest
* DVars. For example, deriving `Var("x", xType) twice
* yields
*
* val ddx = DVar(Dvar(Var("x", xType)))
*
* such that
*
* ddx.getType == deltaType(deltaType(xType))
*/
object DVar {
def apply(original: Var) = {
val deltaName = original.getName match {
case IndexedName(n, idx) =>
IndexedName(DeltaName(n), idx)
case n: NonIndexedName =>
DeltaName(n)
}
Var(deltaName, deltaType(original.getType))
}
}
object ChangeUpdate extends PolymorphicTerm {
def specialize(argumentTypes: Type*): Term =
argumentTypes take 2 match {
case Seq(changeType, valueType)
if changeType == deltaType(valueType) =>
updateTerm(valueType)
case wrongTypes =>
typeErrorNotTheSame("specializing ChangeUpdate",
"a delta type and a type",
wrongTypes)
}
}
object Diff extends PolymorphicTerm {
def specialize(argumentTypes: Type*): Term =
argumentTypes take 2 match {
case Seq(valueType, valueType2)
if valueType == valueType2 =>
diffTerm(valueType)
case wrongTypes =>
typeErrorNotTheSame("specializing Diff",
"two arguments of the same type",
wrongTypes)
}
}
}