Skip to content

Commit

Permalink
path rendering in netview fully working including events to edit etc.…
Browse files Browse the repository at this point in the history
… reasonable layout algorithm seems to work well even for complex networks.
  • Loading branch information
rcoreilly committed Aug 26, 2024
1 parent a2d399e commit 14a576d
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 15 deletions.
8 changes: 6 additions & 2 deletions emer/layer.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ type Layer interface {
// by the algorithm (and usually set by an enum).
TypeName() string

// TypeNumber is the numerical value for the type or category
// of layer, defined by the algorithm (and usually set by an enum).
TypeNumber() int

// UnitVarIndex returns the index of given variable within
// the Neuron, according to *this layer's* UnitVarNames() list
// (using a map to lookup index), or -1 and error message if
Expand Down Expand Up @@ -150,9 +154,9 @@ type LayerBase struct {
// with multple classes.
Class string

// Info contains descriptive information about the layer.
// Doc contains documentation about the layer.
// This is displayed in a tooltip in the network view.
Info string
Doc string

// Off turns off the layer, removing from all computations.
// This provides a convenient way to dynamically test for
Expand Down
4 changes: 2 additions & 2 deletions emer/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ import (
type VarCategory struct {
// Category name.
Cat string
// Description of the category, used as a tooltip.
Desc string
// Documentation of the category, used as a tooltip.
Doc string
}

// Network defines the minimal interface for a neural network,
Expand Down
8 changes: 6 additions & 2 deletions emer/path.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ type Path interface {
// by the algorithm (and usually set by an enum).
TypeName() string

// TypeNumber is the numerical value for the type or category
// of path, defined by the algorithm (and usually set by an enum).
TypeNumber() int

// SendLayer returns the sending layer for this pathway,
// as an emer.Layer interface. The actual Path implmenetation
// can use a Send field with the actual Layer struct type.
Expand Down Expand Up @@ -125,9 +129,9 @@ type PathBase struct {
// with multple classes.
Class string

// Info contains descriptive information about the pathway.
// Doc contains documentation about the pathway.
// This is displayed in a tooltip in the network view.
Info string
Doc string

// can record notes about this pathway here.
Notes string
Expand Down
38 changes: 35 additions & 3 deletions netview/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ func (sw *Scene) Init() {

func (sw *Scene) MouseDownEvent(e events.Event) {
pos := e.Pos().Sub(sw.Geom.ContentBBox.Min)
pt := sw.PathAtPoint(pos)
if pt != nil {
FormDialog(sw, pt, "Path: "+pt.StyleName())
e.SetHandled()
return
}
lay := sw.LayerLabelAtPoint(pos)
if lay != nil {
FormDialog(sw, lay, "Layer: "+lay.StyleName())
Expand All @@ -65,13 +71,25 @@ func (sw *Scene) WidgetTooltip(pos image.Point) (string, image.Point) {
if pos == image.Pt(-1, -1) {
return "_", image.Point{}
}
nv := sw.NetView
lpos := pos.Sub(sw.Geom.ContentBBox.Min)

pt := sw.PathAtPoint(lpos)
if pt != nil {
pe := pt.AsEmer()
tt := "[Click to edit] " + pe.Name + " " + pt.TypeName()
if pe.Doc != "" {
tt += ": " + pe.Doc
}
return tt, pos
}

lay := sw.LayerLabelAtPoint(lpos)
if lay != nil {
le := lay.AsEmer()
tt := "[Click to edit]"
if le.Info != "" {
tt += " " + le.Info
if le.Doc != "" {
tt += " " + le.Doc
}
return tt, pos
}
Expand All @@ -81,7 +99,6 @@ func (sw *Scene) WidgetTooltip(pos image.Point) (string, image.Point) {
return "", pos
}
lb := lay.AsEmer()
nv := sw.NetView

tt := ""
if lb.Is2D() {
Expand Down Expand Up @@ -123,6 +140,21 @@ func (sw *Scene) LayerLabelAtPoint(pos image.Point) emer.Layer {
return nil
}

func (sw *Scene) PathAtPoint(pos image.Point) emer.Path {
ns := xyz.NodesUnderPoint(sw.SceneXYZ(), pos)
for _, n := range ns {
ln, ok := n.(*xyz.Solid)
if ok && ln.Parent != nil {
gpnm := ln.Parent.AsTree().Name
pt, _ := sw.NetView.Net.AsEmer().EmerPathByName(gpnm)
if pt != nil {
return pt
}
}
}
return nil
}

func (sw *Scene) LayerUnitAtPoint(pos image.Point) (lay emer.Layer, lx, ly, unIndex int) {
sc := sw.SceneXYZ()
laysGpi := sc.ChildByName("Layers", 0)
Expand Down
141 changes: 135 additions & 6 deletions netview/netview.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"image/color"
"log"
"log/slog"
"math"
"reflect"
"strings"
"sync"
Expand Down Expand Up @@ -91,6 +92,8 @@ type NetView struct {

// mutex on data access
DataMu sync.RWMutex `display:"-" copier:"-" json:"-" xml:"-"`

hasPaths bool // to detect if paths changes
}

func (nv *NetView) Init() {
Expand Down Expand Up @@ -119,6 +122,8 @@ func (nv *NetView) Init() {
nv.ViewDefaults(se)
laysGp := xyz.NewGroup(se)
laysGp.Name = "Layers"
pathsGp := xyz.NewGroup(se)
pathsGp.Name = "Paths"
})
})
tree.AddChildAt(nv, "counters", func(w *core.Text) {
Expand Down Expand Up @@ -516,7 +521,7 @@ func (nv *NetView) makeVars(netframe *core.Frame) {
tabs := make(map[string]*core.Frame)
for _, ct := range cats {
tf, tb := w.NewTab(ct.Cat)
tb.Tooltip = ct.Desc
tb.Tooltip = ct.Doc
tabs[ct.Cat] = tf
tf.Styler(func(s *styles.Style) {
s.Display = styles.Grid
Expand All @@ -529,7 +534,7 @@ func (nv *NetView) makeVars(netframe *core.Frame) {
for _, vn := range nv.Vars {
cat := ""
pstr := ""
desc := ""
doc := ""
if strings.HasPrefix(vn, "r.") || strings.HasPrefix(vn, "s.") {
pstr = pathprops[vn[2:]]
cat = "Wt" // default
Expand All @@ -539,7 +544,7 @@ func (nv *NetView) makeVars(netframe *core.Frame) {
}
if pstr != "" {
rstr := reflect.StructTag(pstr)
desc = rstr.Get("desc")
doc = rstr.Get("doc")
cat = rstr.Get("cat")
if rstr.Get("display") == "-" {
continue
Expand All @@ -552,8 +557,8 @@ func (nv *NetView) makeVars(netframe *core.Frame) {
tf = tabs[cat]
}
w := core.NewButton(tf).SetText(vn)
if desc != "" {
w.Tooltip = vn + ": " + desc
if doc != "" {
w.Tooltip = vn + ": " + doc
}
w.SetText(vn).SetType(core.ButtonAction)
w.OnClick(func(e events.Event) {
Expand Down Expand Up @@ -596,6 +601,9 @@ func (nv *NetView) UpdateLayers() {
lmesh := errors.Log1(se.MeshByName(ly.StyleName()))
se.SetMesh(lmesh) // does update
}
if nv.hasPaths != nv.Params.Paths {
nv.UpdatePaths()
}
return
}

Expand Down Expand Up @@ -648,10 +656,124 @@ func (nv *NetView) UpdateLayers() {
txt.Styles.Text.Align = styles.Start
txt.Styles.Text.AlignV = styles.Start
}
nv.UpdatePaths()
sw.XYZ.SetNeedsUpdate()
sw.NeedsRender()
}

// UpdatePaths updates the path display.
// Only called when layers have structural changes.
func (nv *NetView) UpdatePaths() {
sw := nv.SceneWidget()
se := sw.SceneXYZ()

nb := nv.Net.AsEmer()
nlay := nv.Net.NumLayers()
pathsGp := se.ChildByName("Paths", 0).(*xyz.Group)
pathsGp.DeleteChildren()

if !nv.Params.Paths {
nv.hasPaths = false
return
}
nv.hasPaths = true

nmin, nmax := nb.MinPos, nb.MaxPos
nsz := nmax.Sub(nmin).Sub(math32.Vec3(1, 1, 0)).Max(math32.Vec3(1, 1, 1))
nsc := math32.Vec3(1.0/nsz.X, 1.0/nsz.Y, 1.0/nsz.Z)
poff := math32.Vector3Scalar(0.5)
poff.Y = -0.5

lineWidth := nv.Params.PathWidth

layPosSize := func(lb *emer.LayerBase) (math32.Vector3, math32.Vector3) {
lp := lb.Pos.Pos
lp.Y = -lp.Y
lp = lp.Sub(nmin).Mul(nsc).Sub(poff)
lp.Y, lp.Z = lp.Z, lp.Y
dsz := lb.DisplaySize()
lsz := math32.Vector3{dsz.X * nsc.X, 0, dsz.Y * nsc.Y}
return lp, lsz
}

// L, R, F, B -- center of each side, z is negative
sideMids := []math32.Vector3{{0, 0, -0.5}, {1, 0, -0.5}, {0.5, 0, 0}, {0.5, 0, -1}}
sideDims := []math32.Dims{math32.Z, math32.Z, math32.X, math32.X}

sideMtx := func(side int, prop float32) math32.Vector3 {
dim := sideDims[side]
smat := sideMids[side]
smat.SetDim(dim, prop)
if dim == math32.Z {
smat.Z *= -1
}
return smat
}

yprop := func(rLayY, sLayY float32) float32 {
if rLayY < sLayY {
return 0.6667
} else if rLayY == sLayY {
return 0.3333
}
return 0
}

for li := range nlay {
ly := nv.Net.EmerLayer(li)
lb := ly.AsEmer()
sLayPos, sLaySz := layPosSize(lb)

var sides [16][]emer.Path
nsp := ly.NumSendPaths()
for pi := range nsp {
sp := ly.SendPath(pi)
rb := sp.RecvLayer().AsEmer()
rLayPos, rLaySz := layPosSize(rb)
minDist := float32(math.MaxFloat32)
minSidx := 0
for sSide := range 4 {
for rSide := range 4 {
prop := yprop(rLayPos.Y, sLayPos.Y) + 0.5*.3333
smat := sideMtx(sSide, prop)
rmat := sideMtx(rSide, prop)
spos := sLayPos.Add(sLaySz.Mul(smat))
rpos := rLayPos.Add(rLaySz.Mul(rmat))
dist := rpos.Sub(spos).Length()
if dist < minDist {
minDist = dist
minSidx = sSide*4 + rSide
}
}
}
sides[minSidx] = append(sides[minSidx], sp)
}
for sSide := range 4 {
for rSide := range 4 {
minSidx := sSide*4 + rSide
pths := sides[minSidx]
nsp := len(pths)
if nsp == 0 {
continue
}
for pi, sp := range pths {
rb := sp.RecvLayer().AsEmer()
sb := sp.AsEmer()
rLayPos, rLaySz := layPosSize(rb)
prop := 0.3333*(float32(pi)+.5)/float32(nsp) + yprop(rLayPos.Y, sLayPos.Y)
smat := sideMtx(sSide, prop)
rmat := sideMtx(rSide, prop)
spos := sLayPos.Add(sLaySz.Mul(smat))
rpos := rLayPos.Add(rLaySz.Mul(rmat))
// xyz.NewLine(se, pathsGp, sb.Name, spos, rpos, lineWidth, clr)
clr := colors.Spaced(sp.TypeNumber())
xyz.NewArrow(se, pathsGp, sb.Name, spos, rpos, lineWidth, clr, xyz.NoStartArrow, xyz.EndArrow, 4, .5, 4)
}
}
}
}
}

// ViewDefaults are the default 3D view params
func (nv *NetView) ViewDefaults(se *xyz.Scene) {
se.Camera.Pose.Pos.Set(0, 1.5, 2.5) // more "top down" view shows more of layers
Expand Down Expand Up @@ -837,7 +959,6 @@ func (nv *NetView) MakeToolbar(p *tree.Plan) {
})
})
tree.Add(p, func(w *core.Separator) {})
tree.Add(p, func(w *core.Separator) {})
tree.Add(p, func(w *core.Button) {
w.SetText("Weights").SetType(core.ButtonAction).SetMenu(func(m *core.Scene) {
fb := core.NewFuncButton(m).SetFunc(nv.SaveWeights)
Expand Down Expand Up @@ -865,6 +986,14 @@ func (nv *NetView) MakeToolbar(p *tree.Plan) {
})
})
tree.Add(p, func(w *core.Separator) {})
tree.Add(p, func(w *core.Switch) {
w.SetText("Paths").SetChecked(nv.Params.Paths).
SetTooltip("Toggles whether pathways between layers are shown or not").
OnChange(func(e events.Event) {
nv.Params.Paths = w.IsChecked()
nv.UpdateView()
})
})
ditp := "data parallel index -- for models running multiple input patterns in parallel, this selects which one is viewed"
tree.Add(p, func(w *core.Text) {
w.SetText("Di:").SetTooltip(ditp)
Expand Down
8 changes: 8 additions & 0 deletions netview/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ func (nv *RasterParams) Defaults() {
// Params holds parameters controlling how the view is rendered
type Params struct { //types:add

// whether to display the pathways between layers as arrows
Paths bool

// width of the path arrows, in normalized units
PathWidth float32 `default:"0.002"`

// raster plot parameters
Raster RasterParams `display:"inline"`

Expand Down Expand Up @@ -88,6 +94,8 @@ func (nv *Params) Defaults() {
nv.Raster.Defaults()
if nv.NVarCols == 0 {
nv.NVarCols = NVarCols
nv.Paths = true
nv.PathWidth = 0.002
}
if nv.MaxRecs == 0 {
nv.MaxRecs = 210 // 200 cycles + 8 phase updates max + 2 extra..
Expand Down

0 comments on commit 14a576d

Please sign in to comment.