diff --git a/emer/layer.go b/emer/layer.go index 516af8d..93f2ba9 100644 --- a/emer/layer.go +++ b/emer/layer.go @@ -50,6 +50,10 @@ type Layer interface { // by the algorithm (and usually set by an enum). TypeName() string + // TypeNumber is the numerical value for the type or category + // of layer, defined by the algorithm (and usually set by an enum). + TypeNumber() int + // UnitVarIndex returns the index of given variable within // the Neuron, according to *this layer's* UnitVarNames() list // (using a map to lookup index), or -1 and error message if @@ -150,9 +154,9 @@ type LayerBase struct { // with multple classes. Class string - // Info contains descriptive information about the layer. + // Doc contains documentation about the layer. // This is displayed in a tooltip in the network view. - Info string + Doc string // Off turns off the layer, removing from all computations. // This provides a convenient way to dynamically test for diff --git a/emer/network.go b/emer/network.go index 62e0653..8c19908 100644 --- a/emer/network.go +++ b/emer/network.go @@ -24,8 +24,8 @@ import ( type VarCategory struct { // Category name. Cat string - // Description of the category, used as a tooltip. - Desc string + // Documentation of the category, used as a tooltip. + Doc string } // Network defines the minimal interface for a neural network, diff --git a/emer/path.go b/emer/path.go index 66f3a03..eed35b6 100644 --- a/emer/path.go +++ b/emer/path.go @@ -37,6 +37,10 @@ type Path interface { // by the algorithm (and usually set by an enum). TypeName() string + // TypeNumber is the numerical value for the type or category + // of path, defined by the algorithm (and usually set by an enum). + TypeNumber() int + // SendLayer returns the sending layer for this pathway, // as an emer.Layer interface. The actual Path implmenetation // can use a Send field with the actual Layer struct type. @@ -125,9 +129,9 @@ type PathBase struct { // with multple classes. Class string - // Info contains descriptive information about the pathway. + // Doc contains documentation about the pathway. // This is displayed in a tooltip in the network view. - Info string + Doc string // can record notes about this pathway here. Notes string diff --git a/netview/events.go b/netview/events.go index 1581132..a9b3f18 100644 --- a/netview/events.go +++ b/netview/events.go @@ -44,6 +44,12 @@ func (sw *Scene) Init() { func (sw *Scene) MouseDownEvent(e events.Event) { pos := e.Pos().Sub(sw.Geom.ContentBBox.Min) + pt := sw.PathAtPoint(pos) + if pt != nil { + FormDialog(sw, pt, "Path: "+pt.StyleName()) + e.SetHandled() + return + } lay := sw.LayerLabelAtPoint(pos) if lay != nil { FormDialog(sw, lay, "Layer: "+lay.StyleName()) @@ -65,13 +71,25 @@ func (sw *Scene) WidgetTooltip(pos image.Point) (string, image.Point) { if pos == image.Pt(-1, -1) { return "_", image.Point{} } + nv := sw.NetView lpos := pos.Sub(sw.Geom.ContentBBox.Min) + + pt := sw.PathAtPoint(lpos) + if pt != nil { + pe := pt.AsEmer() + tt := "[Click to edit] " + pe.Name + " " + pt.TypeName() + if pe.Doc != "" { + tt += ": " + pe.Doc + } + return tt, pos + } + lay := sw.LayerLabelAtPoint(lpos) if lay != nil { le := lay.AsEmer() tt := "[Click to edit]" - if le.Info != "" { - tt += " " + le.Info + if le.Doc != "" { + tt += " " + le.Doc } return tt, pos } @@ -81,7 +99,6 @@ func (sw *Scene) WidgetTooltip(pos image.Point) (string, image.Point) { return "", pos } lb := lay.AsEmer() - nv := sw.NetView tt := "" if lb.Is2D() { @@ -123,6 +140,21 @@ func (sw *Scene) LayerLabelAtPoint(pos image.Point) emer.Layer { return nil } +func (sw *Scene) PathAtPoint(pos image.Point) emer.Path { + ns := xyz.NodesUnderPoint(sw.SceneXYZ(), pos) + for _, n := range ns { + ln, ok := n.(*xyz.Solid) + if ok && ln.Parent != nil { + gpnm := ln.Parent.AsTree().Name + pt, _ := sw.NetView.Net.AsEmer().EmerPathByName(gpnm) + if pt != nil { + return pt + } + } + } + return nil +} + func (sw *Scene) LayerUnitAtPoint(pos image.Point) (lay emer.Layer, lx, ly, unIndex int) { sc := sw.SceneXYZ() laysGpi := sc.ChildByName("Layers", 0) diff --git a/netview/netview.go b/netview/netview.go index 66acd4d..9919c25 100644 --- a/netview/netview.go +++ b/netview/netview.go @@ -15,6 +15,7 @@ import ( "image/color" "log" "log/slog" + "math" "reflect" "strings" "sync" @@ -91,6 +92,8 @@ type NetView struct { // mutex on data access DataMu sync.RWMutex `display:"-" copier:"-" json:"-" xml:"-"` + + hasPaths bool // to detect if paths changes } func (nv *NetView) Init() { @@ -119,6 +122,8 @@ func (nv *NetView) Init() { nv.ViewDefaults(se) laysGp := xyz.NewGroup(se) laysGp.Name = "Layers" + pathsGp := xyz.NewGroup(se) + pathsGp.Name = "Paths" }) }) tree.AddChildAt(nv, "counters", func(w *core.Text) { @@ -516,7 +521,7 @@ func (nv *NetView) makeVars(netframe *core.Frame) { tabs := make(map[string]*core.Frame) for _, ct := range cats { tf, tb := w.NewTab(ct.Cat) - tb.Tooltip = ct.Desc + tb.Tooltip = ct.Doc tabs[ct.Cat] = tf tf.Styler(func(s *styles.Style) { s.Display = styles.Grid @@ -529,7 +534,7 @@ func (nv *NetView) makeVars(netframe *core.Frame) { for _, vn := range nv.Vars { cat := "" pstr := "" - desc := "" + doc := "" if strings.HasPrefix(vn, "r.") || strings.HasPrefix(vn, "s.") { pstr = pathprops[vn[2:]] cat = "Wt" // default @@ -539,7 +544,7 @@ func (nv *NetView) makeVars(netframe *core.Frame) { } if pstr != "" { rstr := reflect.StructTag(pstr) - desc = rstr.Get("desc") + doc = rstr.Get("doc") cat = rstr.Get("cat") if rstr.Get("display") == "-" { continue @@ -552,8 +557,8 @@ func (nv *NetView) makeVars(netframe *core.Frame) { tf = tabs[cat] } w := core.NewButton(tf).SetText(vn) - if desc != "" { - w.Tooltip = vn + ": " + desc + if doc != "" { + w.Tooltip = vn + ": " + doc } w.SetText(vn).SetType(core.ButtonAction) w.OnClick(func(e events.Event) { @@ -596,6 +601,9 @@ func (nv *NetView) UpdateLayers() { lmesh := errors.Log1(se.MeshByName(ly.StyleName())) se.SetMesh(lmesh) // does update } + if nv.hasPaths != nv.Params.Paths { + nv.UpdatePaths() + } return } @@ -648,10 +656,124 @@ func (nv *NetView) UpdateLayers() { txt.Styles.Text.Align = styles.Start txt.Styles.Text.AlignV = styles.Start } + nv.UpdatePaths() sw.XYZ.SetNeedsUpdate() sw.NeedsRender() } +// UpdatePaths updates the path display. +// Only called when layers have structural changes. +func (nv *NetView) UpdatePaths() { + sw := nv.SceneWidget() + se := sw.SceneXYZ() + + nb := nv.Net.AsEmer() + nlay := nv.Net.NumLayers() + pathsGp := se.ChildByName("Paths", 0).(*xyz.Group) + pathsGp.DeleteChildren() + + if !nv.Params.Paths { + nv.hasPaths = false + return + } + nv.hasPaths = true + + nmin, nmax := nb.MinPos, nb.MaxPos + nsz := nmax.Sub(nmin).Sub(math32.Vec3(1, 1, 0)).Max(math32.Vec3(1, 1, 1)) + nsc := math32.Vec3(1.0/nsz.X, 1.0/nsz.Y, 1.0/nsz.Z) + poff := math32.Vector3Scalar(0.5) + poff.Y = -0.5 + + lineWidth := nv.Params.PathWidth + + layPosSize := func(lb *emer.LayerBase) (math32.Vector3, math32.Vector3) { + lp := lb.Pos.Pos + lp.Y = -lp.Y + lp = lp.Sub(nmin).Mul(nsc).Sub(poff) + lp.Y, lp.Z = lp.Z, lp.Y + dsz := lb.DisplaySize() + lsz := math32.Vector3{dsz.X * nsc.X, 0, dsz.Y * nsc.Y} + return lp, lsz + } + + // L, R, F, B -- center of each side, z is negative + sideMids := []math32.Vector3{{0, 0, -0.5}, {1, 0, -0.5}, {0.5, 0, 0}, {0.5, 0, -1}} + sideDims := []math32.Dims{math32.Z, math32.Z, math32.X, math32.X} + + sideMtx := func(side int, prop float32) math32.Vector3 { + dim := sideDims[side] + smat := sideMids[side] + smat.SetDim(dim, prop) + if dim == math32.Z { + smat.Z *= -1 + } + return smat + } + + yprop := func(rLayY, sLayY float32) float32 { + if rLayY < sLayY { + return 0.6667 + } else if rLayY == sLayY { + return 0.3333 + } + return 0 + } + + for li := range nlay { + ly := nv.Net.EmerLayer(li) + lb := ly.AsEmer() + sLayPos, sLaySz := layPosSize(lb) + + var sides [16][]emer.Path + nsp := ly.NumSendPaths() + for pi := range nsp { + sp := ly.SendPath(pi) + rb := sp.RecvLayer().AsEmer() + rLayPos, rLaySz := layPosSize(rb) + minDist := float32(math.MaxFloat32) + minSidx := 0 + for sSide := range 4 { + for rSide := range 4 { + prop := yprop(rLayPos.Y, sLayPos.Y) + 0.5*.3333 + smat := sideMtx(sSide, prop) + rmat := sideMtx(rSide, prop) + spos := sLayPos.Add(sLaySz.Mul(smat)) + rpos := rLayPos.Add(rLaySz.Mul(rmat)) + dist := rpos.Sub(spos).Length() + if dist < minDist { + minDist = dist + minSidx = sSide*4 + rSide + } + } + } + sides[minSidx] = append(sides[minSidx], sp) + } + for sSide := range 4 { + for rSide := range 4 { + minSidx := sSide*4 + rSide + pths := sides[minSidx] + nsp := len(pths) + if nsp == 0 { + continue + } + for pi, sp := range pths { + rb := sp.RecvLayer().AsEmer() + sb := sp.AsEmer() + rLayPos, rLaySz := layPosSize(rb) + prop := 0.3333*(float32(pi)+.5)/float32(nsp) + yprop(rLayPos.Y, sLayPos.Y) + smat := sideMtx(sSide, prop) + rmat := sideMtx(rSide, prop) + spos := sLayPos.Add(sLaySz.Mul(smat)) + rpos := rLayPos.Add(rLaySz.Mul(rmat)) + // xyz.NewLine(se, pathsGp, sb.Name, spos, rpos, lineWidth, clr) + clr := colors.Spaced(sp.TypeNumber()) + xyz.NewArrow(se, pathsGp, sb.Name, spos, rpos, lineWidth, clr, xyz.NoStartArrow, xyz.EndArrow, 4, .5, 4) + } + } + } + } +} + // ViewDefaults are the default 3D view params func (nv *NetView) ViewDefaults(se *xyz.Scene) { se.Camera.Pose.Pos.Set(0, 1.5, 2.5) // more "top down" view shows more of layers @@ -837,7 +959,6 @@ func (nv *NetView) MakeToolbar(p *tree.Plan) { }) }) tree.Add(p, func(w *core.Separator) {}) - tree.Add(p, func(w *core.Separator) {}) tree.Add(p, func(w *core.Button) { w.SetText("Weights").SetType(core.ButtonAction).SetMenu(func(m *core.Scene) { fb := core.NewFuncButton(m).SetFunc(nv.SaveWeights) @@ -865,6 +986,14 @@ func (nv *NetView) MakeToolbar(p *tree.Plan) { }) }) tree.Add(p, func(w *core.Separator) {}) + tree.Add(p, func(w *core.Switch) { + w.SetText("Paths").SetChecked(nv.Params.Paths). + SetTooltip("Toggles whether pathways between layers are shown or not"). + OnChange(func(e events.Event) { + nv.Params.Paths = w.IsChecked() + nv.UpdateView() + }) + }) ditp := "data parallel index -- for models running multiple input patterns in parallel, this selects which one is viewed" tree.Add(p, func(w *core.Text) { w.SetText("Di:").SetTooltip(ditp) diff --git a/netview/params.go b/netview/params.go index 27ca655..322f512 100644 --- a/netview/params.go +++ b/netview/params.go @@ -50,6 +50,12 @@ func (nv *RasterParams) Defaults() { // Params holds parameters controlling how the view is rendered type Params struct { //types:add + // whether to display the pathways between layers as arrows + Paths bool + + // width of the path arrows, in normalized units + PathWidth float32 `default:"0.002"` + // raster plot parameters Raster RasterParams `display:"inline"` @@ -88,6 +94,8 @@ func (nv *Params) Defaults() { nv.Raster.Defaults() if nv.NVarCols == 0 { nv.NVarCols = NVarCols + nv.Paths = true + nv.PathWidth = 0.002 } if nv.MaxRecs == 0 { nv.MaxRecs = 210 // 200 cycles + 8 phase updates max + 2 extra..