// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package actrf
//go:generate core generate -add-types
import (
"slices"
"cogentcore.org/lab/stats/stats"
"cogentcore.org/lab/tensor"
)
// RF is used for computing an activation-based receptive field.
// It simply computes the activation weighted average of other
// *source* patterns of activation -- i.e., sum(act * src) / sum(src)
// which then shows you the patterns of source activity for which
// a given unit was active.
// You must call Init to initialize everything, Reset to restart the accumulation of the data,
// and Avg to compute the resulting averages based an accumulated data.
// Avg does not erase the accumulated data so it can continue beyond that point.
type RF struct {
// name of this RF -- used for management of multiple in RFs
Name string
// computed receptive field, as SumProd / SumSrc -- only after Avg has been called
RF tensor.Float32 `display:"no-inline"`
// unit normalized version of RF per source (inner 2D dimensions) -- good for display
NormRF tensor.Float32 `display:"no-inline"`
// normalized version of SumSrc -- sum of each point in the source -- good for viewing the completeness and uniformity of the sampling of the source space
NormSrc tensor.Float32 `display:"no-inline"`
// sum of the products of act * src
SumProd tensor.Float32 `display:"no-inline"`
// sum of the sources (denomenator)
SumSrc tensor.Float32 `display:"no-inline"`
// temporary destination sum for MPI -- only used when MPISum called
MPITmp tensor.Float32 `display:"no-inline"`
}
// Init initializes this RF based on name and shapes of given
// tensors representing the activations and source values.
func (af *RF) Init(name string, act, src tensor.Tensor) {
af.Name = name
af.InitShape(act, src)
af.Reset()
}
// InitShape initializes shape for this RF based on shapes of given
// tensors representing the activations and source values.
// does nothing if shape is already correct.
// return shape ints
func (af *RF) InitShape(act, src tensor.Tensor) []int {
aNy, aNx, _, _ := tensor.Projection2DShape(act.Shape(), false)
sNy, sNx, _, _ := tensor.Projection2DShape(src.Shape(), false)
oshp := []int{aNy, aNx, sNy, sNx}
if slices.Equal(af.RF.Shape().Sizes, oshp) {
return oshp
}
sshp := []int{sNy, sNx}
af.RF.SetShapeSizes(oshp...)
af.NormRF.SetShapeSizes(oshp...)
af.SumProd.SetShapeSizes(oshp...)
af.NormSrc.SetShapeSizes(sshp...)
af.SumSrc.SetShapeSizes(sshp...)
af.ConfigView(&af.RF)
af.ConfigView(&af.NormRF)
af.ConfigView(&af.SumProd)
af.ConfigView(&af.NormSrc)
af.ConfigView(&af.SumSrc)
return oshp
}
// ConfigView configures the view params on the tensor
func (af *RF) ConfigView(tsr *tensor.Float32) {
// todo:meta
// tsr.SetMetaData("colormap", "Viridis")
// tsr.SetMetaData("grid-fill", "1") // remove extra lines
// tsr.SetMetaData("fix-min", "true")
// tsr.SetMetaData("min", "0")
}
// Reset reinitializes the Sum accumulators -- must have called Init first
func (af *RF) Reset() {
af.SumProd.SetZeros()
af.SumSrc.SetZeros()
}
// Add adds one sample based on activation and source tensor values.
// these must be of the same shape as used when Init was called.
// thr is a threshold value on sources below which values are not added (prevents
// numerical issues with very small numbers)
func (af *RF) Add(act, src tensor.Tensor, thr float32) {
shp := af.InitShape(act, src) // ensure
aNy, aNx, sNy, sNx := shp[0], shp[1], shp[2], shp[3]
for sy := 0; sy < sNy; sy++ {
for sx := 0; sx < sNx; sx++ {
tv := float32(tensor.Projection2DValue(src, false, sy, sx))
if tv < thr {
continue
}
af.SumSrc.SetAdd(tv, sy, sx)
for ay := 0; ay < aNy; ay++ {
for ax := 0; ax < aNx; ax++ {
av := float32(tensor.Projection2DValue(act, false, ay, ax))
af.SumProd.SetAdd(av*tv, ay, ax, sy, sx)
}
}
}
}
}
// Avg computes RF as SumProd / SumSrc. Does not Reset sums.
func (af *RF) Avg() {
aNy := af.SumProd.DimSize(0)
aNx := af.SumProd.DimSize(1)
sNy := af.SumProd.DimSize(2)
sNx := af.SumProd.DimSize(3)
var maxSrc float32
for sy := 0; sy < sNy; sy++ {
for sx := 0; sx < sNx; sx++ {
src := af.SumSrc.Value(sy, sx)
if src == 0 {
continue
}
if src > maxSrc {
maxSrc = src
}
for ay := 0; ay < aNy; ay++ {
for ax := 0; ax < aNx; ax++ {
oo := af.SumProd.Shape().IndexTo1D(ay, ax, sy, sx)
af.RF.Values[oo] = af.SumProd.Values[oo] / src
}
}
}
}
if maxSrc == 0 {
maxSrc = 1
}
for i, v := range af.SumSrc.Values {
af.NormSrc.Values[i] = v / maxSrc
}
}
// Norm computes unit norm of RF values -- must be called after Avg
func (af *RF) Norm() {
stats.UnitNormOut(&af.RF, &af.NormRF)
}
// AvgNorm computes RF as SumProd / SumTarg and then does Norm.
// This is what you typically want to call before viewing RFs.
// Does not Reset sums.
func (af *RF) AvgNorm() {
af.Avg()
af.Norm()
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package actrf
import (
"fmt"
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/tensor"
)
// RFs manages multiple named RF's -- each one must be initialized first
// but functions like Avg, Norm, and Reset can be called generically on all.
type RFs struct {
// map of names to indexes of RFs
NameMap map[string]int
// the RFs
RFs []*RF
}
// RFByName returns RF of given name, nil and error msg if not found.
func (af *RFs) RFByName(name string) (*RF, error) {
if af.NameMap != nil {
idx, ok := af.NameMap[name]
if ok {
return af.RFs[idx], nil
}
}
return nil, fmt.Errorf("Name: %s not found in list of named RFs", name)
}
// AddRF adds a new RF, calling Init on it using given act, src tensors
func (af *RFs) AddRF(name string, act, src tensor.Tensor) *RF {
if af.NameMap == nil {
af.NameMap = make(map[string]int)
}
sz := len(af.RFs)
af.NameMap[name] = sz
rf := &RF{}
af.RFs = append(af.RFs, rf)
rf.Init(name, act, src)
return rf
}
// Add adds a new act sample to the accumulated data for given named rf
func (af *RFs) Add(name string, act, src tensor.Tensor, thr float32) error {
rf, err := af.RFByName(name)
if errors.Log(err) != nil {
return err
}
rf.Add(act, src, thr)
return nil
}
// Reset resets Sum accumulations for all rfs
func (af *RFs) Reset() {
for _, rf := range af.RFs {
rf.Reset()
}
}
// Avg computes RF as SumProd / SumTarg. Does not Reset sums.
func (af *RFs) Avg() {
for _, rf := range af.RFs {
rf.Avg()
}
}
// Norm computes unit norm of RF values -- must be called after Avg
func (af *RFs) Norm() {
for _, rf := range af.RFs {
rf.Norm()
}
}
// AvgNorm computes RF as SumProd / SumTarg and then does Norm.
// This is what you typically want to call before viewing RFs.
// Does not Reset sums.
func (af *RFs) AvgNorm() {
for _, rf := range af.RFs {
rf.AvgNorm()
}
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package actrf
import (
"cogentcore.org/lab/base/mpi"
"cogentcore.org/lab/tensor/tensormpi"
)
// MPISum aggregates RF Sum data across all processors in given mpi communicator.
// It adds to SumProd and SumSrc. Call this prior to calling NormAvg().
func (af *RF) MPISum(comm *mpi.Comm) {
if mpi.WorldSize() == 1 {
return
}
tensormpi.ReduceTensor(&af.MPITmp, &af.SumProd, comm, mpi.OpSum)
af.SumProd.CopyFrom(&af.MPITmp)
tensormpi.ReduceTensor(&af.MPITmp, &af.SumSrc, comm, mpi.OpSum)
af.SumSrc.CopyFrom(&af.MPITmp)
}
// MPISum aggregates RF Sum data across all processors in given mpi communicator.
// It adds to SumProd and SumSrc. Call this prior to calling NormAvg().
func (af *RFs) MPISum(comm *mpi.Comm) {
for _, rf := range af.RFs {
rf.MPISum(comm)
}
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package actrf
import "cogentcore.org/lab/tensor"
// RunningAvg computes a running-average activation-based receptive field
// for activities act relative to source activations src (the thing we're projecting rf onto)
// accumulating into output out, with time constant tau.
// act and src are projected into a 2D space (tensor.Projection2D* methods), and
// resulting out is 4D of act outer and src inner.
func RunningAvg(out *tensor.Float32, act, src tensor.Tensor, tau float32) {
dt := 1 / tau
cdt := 1 - dt
aNy, aNx, _, _ := tensor.Projection2DShape(act.Shape(), false)
tNy, tNx, _, _ := tensor.Projection2DShape(src.Shape(), false)
oshp := []int{aNy, aNx, tNy, tNx}
out.SetShapeSizes(oshp...)
for ay := 0; ay < aNy; ay++ {
for ax := 0; ax < aNx; ax++ {
av := float32(tensor.Projection2DValue(act, false, ay, ax))
for ty := 0; ty < tNy; ty++ {
for tx := 0; tx < tNx; tx++ {
tv := float32(tensor.Projection2DValue(src, false, ty, tx))
oo := out.Shape().IndexTo1D(ay, ax, ty, tx)
ov := out.Values[oo]
nv := cdt*ov + dt*tv*av
out.Values[oo] = nv
}
}
}
}
}
// Copyright (c) 2021 The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package chem
// Buffer provides a soft buffering driving deltas relative to a target N
// which can be set by concentration and volume.
type Buffer struct {
// rate of buffering (akin to permeability / conductance of a channel)
K float64
// buffer target concentration -- drives delta relative to this
Target float64
}
func (bf *Buffer) SetTargVol(targ, vol float64) {
bf.Target = CoToN(targ, vol)
}
// Step computes da delta for current value ca relative to target value Target
func (bf *Buffer) Step(ca float64, da *float64) {
*da += bf.K * (bf.Target - ca)
}
// Copyright (c) 2021 The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package chem
// CoToN returns N based on concentration, for given volume: co * vol
func CoToN(co, vol float64) float64 {
return co * vol
}
// CoFromN returns concentration from N, for given volume: co / vol
func CoFromN(n, vol float64) float64 {
return n / vol
}
// Copyright (c) 2021 The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package chem
// Diffuse models diffusion between two compartments A and B as
// a function of concentration in each and potentially asymmetric
// rate constants: A Kf -> B and B Kb -> A
// computes the difference between each direction and applies to each
type Diffuse struct {
// A -> B forward diffusion rate constant, sec-1
Kf float64
// B -> A backward diffusion rate constant, sec-1
Kb float64
}
// Set sets both diffusion rates
func (rt *Diffuse) Set(kf, kb float64) {
rt.Kf = kf
rt.Kb = kb
}
// SetSym sets symmetric diffusion rate (Kf == Kb)
func (rt *Diffuse) SetSym(kfb float64) {
rt.Kf = kfb
rt.Kb = kfb
}
// Step computes delta A and B values based on current A, B values
// inputs are numbers, converted to concentration to drive rate
func (rt *Diffuse) Step(ca, cb, va, vb float64, da, db *float64) {
df := rt.Kf*(ca/va) - rt.Kb*(cb/vb)
*da -= df
*db += df
}
// Copyright (c) 2021 The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package chem
// Enz models an enzyme-catalyzed reaction based on the Michaelis-Menten kinetics
// that transforms S = substrate into P product via SE-bound C complex
//
// K1 K3
//
// S + E --> C(SE) ---> P + E
//
// <-- K2
//
// S = substrate, E = enzyme, C = SE complex, P = product
// The source K constants are in terms of concentrations μM-1 and sec-1
// but calculations take place using N's, and the forward direction has
// two factors while reverse only has one, so a corrective volume factor needs
// to be divided out to set the actual forward factor.
type Enz struct {
// S+E forward rate constant, in μM-1 msec-1
K1 float64
// SE backward rate constant, in μM-1 msec-1
K2 float64
// SE -> P + E catalyzed rate constant, in μM-1 msec-1
K3 float64
// Michaelis constant = (K2 + K3) / K1
Km float64 `edit:"-"`
}
func (rt *Enz) Update() {
rt.Km = (rt.K2 + rt.K3) / rt.K1
}
// SetKmVol sets time constants in seconds using Km, K2, K3
// dividing forward K1 by volume to compensate for 2 volume-based concentrations
// occurring in forward component (s * e), vs just 1 in back
func (rt *Enz) SetKmVol(km, vol, k2, k3 float64) {
k1 := (k2 + k3) / km
rt.K1 = CoFromN(k1, vol)
rt.K2 = k2
rt.K3 = k3
rt.Update()
}
// SetKm sets time constants in seconds using Km, K2, K3
func (rt *Enz) SetKm(km, k2, k3 float64) {
k1 := (k2 + k3) / km
rt.K1 = k1
rt.K2 = k2
rt.K3 = k3
rt.Update()
}
// Set sets time constants in seconds directly
func (rt *Enz) Set(k1, k2, k3 float64) {
rt.K1 = k1
rt.K2 = k2
rt.K3 = k3
rt.Update()
}
// Step computes delta values based on current S, E, C, and P values
func (rt *Enz) Step(cs, ce, cc, cp float64, ds, de, dc, dp *float64) {
df := rt.K1 * cs * ce // forward
db := rt.K2 * cc // backward
do := rt.K3 * cc // out to product
*dp += do
*dc += df - (do + db) // complex = forward - back - output
*de += (do + db) - df // e is released with product and backward from complex, consumed by forward
*ds -= (df - db) // substrate = back - forward
}
// StepK computes delta values based on current S, E, C, and P values
// K version has additional rate multiplier for Kf = K1
func (rt *Enz) StepK(kf, cs, ce, cc, cp float64, ds, de, dc, dp *float64) {
df := kf * rt.K1 * cs * ce // forward
db := rt.K2 * cc // backward
do := rt.K3 * cc // out to product
*dp += do
*dc += df - (do + db) // complex = forward - back - output
*de += (do + db) - df // e is released with product and backward from complex, consumed by forward
*ds -= (df - db) // substrate = back - forward
}
// Copyright (c) 2021 The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package chem
// EnzRate models an enzyme-catalyzed reaction based on the Michaelis-Menten kinetics
// that transforms S = substrate into P product via SE bound C complex
//
// K1 K3
//
// S + E --> C(SE) ---> P + E
//
// <-- K2
//
// S = substrate, E = enzyme, C = SE complex, P = product
// This version does NOT consume the E enzyme or directly use the C complex
// as an accumulated factor: instead it directly computes an overall rate
// for the end-to-end S <-> P reaction based on the K constants:
// rate = S * E * K3 / (S + Km)
// This amount is added to the P and subtracted from the S, and recorded
// in the C complex variable as rate / K3 -- it is just directly set.
// In some situations this C variable can be used for other things.
// The source K constants are in terms of concentrations μM-1 and sec-1
// but calculations take place using N's, and the forward direction has
// two factors while reverse only has one, so a corrective volume factor needs
// to be divided out to set the actual forward factor.
type EnzRate struct {
// S+E forward rate constant, in μM-1 msec-1
K1 float64
// SE backward rate constant, in μM-1 msec-1
K2 float64
// SE -> P + E catalyzed rate constant, in μM-1 msec-1
K3 float64
// Michaelis constant = (K2 + K3) / K1 -- goes into the rate
Km float64 `edit:"-"`
}
func (rt *EnzRate) Update() {
rt.Km = (rt.K2 + rt.K3) / rt.K1
}
// SetKmVol sets time constants in seconds using Km, K2, K3
// dividing forward K1 by volume to compensate for 2 volume-based concentrations
// occurring in forward component (s * e), vs just 1 in back
func (rt *EnzRate) SetKmVol(km, vol, k2, k3 float64) {
k1 := (k2 + k3) / km
rt.K1 = CoFromN(k1, vol)
rt.K2 = k2
rt.K3 = k3
rt.Update()
}
// SetKm sets time constants in seconds using Km, K2, K3
func (rt *EnzRate) SetKm(km, k2, k3 float64) {
k1 := (k2 + k3) / km
rt.K1 = k1
rt.K2 = k2
rt.K3 = k3
rt.Update()
}
// Set sets time constants in seconds directly
func (rt *EnzRate) Set(k1, k2, k3 float64) {
rt.K1 = k1
rt.K2 = k2
rt.K3 = k3
rt.Update()
}
// Step computes delta values based on current S, E values, setting dS, dP and C = rate
func (rt *EnzRate) Step(cs, ce float64, ds, dp, cc *float64) {
rate := (rt.K3 * cs * ce) / (cs + rt.Km)
*dp += rate
*ds -= rate
*cc += rate // directly stored
}
// Step computes delta values based on current S, E values, setting dS, dP and C = rate
// K version has additional rate multiplier for Kf = K1
func (rt *EnzRate) StepK(kf, cs, ce float64, ds, dp, cc *float64) {
km := (rt.K2 + rt.K3) / (rt.K1 * kf)
rate := (rt.K3 * cs * ce) / (cs + km)
*dp += rate
*ds -= rate
*cc += rate // directly stored
}
// Copyright (c) 2021 The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package chem
// IntegrationDt is the time step of integration
// for Urakubo et al, 2008: uses 5e-5, 2e-4 is barely stable, 5e-4 is not
// The AC1act dynamics in particular are not stable due to large ATP, AMP numbers
var IntegrationDt = 5.0e-5
// Integrate adds delta to current value with integration rate constant IntegrationDt
// new value cannot go below 0
func Integrate(c *float64, d float64) {
*c += IntegrationDt * d
if *c < 0 {
*c = 0
}
}
// note: genesis kkit uses exponential Euler which requires separate A - B deltas
// advantages are unclear.
// if *c > 1e-10 && d > 1e-10 { // note: exponential Euler requires separate A - B deltas
// dd := math.Exp()
// } else {
// Copyright (c) 2021 The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package chem
//go:generate core generate -add-types
// React models a basic chemical reaction:
//
// Kf
//
// A + B --> AB
//
// <-- Kb
//
// where Kf is the forward and Kb is the backward time constant.
// The source Kf and Kb constants are in terms of concentrations μM-1 and sec-1
// but calculations take place using N's, and the forward direction has
// two factors while reverse only has one, so a corrective volume factor needs
// to be divided out to set the actual forward factor.
type React struct {
// forward rate constant for N / sec assuming 2 forward factors
Kf float64
// backward rate constant for N / sec assuming 1 backward factor
Kb float64
}
// SetVol sets reaction forward / backward time constants in seconds,
// dividing forward Kf by volume to compensate for 2 volume-based concentrations
// occurring in forward component, vs just 1 in back
func (rt *React) SetVol(f, vol, b float64) {
rt.Kf = CoFromN(f, vol)
rt.Kb = b
}
// Set sets reaction forward / backward time constants in seconds
func (rt *React) Set(f, b float64) {
rt.Kf = f
rt.Kb = b
}
// Step computes delta A, B, AB values based on current A, B, and AB values
func (rt *React) Step(ca, cb, cab float64, da, db, dab *float64) {
df := rt.Kf*ca*cb - rt.Kb*cab
*dab += df
*da -= df
*db -= df
}
// StepK computes delta A, B, AB values based on current A, B, and AB values
// K version has additional rate multiplier for Kf
func (rt *React) StepK(kf, ca, cb, cab float64, da, db, dab *float64) {
df := kf*rt.Kf*ca*cb - rt.Kb*cab
*dab += df
*da -= df
*db -= df
}
// StepCB computes delta A, AB values based on current A, B, and AB values
// assumes B does not change -- does not compute db
func (rt *React) StepCB(ca, cb, cab float64, da, dab *float64) {
df := rt.Kf*ca*cb - rt.Kb*cab
*dab += df
*da -= df
}
// Copyright (c) 2021 The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package chem
// SimpleEnz models a simple enzyme-catalyzed reaction
// that transforms S = substrate into P product via E which is not consumed
// assuming there is much more E than S and P -- E effectively acts as a
// rate constant multiplier
//
// Kf*E
//
// S ----> P
//
// S = substrate, E = enzyme, P = product, Kf is the rate of the reaction
type SimpleEnz struct {
// S->P forward rate constant, in μM-1 msec-1
Kf float64
}
// SetVol sets reaction forward / backward time constants in seconds,
// dividing forward Kf by volume to compensate for 2 volume-based concentrations
// occurring in forward component, vs just 1 in back
func (rt *SimpleEnz) SetVol(f, vol float64) {
rt.Kf = CoFromN(f, vol)
}
// Step computes delta S and P values based on current S, E values
func (rt *SimpleEnz) Step(cs, ce float64, ds, dp *float64) {
df := rt.Kf * cs * ce // forward
*ds -= df
*dp += df
}
// StepCo computes delta S and P values based on current S, E values
// based on concentration
func (rt *SimpleEnz) StepCo(cs, ce, vol float64, ds, dp *float64) {
df := rt.Kf * CoFromN(cs, vol) * CoFromN(ce, vol) // forward
*ds -= df
*dp += df
}
// StepK computes delta S and P values based on current S, E values
// K version has additional rate multiplier for Kf
func (rt *SimpleEnz) StepK(kf, cs, ce float64, ds, dp *float64) {
df := kf * rt.Kf * cs * ce // forward
*ds -= df
*dp += df
}
// Copyright (c) 2023, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package decoder
//go:generate core generate -add-types
import (
"fmt"
"cogentcore.org/core/math32"
"cogentcore.org/lab/base/mpi"
"cogentcore.org/lab/tensor"
)
type ActivationFunc func(float32) float32
// Linear is a linear neural network, which can be configured with a custom
// activation function. By default it will use the identity function.
// It learns using the delta rule for each output unit.
type Linear struct {
// learning rate
LRate float32 `default:"0.1"`
// layers to decode
Layers []Layer
// unit values -- read this for decoded output
Units []LinearUnit
// number of inputs -- total sizes of layer inputs
NInputs int
// number of outputs -- total sizes of layer inputs
NOutputs int
// input values, copied from layers
Inputs []float32
// for holding layer values
ValuesTsrs map[string]*tensor.Float32 `display:"-"`
// synaptic weights: outer loop is units, inner loop is inputs
Weights tensor.Float32
// activation function
ActivationFn ActivationFunc
// which pool to use within a layer
PoolIndex int
// mpi communicator -- MPI users must set this to their comm -- do direct assignment
Comm *mpi.Comm `display:"-"`
// delta weight changes: only for MPI mode -- outer loop is units, inner loop is inputs
MPIDWts tensor.Float32
}
// Layer is the subset of emer.Layer that is used by this code
type Layer interface {
Name() string
UnitValuesTensor(tsr tensor.Values, varNm string, di int) error
Shape() *tensor.Shape
}
func IdentityFunc(x float32) float32 { return x }
// LogisticFunc implements the standard logistic function.
// Its outputs are in the range (0, 1).
// Also known as Sigmoid. See https://en.wikipedia.org/wiki/Logistic_function.
func LogisticFunc(x float32) float32 { return 1 / (1 + math32.FastExp(-x)) }
// LinearUnit has variables for Linear decoder unit
type LinearUnit struct {
// target activation value -- typically 0 or 1 but can be within that range too
Target float32
// final activation = sum x * w -- this is the decoded output
Act float32
// net input = sum x * w
Net float32
}
// InitLayer initializes detector with number of categories and layers
func (dec *Linear) InitLayer(nOutputs int, layers []Layer, activationFn ActivationFunc) {
dec.Layers = layers
nIn := 0
for _, ly := range dec.Layers {
nIn += ly.Shape().Len()
}
dec.Init(nOutputs, nIn, -1, activationFn)
}
// InitPool initializes detector with number of categories, 1 layer and
func (dec *Linear) InitPool(nOutputs int, layer Layer, poolIndex int, activationFn ActivationFunc) {
dec.Layers = []Layer{layer}
shape := layer.Shape()
// TODO: assert that it's a 4D layer
nIn := shape.DimSize(2) * shape.DimSize(3)
dec.Init(nOutputs, nIn, poolIndex, activationFn)
}
// Init initializes detector with number of categories and number of inputs
func (dec *Linear) Init(nOutputs, nInputs int, poolIndex int, activationFn ActivationFunc) {
dec.NInputs = nInputs
dec.LRate = 0.1
dec.NOutputs = nOutputs
dec.Units = make([]LinearUnit, dec.NOutputs)
dec.Inputs = make([]float32, dec.NInputs)
dec.Weights.SetShapeSizes(dec.NOutputs, dec.NInputs)
for i := range dec.Weights.Values {
dec.Weights.Values[i] = 0.1
}
dec.PoolIndex = poolIndex
dec.ActivationFn = activationFn
}
// Decode decodes the given variable name from layers (forward pass).
// Decoded values are in Units[i].Act -- see also Output to get into a []float32.
// di is a data parallel index di, for networks capable
// of processing input patterns in parallel.
func (dec *Linear) Decode(varNm string, di int) {
dec.Input(varNm, di)
dec.Forward()
}
// Output returns the resulting Decoded output activation values into given slice
// which is automatically resized if not of sufficient size.
func (dec *Linear) Output(acts *[]float32) {
if cap(*acts) < dec.NOutputs {
*acts = make([]float32, dec.NOutputs)
} else if len(*acts) != dec.NOutputs {
*acts = (*acts)[:dec.NOutputs]
}
for ui := range dec.Units {
u := &dec.Units[ui]
(*acts)[ui] = u.Act
}
}
// Train trains the decoder with given target correct answers, as []float32 values.
// Returns SSE (sum squared error) of difference between targets and outputs.
// Also returns and prints an error if targets are not sufficient length for NOutputs.
func (dec *Linear) Train(targs []float32) (float32, error) {
err := dec.SetTargets(targs)
if err != nil {
return 0, err
}
sse := dec.Back()
return sse, nil
}
// TrainMPI trains the decoder with given target correct answers, as []float32 values.
// Returns SSE (sum squared error) of difference between targets and outputs.
// Also returns and prints an error if targets are not sufficient length for NOutputs.
// MPI version uses mpi to synchronize weight changes across parallel nodes.
func (dec *Linear) TrainMPI(targs []float32) (float32, error) {
err := dec.SetTargets(targs)
if err != nil {
return 0, err
}
sse := dec.BackMPI()
return sse, nil
}
// SetTargets sets given target correct answers, as []float32 values.
// Also returns and prints an error if targets are not sufficient length for NOutputs.
func (dec *Linear) SetTargets(targs []float32) error {
if len(targs) < dec.NOutputs {
err := fmt.Errorf("decoder.Linear: number of targets < NOutputs: %d < %d", len(targs), dec.NOutputs)
fmt.Println(err)
return err
}
for ui := range dec.Units {
u := &dec.Units[ui]
u.Target = targs[ui]
}
return nil
}
// ValuesTsr gets value tensor of given name, creating if not yet made
func (dec *Linear) ValuesTsr(name string) *tensor.Float32 {
if dec.ValuesTsrs == nil {
dec.ValuesTsrs = make(map[string]*tensor.Float32)
}
tsr, ok := dec.ValuesTsrs[name]
if !ok {
tsr = &tensor.Float32{}
dec.ValuesTsrs[name] = tsr
}
return tsr
}
// Input grabs the input from given variable in layers
// di is a data parallel index di, for networks capable
// of processing input patterns in parallel.
func (dec *Linear) Input(varNm string, di int) {
off := 0
for _, ly := range dec.Layers {
tsr := dec.ValuesTsr(ly.Name())
ly.UnitValuesTensor(tsr, varNm, di)
if dec.PoolIndex >= 0 {
shape := ly.Shape()
y := dec.PoolIndex / shape.DimSize(1)
x := dec.PoolIndex % shape.DimSize(1)
tsr = tsr.SubSpace(y, x).(*tensor.Float32)
}
for j, v := range tsr.Values {
dec.Inputs[off+j] = v
}
off += ly.Shape().Len()
}
}
// Forward compute the forward pass from input
func (dec *Linear) Forward() {
for ui := range dec.Units {
u := &dec.Units[ui]
net := float32(0)
off := ui * dec.NInputs
for j, in := range dec.Inputs {
net += dec.Weights.Values[off+j] * in
}
u.Net = net
u.Act = dec.ActivationFn(net)
}
}
// https://en.wikipedia.org/wiki/Delta_rule
// Delta rule: delta = learning rate * error * input
// We don't need the g' (derivative of activation function) term assuming:
// 1. Identity activation function with SSE loss (beecause it's 1), OR
// 2. Logistic activation function with Cross Entropy loss (because it cancels out, see
// https://towardsdatascience.com/deriving-backpropagation-with-cross-entropy-loss-d24811edeaf9)
// The fact that we return SSE does not mean we're optimizing SSE.
// Back compute the backward error propagation pass
// Returns SSE (sum squared error) of difference between targets and outputs.
func (dec *Linear) Back() float32 {
var sse float32
for ui := range dec.Units {
u := &dec.Units[ui]
err := u.Target - u.Act
sse += err * err
del := dec.LRate * err
off := ui * dec.NInputs
for j, in := range dec.Inputs {
dec.Weights.Values[off+j] += del * in
}
}
return sse
}
// BackMPI compute the backward error propagation pass
// Returns SSE (sum squared error) of difference between targets and outputs.
func (dec *Linear) BackMPI() float32 {
if dec.MPIDWts.Len() != dec.Weights.Len() {
tensor.SetShapeFrom(&dec.MPIDWts, &dec.Weights)
}
var sse float32
for ui := range dec.Units {
u := &dec.Units[ui]
err := u.Target - u.Act
sse += err * err
del := dec.LRate * err
off := ui * dec.NInputs
for j, in := range dec.Inputs {
dec.MPIDWts.Values[off+j] = del * in
}
}
dec.Comm.AllReduceF32(mpi.OpSum, dec.MPIDWts.Values, nil)
for i, dw := range dec.MPIDWts.Values {
dec.Weights.Values[i] += dw
}
return sse
}
// Copyright (c) 2021, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package decoder
import (
"bufio"
"compress/gzip"
"encoding/json"
"fmt"
"io"
"math"
"os"
"path/filepath"
"sort"
"cogentcore.org/core/math32"
"cogentcore.org/lab/base/mpi"
"cogentcore.org/lab/tensor"
"github.com/emer/emergent/v2/emer"
)
// SoftMax is a softmax decoder, which is the best choice for a 1-hot classification
// using the widely used SoftMax function: https://en.wikipedia.org/wiki/Softmax_function
type SoftMax struct {
// learning rate
Lrate float32 `default:"0.1"`
// layers to decode
Layers []emer.Layer
// number of different categories to decode
NCats int
// unit values
Units []SoftMaxUnit
// sorted list of indexes into Units, in descending order from strongest to weakest -- i.e., Sortedhas the most likely categorization, and its activity is Units].Act
Sorted []int
// number of inputs -- total sizes of layer inputs
NInputs int
// input values, copied from layers
Inputs []float32
// current target index of correct category
Target int
// for holding layer values
ValuesTsrs map[string]*tensor.Float32 `display:"-"`
// synaptic weights: outer loop is units, inner loop is inputs
Weights tensor.Float32
// mpi communicator -- MPI users must set this to their comm -- do direct assignment
Comm *mpi.Comm `display:"-"`
// delta weight changes: only for MPI mode -- outer loop is units, inner loop is inputs
MPIDWts tensor.Float32
}
// SoftMaxUnit has variables for softmax decoder unit
type SoftMaxUnit struct {
// final activation = e^Ge / sum e^Ge
Act float32
// net input = sum x * w
Net float32
// exp(Net)
Exp float32
}
// InitLayer initializes detector with number of categories and layers
func (sm *SoftMax) InitLayer(ncats int, layers []emer.Layer) {
sm.Layers = layers
nin := 0
for _, ly := range sm.Layers {
nin += ly.AsEmer().Shape.Len()
}
sm.Init(ncats, nin)
}
// Init initializes detector with number of categories and number of inputs
func (sm *SoftMax) Init(ncats, ninputs int) {
sm.NInputs = ninputs
sm.Lrate = 0.1 // seems pretty good
sm.NCats = ncats
sm.Units = make([]SoftMaxUnit, ncats)
sm.Sorted = make([]int, ncats)
sm.Inputs = make([]float32, sm.NInputs)
sm.Weights.SetShapeSizes(sm.NCats, sm.NInputs)
for i := range sm.Weights.Values {
sm.Weights.Values[i] = .1
}
}
// Decode decodes the given variable name from layers (forward pass)
// See Sorted list of indexes for the decoding output -- i.e., Sorted[0]
// is the most likely -- that is returned here as a convenience.
// di is a data parallel index di, for networks capable
// of processing input patterns in parallel.
func (sm *SoftMax) Decode(varNm string, di int) int {
sm.Input(varNm, di)
sm.Forward()
sm.Sort()
return sm.Sorted[0]
}
// Train trains the decoder with given target correct answer (0..NCats-1)
func (sm *SoftMax) Train(targ int) {
sm.Target = targ
sm.Back()
}
// TrainMPI trains the decoder with given target correct answer (0..NCats-1)
// MPI version uses mpi to synchronize weight changes across parallel nodes.
func (sm *SoftMax) TrainMPI(targ int) {
sm.Target = targ
sm.BackMPI()
}
// ValuesTsr gets value tensor of given name, creating if not yet made
func (sm *SoftMax) ValuesTsr(name string) *tensor.Float32 {
if sm.ValuesTsrs == nil {
sm.ValuesTsrs = make(map[string]*tensor.Float32)
}
tsr, ok := sm.ValuesTsrs[name]
if !ok {
tsr = &tensor.Float32{}
sm.ValuesTsrs[name] = tsr
}
return tsr
}
// Input grabs the input from given variable in layers
// di is a data parallel index di, for networks capable
// of processing input patterns in parallel.
func (sm *SoftMax) Input(varNm string, di int) {
off := 0
for _, ly := range sm.Layers {
lb := ly.AsEmer()
tsr := sm.ValuesTsr(lb.Name)
lb.UnitValuesTensor(tsr, varNm, di)
for j, v := range tsr.Values {
sm.Inputs[off+j] = v
}
off += lb.Shape.Len()
}
}
// Forward compute the forward pass from input
func (sm *SoftMax) Forward() {
max := float32(-math.MaxFloat32)
for ui := range sm.Units {
u := &sm.Units[ui]
net := float32(0)
off := ui * sm.NInputs
for j, in := range sm.Inputs {
net += sm.Weights.Values[off+j] * in
}
u.Net = net
if net > max {
max = net
}
}
sum := float32(0)
for ui := range sm.Units {
u := &sm.Units[ui]
u.Net -= max
u.Exp = math32.FastExp(u.Net)
sum += u.Exp
}
for ui := range sm.Units {
u := &sm.Units[ui]
u.Act = u.Exp / sum
}
}
// Sort updates Sorted indexes of the current Unit category activations sorted
// from highest to lowest. i.e., the 0-index value has the strongest
// decoded output category, 1 the next-strongest, etc.
func (sm *SoftMax) Sort() {
for i := range sm.Sorted {
sm.Sorted[i] = i
}
sort.Slice(sm.Sorted, func(i, j int) bool {
return sm.Units[sm.Sorted[i]].Act > sm.Units[sm.Sorted[j]].Act
})
}
// Back compute the backward error propagation pass
func (sm *SoftMax) Back() {
lr := sm.Lrate
for ui := range sm.Units {
u := &sm.Units[ui]
var del float32
if ui == sm.Target {
del = lr * (1 - u.Act)
} else {
del = -lr * u.Act
}
off := ui * sm.NInputs
for j, in := range sm.Inputs {
sm.Weights.Values[off+j] += del * in
}
}
}
// BackMPI compute the backward error propagation pass
// MPI version shares weight changes across nodes
func (sm *SoftMax) BackMPI() {
if sm.MPIDWts.Len() != sm.Weights.Len() {
tensor.SetShapeFrom(&sm.MPIDWts, &sm.Weights)
}
lr := sm.Lrate
for ui := range sm.Units {
u := &sm.Units[ui]
var del float32
if ui == sm.Target {
del = lr * (1 - u.Act)
} else {
del = -lr * u.Act
}
off := ui * sm.NInputs
for j, in := range sm.Inputs {
sm.MPIDWts.Values[off+j] = del * in
}
}
sm.Comm.AllReduceF32(mpi.OpSum, sm.MPIDWts.Values, nil)
for i, dw := range sm.MPIDWts.Values {
sm.Weights.Values[i] += dw
}
}
type softMaxForSerialization struct {
Weights []float32 `json:"weights"`
}
// Save saves the decoder weights to given file paths.
// If path ends in .gz, it will be gzipped.
func (sm *SoftMax) Save(path string) error {
file, err := os.Create(path)
if err != nil {
return err
}
defer file.Close()
ext := filepath.Ext(path)
var writer io.Writer
if ext == ".gz" {
gw := gzip.NewWriter(file)
defer gw.Close()
writer = gw
} else {
bw := bufio.NewWriter(file)
defer bw.Flush()
writer = bw
}
encoder := json.NewEncoder(writer)
return encoder.Encode(softMaxForSerialization{Weights: sm.Weights.Values})
}
// Load loads the decoder weights from given file paths.
// If the shape of the decoder does not match the shape of the saved weights,
// an error will be returned.
func (sm *SoftMax) Load(path string) error {
ext := filepath.Ext(path)
var reader io.Reader
file, err := os.Open(path)
if err != nil {
return err
}
defer file.Close()
if ext == ".gz" {
gr, err := gzip.NewReader(file)
if err != nil {
return err
}
defer gr.Close()
reader = gr
} else {
reader = bufio.NewReader(file)
}
decoder := json.NewDecoder(reader)
var s softMaxForSerialization
if err := decoder.Decode(&s); err != nil {
return err
}
if len(sm.Weights.Values) != len(s.Weights) {
return fmt.Errorf("loaded weights length %d does not match expected length %d", len(s.Weights), len(sm.Weights.Values))
}
sm.Weights.Values = s.Weights
return nil
}
// Copyright (c) 2021, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package decoder
import (
"math/rand"
"sort"
)
// TopVoteInt returns the choice with the most votes among a list of votes
// as integer-valued choices, and also returns the number of votes for that item.
// In the case of ties, it chooses one at random (otherwise it would have a bias
// toward the lowest numbered item).
func TopVoteInt(votes []int) (int, int) {
sort.Ints(votes)
prv := votes[0]
cur := prv
top := prv
topn := 1
curn := 1
n := len(votes)
var ties []int
for i := 1; i < n; i++ {
cur = votes[i]
if cur != prv {
if curn > topn {
top = prv
topn = curn
ties = []int{top}
} else if curn == topn {
ties = append(ties, prv)
}
curn = 1
prv = cur
} else {
curn++
}
}
if curn > topn {
top = cur
topn = curn
ties = []int{top}
} else if curn == topn {
ties = append(ties, cur)
}
if len(ties) > 1 {
ti := rand.Intn(len(ties))
top = ties[ti]
}
return top, topn
}
// TopVoteString returns the choice with the most votes among a list of votes
// as string-valued choices, and also returns the number of votes for that item.
// In the case of ties, it chooses one at random (otherwise it would have a bias
// toward the lowest numbered item).
func TopVoteString(votes []string) (string, int) {
sort.Strings(votes)
prv := votes[0]
cur := prv
top := prv
topn := 1
curn := 1
n := len(votes)
var ties []string
for i := 1; i < n; i++ {
cur = votes[i]
if cur != prv {
if curn > topn {
top = prv
topn = curn
ties = []string{top}
} else if curn == topn {
ties = append(ties, prv)
}
curn = 1
prv = cur
} else {
curn++
}
}
if curn > topn {
top = cur
topn = curn
ties = []string{top}
} else if curn == topn {
ties = append(ties, cur)
}
if len(ties) > 1 {
ti := rand.Intn(len(ties))
top = ties[ti]
}
return top, topn
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edge
// Edge returns coordinate value based on either wrapping or clipping at the edge
// and if not wrapping, if it should be clipped (ignored)
func Edge(ci, max int, wrap bool) (int, bool) {
if ci < 0 {
if wrap {
return (max + ci) % max, false
}
return 0, true
}
if ci >= max {
if wrap {
return (ci - max) % max, false
}
return max - 1, true
}
return ci, false
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package edge
import "cogentcore.org/core/math32"
// WrapMinDist returns the wrapped coordinate value that is closest to ctr
// i.e., if going out beyond max is closer, then returns that coordinate
// else if going below 0 is closer than not, then returns that coord
func WrapMinDist(ci, max, ctr float32) float32 {
nwd := math32.Abs(ci - ctr) // no-wrap dist
if math32.Abs((ci+max)-ctr) < nwd {
return ci + max
}
if math32.Abs((ci-max)-ctr) < nwd {
return ci - max
}
return ci
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package efuns has misc functions, such as Gaussian and Logistic,
// that are used in neural models, and do not have a home elsewhere.
package efuns
//go:generate core generate -add-types
import (
"cogentcore.org/core/math32"
)
// GaussVecDistNoNorm returns the gaussian of the distance between two 2D vectors
// using given sigma standard deviation, without normalizing area under gaussian
// (i.e., max value is 1 at dist = 0)
func GaussVecDistNoNorm(a, b math32.Vector2, sigma float32) float32 {
dsq := a.DistanceToSquared(b)
return math32.FastExp((-0.5 * dsq) / (sigma * sigma))
}
// Gauss1DNoNorm returns the gaussian of a given x value, without normalizing
// (i.e., max value is 1 at x = 0)
func Gauss1DNoNorm(x, sig float32) float32 {
x /= sig
return math32.FastExp(-0.5 * x * x)
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package efuns
import "cogentcore.org/core/math32"
// Logistic is the logistic (sigmoid) function of x: 1/(1 + e^(-gain*(x-off)))
func Logistic(x, gain, off float32) float32 {
return 1.0 / (1.0 + math32.FastExp(-gain*(x-off)))
}
// Copyright (c) 2025, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package egui
import (
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/fileinfo"
"cogentcore.org/core/base/reflectx"
"cogentcore.org/core/core"
"cogentcore.org/core/styles"
"cogentcore.org/core/system"
"cogentcore.org/core/text/textcore"
)
// Config is an interface implemented by all [Sim] config types.
// To implement Config, you must embed [BaseConfig]. You must
// implement [Config.Defaults] yourself.
type Config interface {
// AsBaseConfig returns the embedded [BaseConfig].
AsBaseConfig() *BaseConfig
// Defaults sets default values for config fields.
// Helper functions such as [Run], [Embed], and [NewConfig] already set defaults
// based on struct tags, so you only need to set non-tag-based defaults here.
Defaults()
}
// BaseConfig contains the basic configuration parameters common to all sims.
type BaseConfig struct {
// Name is the short name of the sim.
Name string `display:"-"`
// Title is the longer title of the sim.
Title string `display:"-"`
// URL is a link to the online README or other documentation for this sim.
URL string `display:"-"`
// Doc is brief documentation of the sim.
Doc string `display:"-"`
// Includes has a list of additional config files to include.
// After configuration, it contains list of include files added.
Includes []string
// GUI indicates to open the GUI. Otherwise it runs automatically and quits,
// saving results to log files.
GUI bool `default:"true"`
// Debug indicates to report debugging information.
Debug bool
// GPU indicates to use the GPU for computation. This is on by default, except
// on web, where it is currently off by default.
GPU bool
}
func (bc *BaseConfig) AsBaseConfig() *BaseConfig { return bc }
func (bc *BaseConfig) IncludesPtr() *[]string { return &bc.Includes }
// BaseDefaults sets default values not specified by struct tags.
// It is called automatically by [NewConfig].
func (bc *BaseConfig) BaseDefaults() {
bc.GPU = core.TheApp.Platform() != system.Web // GPU compute not fully working on web yet
}
// ScriptFieldWidget is a core FieldWidget function to use a text Editor
// for the Params Script (or any other field named Script).
func ScriptFieldWidget(field string) core.Value {
if field == "Script" {
tx := textcore.NewEditor()
tx.Styler(func(s *styles.Style) {
s.Min.X.Em(60)
tx.Lines.SetLanguage(fileinfo.Go)
})
return tx
}
return nil
}
// NewConfig makes a new [Config] of type *C with defaults set.
func NewConfig[C any]() (*C, Config) { //yaegi:add
cfgC := new(C)
cfg := any(cfgC).(Config)
errors.Log(reflectx.SetFromDefaultTags(cfg))
cfg.AsBaseConfig().BaseDefaults()
cfg.Defaults()
return cfgC, cfg
}
// Code generated by "core generate -add-types"; DO NOT EDIT.
package egui
import (
"cogentcore.org/core/enums"
)
var _ToolGhostingValues = []ToolGhosting{0, 1, 2}
// ToolGhostingN is the highest valid value for type ToolGhosting, plus one.
const ToolGhostingN ToolGhosting = 3
var _ToolGhostingValueMap = map[string]ToolGhosting{`ActiveStopped`: 0, `ActiveRunning`: 1, `ActiveAlways`: 2}
var _ToolGhostingDescMap = map[ToolGhosting]string{0: ``, 1: ``, 2: ``}
var _ToolGhostingMap = map[ToolGhosting]string{0: `ActiveStopped`, 1: `ActiveRunning`, 2: `ActiveAlways`}
// String returns the string representation of this ToolGhosting value.
func (i ToolGhosting) String() string { return enums.String(i, _ToolGhostingMap) }
// SetString sets the ToolGhosting value from its string representation,
// and returns an error if the string is invalid.
func (i *ToolGhosting) SetString(s string) error {
return enums.SetString(i, s, _ToolGhostingValueMap, "ToolGhosting")
}
// Int64 returns the ToolGhosting value as an int64.
func (i ToolGhosting) Int64() int64 { return int64(i) }
// SetInt64 sets the ToolGhosting value from an int64.
func (i *ToolGhosting) SetInt64(in int64) { *i = ToolGhosting(in) }
// Desc returns the description of the ToolGhosting value.
func (i ToolGhosting) Desc() string { return enums.Desc(i, _ToolGhostingDescMap) }
// ToolGhostingValues returns all possible values for the type ToolGhosting.
func ToolGhostingValues() []ToolGhosting { return _ToolGhostingValues }
// Values returns all possible values for the type ToolGhosting.
func (i ToolGhosting) Values() []enums.Enum { return enums.Values(_ToolGhostingValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i ToolGhosting) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *ToolGhosting) UnmarshalText(text []byte) error {
return enums.UnmarshalText(i, text, "ToolGhosting")
}
// Copyright (c) 2022, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package egui
//go:generate core generate -add-types
import (
"embed"
"fmt"
"io/fs"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/fileinfo/mimedata"
"cogentcore.org/core/base/labels"
"cogentcore.org/core/core"
"cogentcore.org/core/enums"
"cogentcore.org/core/events"
"cogentcore.org/core/htmlcore"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/abilities"
"cogentcore.org/core/system"
"cogentcore.org/core/text/textcore"
"cogentcore.org/core/tree"
_ "cogentcore.org/lab/gosl/slbool/slboolcore" // include to get gui views
"cogentcore.org/lab/lab"
"github.com/emer/emergent/v2/etime"
"github.com/emer/emergent/v2/netview"
)
// GUI manages all standard elements of a simulation Graphical User Interface
type GUI struct {
lab.Browser
// CycleUpdateInterval is number of cycles between updates of cycle-level plots.
CycleUpdateInterval int
// Active is true if the GUI is configured and running
Active bool `display:"-"`
// NetViews are the created netviews.
NetViews []*netview.NetView
// SimForm displays the Sim object fields in the left panel.
SimForm *core.Form `display:"-"`
// Body is the entire content of the sim window.
Body *core.Body `display:"-"`
// Readme is the sim readme frame
Readme *core.Frame `display:"-"`
// OnStop is called when running is stopped through the GUI,
// via the Stopped method. It should update the network view for example.
OnStop func(mode, level enums.Enum)
// StopLevel is the enum to use when the stop button is pressed.
StopLevel enums.Enum
// isRunning is true if sim is running.
isRunning bool
// stopNow can be set via SetStopNow method under mutex protection
// to signal the current sim to stop running.
// It is not used directly in the looper-based control logic, which has
// its own direct Stop function, but it is set there in case there are
// other processes that are looking at this flag.
stopNow bool
runMu sync.Mutex
}
// UpdateWindow triggers an update on window body,
// to be called from within the normal event processing loop.
// See GoUpdateWindow for version to call from separate goroutine.
func (gui *GUI) UpdateWindow() {
if gui.Toolbar != nil {
gui.Toolbar.Restyle()
}
gui.SimForm.Update()
gui.Splits.NeedsRender()
// todo: could update other stuff but not really necessary
}
// GoUpdateWindow triggers an update on window body,
// for calling from a separate goroutine.
func (gui *GUI) GoUpdateWindow() {
gui.Splits.Scene.AsyncLock()
defer gui.Splits.Scene.AsyncUnlock()
gui.UpdateWindow()
}
// StartRun should be called whenever a process starts running.
// It sets stopNow = false and isRunning = true under a mutex.
func (gui *GUI) StartRun() {
gui.runMu.Lock()
gui.stopNow = false
gui.isRunning = true
gui.runMu.Unlock()
}
// IsRunning returns the state of the isRunning flag, under a mutex.
func (gui *GUI) IsRunning() bool {
gui.runMu.Lock()
defer gui.runMu.Unlock()
return gui.isRunning
}
// StopNow returns the state of the stopNow flag, under a mutex.
func (gui *GUI) StopNow() bool {
gui.runMu.Lock()
defer gui.runMu.Unlock()
return gui.stopNow
}
// SetStopNow sets the stopNow flag to true, under a mutex.
func (gui *GUI) SetStopNow() {
gui.runMu.Lock()
gui.stopNow = true
gui.runMu.Unlock()
}
// Stopped is called when a run method stops running,
// from a separate goroutine (do not call from main event loop).
// Turns off the isRunning flag, calls OnStop with the given arguments,
// and calls GoUpdateWindow to update window state.
func (gui *GUI) Stopped(mode, level enums.Enum) {
gui.runMu.Lock()
gui.isRunning = false
gui.stopNow = true // in case anyone else is looking
gui.runMu.Unlock()
if gui.OnStop != nil {
gui.OnStop(mode, level)
}
gui.GoUpdateWindow()
}
// NewGUIBody returns a new GUI, with an initialized Body by calling [gui.MakeBody].
func NewGUIBody(b tree.Node, sim any, fsroot fs.FS, appname, title, about string) *GUI {
gu := &GUI{}
gu.MakeBody(b, sim, fsroot, appname, title, about)
return gu
}
// MakeBody initializes default Body with a top-level [core.Splits] containing
// a [core.Form] editor of the given sim object, and a filetree for the data filesystem
// rooted at fsroot, and with given app name, title, and about information.
// The first arg is an optional existing [core.Body] to make into: if nil then
// a new body is made first. It takes an optional fs with a README.md file.
func (gui *GUI) MakeBody(b tree.Node, sim any, fsroot fs.FS, appname, title, about string, readme ...embed.FS) {
gui.StopLevel = etime.NoTime // corresponds to the first level typically
core.NoSentenceCaseFor = append(core.NoSentenceCaseFor, "github.com/emer")
if b == nil {
gui.Body = core.NewBody(appname).SetTitle(title)
b = gui.Body
core.AppAbout = about
} else {
gui.Toolbar = core.NewToolbar(b)
}
split := core.NewSplits(b)
split.Styler(func(s *styles.Style) {
s.Min.Y.Em(40)
})
split.Name = "split"
gui.Splits = split
gui.SimForm = core.NewForm(split).SetStruct(sim)
gui.SimForm.Name = "sim-form"
if tb, ok := sim.(core.ToolbarMaker); ok {
if gui.Body != nil {
gui.Body.AddTopBar(func(bar *core.Frame) {
gui.Toolbar = core.NewToolbar(bar)
gui.Toolbar.Maker(gui.MakeToolbar)
gui.Toolbar.Maker(tb.MakeToolbar)
})
} else {
gui.Toolbar.Maker(gui.MakeToolbar)
gui.Toolbar.Maker(tb.MakeToolbar)
}
}
fform := core.NewFrame(split)
fform.Styler(func(s *styles.Style) {
s.Direction = styles.Column
s.Overflow.Set(styles.OverflowAuto)
s.Grow.Set(1, 1)
})
gui.Files = lab.NewDataTree(fform)
tabs := lab.NewTabs(split)
gui.Tabs = tabs
lab.Lab = tabs
tabs.Name = "tabs"
gui.FS = fsroot
gui.DataRoot = "Root"
gui.CycleUpdateInterval = 10
gui.UpdateFiles()
gui.Files.Tabber = tabs
if len(readme) > 0 {
gui.addReadme(readme[0], split, appname)
} else {
split.SetTiles(core.TileSplit, core.TileSpan)
split.SetSplits(.2, .5, .8)
}
}
func (gui *GUI) addReadme(readmefs embed.FS, split *core.Splits, appname string) {
gui.Readme = core.NewFrame(split)
gui.Readme.Name = "readme"
split.SetTiles(core.TileSplit, core.TileSpan, core.TileSpan)
split.SetSplits(.2, .5, .5, .3)
ctx := htmlcore.NewContext()
ctx.GetURL = func(rawURL string) (*http.Response, error) {
return htmlcore.GetURLFromFS(readmefs, rawURL)
}
ctx.AddWikilinkHandler(gui.readmeWikilink("sim"))
ctx.OpenURL = gui.readmeOpenURL
eds := []*textcore.Editor{}
ctx.ElementHandlers["sim-question"] = func(ctx *htmlcore.Context) bool {
ed := textcore.NewEditor(ctx.BlockParent)
ed.Lines.Settings.LineNumbers = false
eds = append(eds, ed)
id := htmlcore.GetAttr(ctx.Node, "id")
ed.SetName(id)
ed.Styler(func(s *styles.Style) {
s.Min.Y.Em(10)
})
// used with Embed in shared context so need appname in filename to avoid conflicts
saveFile := filepath.Join(core.TheApp.AppDataDir(), appname+"-"+"q"+id+".md")
err := ed.Lines.Open(saveFile)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
err := os.WriteFile(saveFile, nil, 0666)
core.ErrorSnackbar(ed, err, "Error creating answer file")
if err == nil {
err := ed.Lines.Open(saveFile)
core.ErrorSnackbar(ed, err, "Error loading answer")
}
} else {
core.ErrorSnackbar(ed, err, "Error loading answer")
}
}
ed.OnChange(func(e events.Event) {
core.ErrorSnackbar(ed, ed.SaveQuiet(), "Error saving answer")
})
return true
}
core.NewButton(gui.Readme).SetText("Copy answers").OnClick(func(e events.Event) {
clipboard := gui.Readme.Clipboard()
var ab strings.Builder
for _, ed := range eds {
ab.WriteString("## Question " + ed.Name + "\n" + ed.Lines.String() + "\n")
}
answers := ab.String()
md := mimedata.NewText(answers)
clipboard.Write(md)
core.MessageSnackbar(gui.Body, "Answers copied to clipboard")
})
readme, err := readmefs.ReadFile("README.md")
if errors.Log(err) == nil {
htmlcore.ReadMDString(ctx, gui.Readme, string(readme))
}
}
func (gui *GUI) readmeWikilink(prefix string) htmlcore.WikilinkHandler {
return func(text string) (url string, label string) {
if !strings.HasPrefix(text, prefix+":") {
return "", ""
}
text = strings.TrimPrefix(text, prefix+":")
url = prefix + "://" + text
if strings.Contains(text, "/") {
_, text, _ = strings.Cut(text, "/")
}
return url, text
}
}
// readmeOpenURL Parses URL, highlights linked button or opens URL
func (gui *GUI) readmeOpenURL(url string) {
focusSet := false
if !strings.HasPrefix(url, "sim://") {
system.TheApp.OpenURL(url)
return
}
text := strings.TrimPrefix(url, "sim://")
var pathPrefix string = ""
hasPath := false
if strings.Contains(text, "/") {
pathPrefix, text, hasPath = strings.Cut(text, "/")
}
gui.Body.Scene.WidgetWalkDown(func(cw core.Widget, cwb *core.WidgetBase) bool {
if focusSet {
return tree.Break
}
if !hasPath && !cwb.IsDisplayable() {
return tree.Break
}
if hasPath && !strings.Contains(cw.AsTree().Path(), pathPrefix) {
return tree.Continue
}
label := labels.ToLabel(cw)
if !strings.EqualFold(label, text) {
return tree.Continue
}
if cwb.AbilityIs(abilities.Focusable) {
cwb.SetFocus()
focusSet = true
return tree.Break
}
next := core.AsWidget(tree.Next(cwb))
if next.AbilityIs(abilities.Focusable) {
next.SetFocus()
focusSet = true
return tree.Break
}
return tree.Continue
})
if !focusSet {
core.ErrorSnackbar(gui.Body, fmt.Errorf("invalid sim url %q", url))
}
}
// AddNetView adds NetView in tab with given name
func (gui *GUI) AddNetView(tabName string) *netview.NetView {
nv := lab.NewTab(gui.Tabs, tabName, func(tab *core.Frame) *netview.NetView {
nv := netview.NewNetView(tab)
nv.Var = "Act"
// tb.OnFinal(events.Click, func(e events.Event) {
// nv.Current()
// nv.Update()
// })
gui.NetViews = append(gui.NetViews, nv)
return nv
})
return nv
}
// NetView returns the first created netview, or nil if none.
func (gui *GUI) NetView() *netview.NetView {
if len(gui.NetViews) == 0 {
return nil
}
return gui.NetViews[0]
}
// FinalizeGUI wraps the end functionality of the GUI
func (gui *GUI) FinalizeGUI(closePrompt bool) {
gui.Active = true
if !closePrompt || gui.Body == nil {
return
}
gui.Body.AddCloseDialog(func(d *core.Body) bool {
d.SetTitle("Close?")
core.NewText(d).SetType(core.TextSupporting).SetText("Are you sure you want to close?")
d.AddBottomBar(func(bar *core.Frame) {
d.AddOK(bar).SetText("Close").OnClick(func(e events.Event) {
gui.Body.Close()
})
})
return true
})
}
// Copyright (c) 2022, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package egui
import (
"cmp"
"slices"
"strings"
"cogentcore.org/core/core"
"cogentcore.org/core/enums"
"cogentcore.org/core/events"
"cogentcore.org/core/icons"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/abilities"
"cogentcore.org/core/tree"
"github.com/emer/emergent/v2/looper"
)
// AddLooperCtrl adds toolbar control for looper.Stacks with Init, Run, Step controls,
// with selector for which stack is being controlled.
// A prefix can optionally be provided if multiple loops are used.
func (gui *GUI) AddLooperCtrl(p *tree.Plan, loops *looper.Stacks, prefix ...string) {
pfx := ""
lblpfx := ""
if len(prefix) == 1 {
pfx = strings.ToLower(prefix[0]) + "-"
lblpfx = prefix[0] + " "
}
modes := make([]enums.Enum, len(loops.Stacks))
var stepChoose *core.Chooser
var stepNSpin *core.Spinner
i := 0
for m := range loops.Stacks {
modes[i] = m
i++
}
slices.SortFunc(modes, func(a, b enums.Enum) int {
return cmp.Compare(a.Int64(), b.Int64())
})
curMode := modes[0]
curStep := loops.Stacks[curMode].StepLevel
updateSteps := func() {
st := loops.Stacks[curMode]
stepStrs := make([]string, len(st.Order))
cur := ""
for i, s := range st.Order {
sv := s.String()
stepStrs[i] = sv
if s.Int64() == curStep.Int64() {
cur = sv
}
}
stepChoose.SetStrings(stepStrs...)
stepChoose.SetCurrentValue(cur)
}
if len(modes) > 1 {
tree.AddAt(p, pfx+"loop-mode", func(w *core.Switches) {
w.SetType(core.SwitchSegmentedButton)
w.Mutex = true
w.SetEnums(modes...)
w.SelectValue(curMode)
w.FinalStyler(func(s *styles.Style) {
s.Grow.Set(0, 0)
})
w.OnChange(func(e events.Event) {
sel := w.SelectedItem()
if sel == nil || sel.Value == nil {
return
}
curMode = sel.Value.(enums.Enum)
st := loops.Stacks[curMode]
if st != nil {
curStep = st.StepLevel
}
updateSteps()
stepChoose.Update()
stepN := st.Loops[curStep].StepCount
stepNSpin.SetValue(float32(stepN))
stepNSpin.Update()
})
})
}
gui.AddToolbarItem(p, ToolbarItem{Label: lblpfx + "Init",
Icon: icons.Update,
Tooltip: "Initializes running and state for current mode.",
Active: ActiveStopped,
Func: func() {
loops.InitMode(curMode)
},
})
gui.AddToolbarItem(p, ToolbarItem{Label: lblpfx + "Stop",
Icon: icons.Stop,
Tooltip: "Interrupts current running. Will pick back up where it left off.",
Active: ActiveRunning,
Func: func() {
loops.Stop(gui.StopLevel)
// fmt.Println("Stop time!")
gui.SetStopNow()
},
})
tree.AddAt(p, pfx+"loop-run", func(w *core.Button) {
tb := gui.Toolbar
w.SetText("Run").SetIcon(icons.PlayArrow).
SetTooltip("Run the current mode, picking up from where it left off last time (Init to restart)")
w.FirstStyler(func(s *styles.Style) { s.SetEnabled(!gui.IsRunning()) })
w.OnClick(func(e events.Event) {
if !gui.IsRunning() {
gui.StartRun()
tb.Restyle()
go func() {
stop := loops.Run(curMode)
gui.Stopped(curMode, stop)
}()
}
})
})
tree.AddAt(p, pfx+"loop-step", func(w *core.Button) {
tb := gui.Toolbar
w.SetText("Step").SetIcon(icons.SkipNext).
SetTooltip("Step the current mode, according to the following step level and N")
w.FirstStyler(func(s *styles.Style) {
s.SetEnabled(!gui.IsRunning())
s.SetAbilities(true, abilities.RepeatClickable)
})
w.OnClick(func(e events.Event) {
if !gui.IsRunning() {
gui.StartRun()
tb.Restyle()
go func() {
st := loops.Stacks[curMode]
nst := int(stepNSpin.Value)
stop := loops.Step(curMode, nst, st.StepLevel)
gui.Stopped(curMode, stop)
}()
}
})
})
tree.AddAt(p, pfx+"step-level", func(w *core.Chooser) {
stepChoose = w
updateSteps()
w.SetCurrentValue(curStep.String())
w.OnChange(func(e events.Event) {
st := loops.Stacks[curMode]
if w.CurrentItem.Value == nil {
return
}
cs := w.CurrentItem.Value.(string)
for _, l := range st.Order {
if l.String() == cs {
st.StepLevel = l
stepNSpin.Value = float32(st.Loops[l].StepCount)
stepNSpin.Update()
break
}
}
})
})
tree.AddAt(p, pfx+"step-n", func(w *core.Spinner) {
stepNSpin = w
w.SetStep(1).SetMin(1).SetValue(1)
w.SetTooltip("number of iterations per step")
w.OnChange(func(e events.Event) {
st := loops.Stacks[curMode]
if st != nil {
st.StepCount = int(w.Value)
st.Loops[st.StepLevel].StepCount = st.StepCount
}
})
})
}
// Copyright (c) 2025, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package egui
import (
"cogentcore.org/core/cli"
"cogentcore.org/core/core"
"cogentcore.org/core/tree"
)
// Sim is an interface implemented by all sim types.
// It is parameterized by the config type C. *C must implement [Config].
//
// See [Run], [RunSim], and [Embed].
type Sim[C any] interface {
// SetConfig sets the sim config.
SetConfig(cfg *C)
ConfigSim()
Init()
ConfigGUI(b tree.Node)
// Body returns the [core.Body] used by the sim.
Body() *core.Body
RunNoGUI()
}
// Run runs a sim of the given type S with config type C. *S must implement [Sim][C]
// (interface [Sim] parameterized by config type C), and *C must implement [Config].
//
// This is a high-level helper function designed to be called as one-liner
// from the main() function of the sim's command subdirectory with package main.
// This subdirectory has the same name as the sim name itself, ex: sims/ra25
// has the package with the sim logic, and sims/ra25/ra25 has the compilable main().
//
// Run uses the config type C to make a new [Config] object and set its default values
// with [Config.Defaults].
func Run[S, C any]() {
cfgC, cfg := NewConfig[C]()
bc := cfg.AsBaseConfig()
opts := cli.DefaultOptions(bc.Name, bc.Title)
opts.DefaultFiles = append(opts.DefaultFiles, "config.toml")
opts.SearchUp = true // so that the sim can be run from the command subdirectory
opts.IncludePaths = append(opts.IncludePaths, "../configs")
cli.Run(opts, cfgC, RunSim[S, C])
}
// RunSim runs a sim with the given config. *S must implement [Sim][C]
// (interface [Sim] parameterized by config type C).
//
// Unlike [Run], this does not handle command-line config parsing. End users
// should typically use [Run], which uses RunSim under the hood.
func RunSim[S, C any](cfg *C) error {
simS := new(S)
sim := any(simS).(Sim[C])
bc := any(cfg).(Config).AsBaseConfig()
sim.SetConfig(cfg)
sim.ConfigSim()
if bc.GUI {
sim.Init()
sim.ConfigGUI(nil)
sim.Body().RunMainWindow()
} else {
sim.RunNoGUI()
}
return nil
}
// Embed runs a sim with the default config, embedding it under the given parent node.
// It returns the resulting sim. *S must implement [Sim][C] (interface [Sim]
// parameterized by config type C).
//
// See also [Run] and [RunSim].
func Embed[S, C any](parent tree.Node) *S { //yaegi:add
cfgC, cfg := NewConfig[C]()
cfg.AsBaseConfig().GUI = true // force GUI on
simS := new(S)
sim := any(simS).(Sim[C])
sim.SetConfig(cfgC)
sim.ConfigSim()
sim.Init()
sim.ConfigGUI(parent)
return simS
}
// Copyright (c) 2022, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package egui
import (
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/icons"
"cogentcore.org/core/styles"
"cogentcore.org/core/tree"
)
// ToolbarItem holds the configuration values for a toolbar item
type ToolbarItem struct {
Label string
Icon icons.Icon
Tooltip string
Active ToolGhosting
Func func()
}
// AddToolbarItem adds a toolbar item but also checks when it be active in the UI
func (gui *GUI) AddToolbarItem(p *tree.Plan, item ToolbarItem) {
tree.AddAt(p, item.Label, func(w *core.Button) {
w.SetText(item.Label).SetIcon(item.Icon).
SetTooltip(item.Tooltip).OnClick(func(e events.Event) {
item.Func()
})
switch item.Active {
case ActiveStopped:
w.FirstStyler(func(s *styles.Style) { s.SetEnabled(!gui.IsRunning()) })
case ActiveRunning:
w.FirstStyler(func(s *styles.Style) { s.SetEnabled(gui.IsRunning()) })
}
})
}
// Copyright (c) 2024, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"os"
"path/filepath"
"text/template"
"cogentcore.org/core/base/exec"
)
// Build builds a Docker image for the emergent model in the current directory.
func Build(c *Config) error { //types:add
f, err := os.Create("Dockerfile")
if err != nil {
return err
}
defer f.Close()
err = DockerfileTmpl.Execute(f, c)
if err != nil {
return err
}
err = exec.Verbose().SetBuffer(false).Run("docker", "build", "-t", filepath.Base(c.Dir)+":latest", ".")
if err != nil {
return err
}
err = f.Close()
if err != nil {
return err
}
err = os.RemoveAll("Dockerfile")
if err != nil {
return err
}
return nil
}
// Partially based on https://github.com/rickyjames35/vulkan_docker_test/blob/main/Dockerfile
var DockerfileTmpl = template.Must(template.New("Dockerfile").Parse(
`FROM golang:1.21-bookworm as builder
WORKDIR /build
# By copying the go.mod and go.sum and downloading the deps first, it can cache all of the dependencies
COPY go.* ./
RUN go mod download
COPY . ./
WORKDIR /build/{{.Dir}}
RUN go build -tags offscreen -o ./app
FROM ubuntu:latest as runner
# Needed to share GPU
ENV NVIDIA_DRIVER_CAPABILITIES=all
ENV NVIDIA_VISIBLE_DEVICES=all
RUN apt-get update && \
export DEBIAN_FRONTEND=noninteractive && \
apt-get install -y pciutils vulkan-tools mesa-utils
COPY --from=builder /build/{{.Dir}} /build
WORKDIR /build
CMD ["./app", "-nogui"]
`))
// Copyright (c) 2024, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Command ekube provides easy building of Docker images for emergent models
// and the deployment of those images to Kubernetes clusters.
package main
import "cogentcore.org/core/cli"
//go:generate core generate
func main() {
opts := cli.DefaultOptions("ekube", "ekube provides easy building of Docker images for emergent models and the deployment of those images to Kubernetes clusters.")
cli.Run(opts, &Config{}, Build)
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package emer
import (
"fmt"
"io"
"math"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/slicesx"
"cogentcore.org/core/math32"
"cogentcore.org/lab/tensor"
"github.com/emer/emergent/v2/params"
"github.com/emer/emergent/v2/relpos"
"github.com/emer/emergent/v2/weights"
)
var (
// LayerDimNames2D provides the standard Shape dimension names for 2D layers
LayerDimNames2D = []string{"Y", "X"}
// LayerDimNames4D provides the standard Shape dimension names for 4D layers
// which have Pools and then neurons within pools.
LayerDimNames4D = []string{"PoolY", "PoolX", "NeurY", "NeurX"}
)
// Layer defines the minimal interface for neural network layers,
// necessary to support the visualization (NetView), I/O,
// and parameter setting functionality provided by emergent.
// Most of the standard expected functionality is defined in the
// LayerBase struct, and this interface only has methods that must be
// implemented specifically for a given algorithmic implementation.
type Layer interface {
// AsEmer returns the layer as an *emer.LayerBase,
// to access base functionality.
AsEmer() *LayerBase
// Label satisfies the core.Labeler interface for getting
// the name of objects generically. Use to access Name via interface.
Label() string
// TypeName is the type or category of layer, defined
// by the algorithm (and usually set by an enum).
TypeName() string
// TypeNumber is the numerical value for the type or category
// of layer, defined by the algorithm (and usually set by an enum).
TypeNumber() int
// UnitVarIndex returns the index of given variable within
// the Neuron, according to *this layer's* UnitVarNames() list
// (using a map to lookup index), or -1 and error message if
// not found.
UnitVarIndex(varNm string) (int, error)
// UnitValue1D returns value of given variable index on given unit,
// using 1-dimensional index, and a data parallel index di,
// for networks capable of processing multiple input patterns
// in parallel. Returns NaN on invalid index.
// This is the core unit var access method used by other methods,
// so it is the only one that needs to be updated for derived layer types.
UnitValue1D(varIndex int, idx, di int) float32
// VarRange returns the min / max values for given variable
VarRange(varNm string) (min, max float32, err error)
// NumRecvPaths returns the number of receiving pathways.
NumRecvPaths() int
// RecvPath returns a specific receiving pathway.
RecvPath(idx int) Path
// NumSendPaths returns the number of sending pathways.
NumSendPaths() int
// SendPath returns a specific sending pathway.
SendPath(idx int) Path
// RecvPathValues fills in values of given synapse variable name,
// for pathway from given sending layer and neuron 1D index,
// for all receiving neurons in this layer,
// into given float32 slice (only resized if not big enough).
// pathType is the string representation of the path type;
// used if non-empty, useful when there are multiple pathways
// between two layers.
// Returns error on invalid var name.
// If the receiving neuron is not connected to the given sending
// layer or neuron then the value is set to math32.NaN().
// Returns error on invalid var name or lack of recv path
// (vals always set to nan on path err).
RecvPathValues(vals *[]float32, varNm string, sendLay Layer, sendIndex1D int, pathType string) error
// SendPathValues fills in values of given synapse variable name,
// for pathway into given receiving layer and neuron 1D index,
// for all sending neurons in this layer,
// into given float32 slice (only resized if not big enough).
// pathType is the string representation of the path type -- used if non-empty,
// useful when there are multiple pathways between two layers.
// Returns error on invalid var name.
// If the sending neuron is not connected to the given receiving layer or neuron
// then the value is set to math32.NaN().
// Returns error on invalid var name or lack of recv path (vals always set to nan on path err).
SendPathValues(vals *[]float32, varNm string, recvLay Layer, recvIndex1D int, pathType string) error
// ParamsString returns a listing of all parameters in the Layer and
// pathways within the layer. If nonDefault is true, only report those
// not at their default values.
ParamsString(nonDefault bool) string
// WriteWeightsJSON writes the weights from this layer from the
// receiver-side perspective in a JSON text format.
WriteWeightsJSON(w io.Writer, depth int)
// SetWeights sets the weights for this layer from weights.Layer
// decoded values
SetWeights(lw *weights.Layer) error
}
// LayerBase defines the basic shared data for neural network layers,
// used for managing the structural elements of a network,
// and for visualization, I/O, etc.
// Nothing algorithm-specific is implemented here
type LayerBase struct {
// EmerLayer provides access to the emer.Layer interface
// methods for functions defined in the LayerBase type.
// Must set this with a pointer to the actual instance
// when created, using InitLayer function.
EmerLayer Layer `display:"-"`
// Name of the layer, which must be unique within the network.
// Layers are typically accessed directly by name, via a map.
Name string
// Class is for applying parameter styles across multiple layers
// that all get the same parameters. This can be space separated
// with multple classes.
Class string
// Doc contains documentation about the layer.
// This is displayed in a tooltip in the network view.
Doc string
// Off turns off the layer, removing from all computations.
// This provides a convenient way to dynamically test for
// the contributions of the layer, for example.
Off bool
// Shape of the layer, either 2D or 4D. Although spatial topology
// is not relevant to all algorithms, the 2D shape is important for
// efficiently visualizing large numbers of units / neurons.
// 4D layers have 2D Pools of units embedded within a larger 2D
// organization of such pools. This is used for max-pooling or
// pooled inhibition at a finer-grained level, and biologically
// corresopnds to hypercolumns in the cortex for example.
// Order is outer-to-inner (row major), so Y then X for 2D;
// 4D: Y-X unit pools then Y-X neurons within pools.
Shape tensor.Shape
// Pos specifies the relative spatial relationship to another
// layer, which determines positioning. Every layer except one
// "anchor" layer should be positioned relative to another,
// e.g., RightOf, Above, etc. This provides robust positioning
// in the face of layer size changes etc.
// Layers are arranged in X-Y planes, stacked vertically along the Z axis.
Pos relpos.Pos `table:"-" display:"inline"`
// Index is a 0..n-1 index of the position of the layer within
// the list of layers in the network.
Index int `display:"-" edit:"-"`
// SampleIndexes are the current set of "sample" unit indexes,
// which are a smaller subset of units that represent the behavior
// of the layer, for computationally intensive statistics and displays
// (e.g., PCA, ActRF, NetView rasters), when the layer is large.
// If none have been set, then all units are used.
// See utility function CenterPoolIndexes that returns indexes of
// units in the central pools of a 4D layer.
SampleIndexes []int `table:"-"`
// SampleShape is the shape to use for the subset of sample
// unit indexes, in terms of an array of dimensions.
// See Shape for more info.
// Layers that set SampleIndexes should also set this,
// otherwise a 1D array of len SampleIndexes will be used.
// See utility function CenterPoolShape that returns shape of
// units in the central pools of a 4D layer.
SampleShape tensor.Shape `table:"-"`
// optional metadata that is saved in network weights files,
// e.g., can indicate number of epochs that were trained,
// or any other information about this network that would be useful to save.
MetaData map[string]string
}
// InitLayer initializes the layer, setting the EmerLayer interface
// to provide access to it for LayerBase methods, along with the name.
func InitLayer(l Layer, name string) {
lb := l.AsEmer()
lb.EmerLayer = l
lb.Name = name
}
func (ly *LayerBase) AsEmer() *LayerBase { return ly }
func (ly *LayerBase) Label() string { return ly.Name }
// AddClass adds a CSS-style class name(s) for this layer,
// ensuring that it is not a duplicate, and properly space separated.
// Returns Layer so it can be chained to set other properties too.
func (ly *LayerBase) AddClass(cls ...string) *LayerBase {
ly.Class = params.AddClass(ly.Class, cls...)
return ly
}
// Is2D() returns true if this is a 2D layer (no Pools)
func (ly *LayerBase) Is2D() bool { return ly.Shape.NumDims() == 2 }
// Is4D() returns true if this is a 4D layer (has Pools as inner 2 dimensions)
func (ly *LayerBase) Is4D() bool { return ly.Shape.NumDims() == 4 }
func (ly *LayerBase) NumUnits() int { return ly.Shape.Len() }
// Index4DFrom2D returns the 4D index from 2D coordinates
// within which inner dims are interleaved. Returns false if 2D coords are invalid.
func (ly *LayerBase) Index4DFrom2D(x, y int) ([]int, bool) {
lshp := ly.Shape
nux := lshp.DimSize(3)
nuy := lshp.DimSize(2)
ux := x % nux
uy := y % nuy
px := x / nux
py := y / nuy
idx := []int{py, px, uy, ux}
if !lshp.IndexIsValid(idx...) {
return nil, false
}
return idx, true
}
// PlaceRightOf positions the layer to the right of the other layer,
// with given spacing, using default YAlign = Front alignment.
func (ly *LayerBase) PlaceRightOf(other Layer, space float32) {
ly.Pos.SetRightOf(other.AsEmer().Name, space)
}
// PlaceBehind positions the layer behind the other layer,
// with given spacing, using default XAlign = Left alignment.
func (ly *LayerBase) PlaceBehind(other Layer, space float32) {
ly.Pos.SetBehind(other.AsEmer().Name, space)
}
// PlaceAbove positions the layer above the other layer,
// using default XAlign = Left, YAlign = Front alignment.
func (ly *LayerBase) PlaceAbove(other Layer) {
ly.Pos.SetAbove(other.AsEmer().Name)
}
// DisplaySize returns the display size of this layer for the 3D view.
// see Pos field for general info.
// This is multiplied by the Pos.Scale factor to rescale
// layer sizes, and takes into account 2D and 4D layer structures.
func (ly *LayerBase) DisplaySize() math32.Vector2 {
if ly.Pos.Scale == 0 {
ly.Pos.Defaults()
}
var sz math32.Vector2
switch {
case ly.Is2D():
sz = math32.Vec2(float32(ly.Shape.DimSize(1)), float32(ly.Shape.DimSize(0))) // Y, X
case ly.Is4D():
// note: pool spacing is handled internally in display and does not affect overall size
sz = math32.Vec2(float32(ly.Shape.DimSize(1)*ly.Shape.DimSize(3)), float32(ly.Shape.DimSize(0)*ly.Shape.DimSize(2))) // Y, X
default:
sz = math32.Vec2(float32(ly.Shape.Len()), 1)
}
return sz.MulScalar(ly.Pos.Scale)
}
// SetShape sets the layer shape and also uses default dim names.
func (ly *LayerBase) SetShape(shape ...int) {
ly.Shape.SetShapeSizes(shape...)
}
// SetSampleShape sets the SampleIndexes,
// and SampleShape and as list of dimension sizes,
// for a subset sample of units to represent the entire layer.
// This is critical for large layers that are otherwise unwieldy
// to visualize and for computationally-intensive statistics.
func (ly *LayerBase) SetSampleShape(idxs, shape []int) {
ly.SampleIndexes = idxs
ly.SampleShape.SetShapeSizes(shape...)
}
// GetSampleShape returns the shape to use for representative units.
func (ly *LayerBase) GetSampleShape() *tensor.Shape {
sz := len(ly.SampleIndexes)
if sz == 0 {
return &ly.Shape
}
if ly.SampleShape.Len() != sz {
ly.SampleShape.SetShapeSizes(sz)
}
return &ly.SampleShape
}
// NSubPools returns the number of sub-pools of neurons
// according to the shape parameters. 2D shapes have 0 sub pools.
// For a 4D shape, the pools are the first set of 2 Y,X dims
// and then the neurons within the pools are the 2nd set of 2 Y,X dims.
func (ly *LayerBase) NumPools() int {
if ly.Shape.NumDims() != 4 {
return 0
}
return ly.Shape.DimSize(0) * ly.Shape.DimSize(1)
}
// UnitValues fills in values of given variable name on unit,
// for each unit in the layer, into given float32 slice
// (only resized if not big enough).
// di is a data parallel index di, for networks capable of
// processing input patterns in parallel.
// Returns error on invalid var name.
func (ly *LayerBase) UnitValues(vals *[]float32, varNm string, di int) error {
nn := ly.NumUnits()
*vals = slicesx.SetLength(*vals, nn)
vidx, err := ly.EmerLayer.UnitVarIndex(varNm)
if err != nil {
nan := math32.NaN()
for lni := range nn {
(*vals)[lni] = nan
}
return err
}
for lni := range nn {
(*vals)[lni] = ly.EmerLayer.UnitValue1D(vidx, lni, di)
}
return nil
}
// UnitValuesTensor fills in values of given variable name
// on unit for each unit in the layer, into given tensor.
// di is a data parallel index di, for networks capable of
// processing input patterns in parallel.
// If tensor is not already big enough to hold the values, it is
// set to the same shape as the layer.
// Returns error on invalid var name.
func (ly *LayerBase) UnitValuesTensor(tsr tensor.Values, varNm string, di int) error {
if tsr == nil {
err := fmt.Errorf("emer.UnitValuesTensor: Tensor is nil")
return errors.Log(err)
}
nn := ly.NumUnits()
tsr.SetShapeSizes(ly.Shape.Sizes...)
vidx, err := ly.EmerLayer.UnitVarIndex(varNm)
if err != nil {
nan := math.NaN()
for lni := 0; lni < nn; lni++ {
tsr.SetFloat1D(nan, lni)
}
return err
}
for lni := 0; lni < nn; lni++ {
v := ly.EmerLayer.UnitValue1D(vidx, lni, di)
if math32.IsNaN(v) {
tsr.SetFloat1D(math.NaN(), lni)
} else {
tsr.SetFloat1D(float64(v), lni)
}
}
return nil
}
// UnitValuesSampleTensor fills in values of given variable name
// on unit for a smaller subset of representative units
// in the layer, into given tensor.
// di is a data parallel index di, for networks capable of
// processing input patterns in parallel.
// This is used for computationally intensive stats or displays that work
// much better with a smaller number of units.
// The set of representative units are defined by SetSampleIndexes -- all units
// are used if no such subset has been defined.
// If tensor is not already big enough to hold the values, it is
// set to SampleShape to hold all the values if subset is defined,
// otherwise it calls UnitValuesTensor and is identical to that.
// Returns error on invalid var name.
func (ly *LayerBase) UnitValuesSampleTensor(tsr tensor.Values, varNm string, di int) error {
nu := len(ly.SampleIndexes)
if nu == 0 {
return ly.UnitValuesTensor(tsr, varNm, di)
}
if tsr == nil {
err := fmt.Errorf("emer.UnitValuesSampleTensor: Tensor is nil")
return errors.Log(err)
}
if tsr.Len() != nu {
rs := ly.GetSampleShape()
tsr.SetShapeSizes(rs.Sizes...)
}
vidx, err := ly.EmerLayer.UnitVarIndex(varNm)
if err != nil {
nan := math.NaN()
for i, _ := range ly.SampleIndexes {
tsr.SetFloat1D(nan, i)
}
return err
}
for i, ui := range ly.SampleIndexes {
v := ly.EmerLayer.UnitValue1D(vidx, ui, di)
if math32.IsNaN(v) {
tsr.SetFloat1D(math.NaN(), i)
} else {
tsr.SetFloat1D(float64(v), i)
}
}
return nil
}
// UnitValue returns value of given variable name on given unit,
// using shape-based dimensional index.
// Returns NaN on invalid var name or index.
// di is a data parallel index di, for networks capable of
// processing input patterns in parallel.
func (ly *LayerBase) UnitValue(varNm string, idx []int, di int) float32 {
vidx, err := ly.EmerLayer.UnitVarIndex(varNm)
if err != nil {
return math32.NaN()
}
fidx := ly.Shape.IndexTo1D(idx...)
return ly.EmerLayer.UnitValue1D(vidx, fidx, di)
}
// CenterPoolIndexes returns the indexes for n x n center pools of given 4D layer.
// Useful for setting SampleIndexes on Layer.
// Will crash if called on non-4D layers.
func CenterPoolIndexes(ly Layer, n int) []int {
lb := ly.AsEmer()
nPy := lb.Shape.DimSize(0)
nPx := lb.Shape.DimSize(1)
sPy := (nPy - n) / 2
sPx := (nPx - n) / 2
nu := lb.Shape.DimSize(2) * lb.Shape.DimSize(3)
nt := n * n * nu
idxs := make([]int, nt)
ix := 0
for py := 0; py < n; py++ {
for px := 0; px < n; px++ {
si := ((py+sPy)*nPx + px + sPx) * nu
for ui := 0; ui < nu; ui++ {
idxs[ix+ui] = si + ui
}
ix += nu
}
}
return idxs
}
// CenterPoolShape returns shape for n x n center pools of given 4D layer.
// Useful for setting SampleShape on Layer.
func CenterPoolShape(ly Layer, n int) []int {
lb := ly.AsEmer()
return []int{n, n, lb.Shape.DimSize(2), lb.Shape.DimSize(3)}
}
// Layer2DSampleIndexes returns neuron indexes and corresponding 2D shape
// for the representative neurons within a large 2D layer, for passing to
// [SetSampleShape]. These neurons are used for the raster plot
// in the GUI and for computing PCA, among other cases where the full set
// of neurons is problematic. The lower-left corner of neurons up to
// given maxSize is selected.
func Layer2DSampleIndexes(ly Layer, maxSize int) (idxs, shape []int) {
lb := ly.AsEmer()
sh := lb.Shape
my := min(maxSize, sh.DimSize(0))
mx := min(maxSize, sh.DimSize(1))
shape = []int{my, mx}
idxs = make([]int, my*mx)
i := 0
for y := 0; y < my; y++ {
for x := 0; x < mx; x++ {
idxs[i] = sh.IndexTo1D(y, x)
i++
}
}
return
}
// RecvPathBySendName returns the receiving Path with given
// sending layer name (the first one if multiple exist).
func (ly *LayerBase) RecvPathBySendName(sender string) (Path, error) {
el := ly.EmerLayer
for pi := range el.NumRecvPaths() {
pt := el.RecvPath(pi)
if pt.SendLayer().Label() == sender {
return pt, nil
}
}
return nil, fmt.Errorf("sending layer named: %s not found in list of receiving pathways", sender)
}
// SendPathByRecvName returns the sending Path with given
// recieving layer name (the first one if multiple exist).
func (ly *LayerBase) SendPathByRecvName(recv string) (Path, error) {
el := ly.EmerLayer
for pi := range el.NumSendPaths() {
pt := el.SendPath(pi)
if pt.RecvLayer().Label() == recv {
return pt, nil
}
}
return nil, fmt.Errorf("receiving layer named: %s not found in list of sending pathways", recv)
}
// RecvPathBySendName returns the receiving Path with given
// sending layer name, with the given type name
// (the first one if multiple exist).
func (ly *LayerBase) RecvPathBySendNameType(sender, typeName string) (Path, error) {
el := ly.EmerLayer
for pi := range el.NumRecvPaths() {
pt := el.RecvPath(pi)
if pt.SendLayer().Label() == sender && pt.TypeName() == typeName {
return pt, nil
}
}
return nil, fmt.Errorf("sending layer named: %s of type %s not found in list of receiving pathways", sender, typeName)
}
// SendPathByRecvName returns the sending Path with given
// recieving layer name, with the given type name
// (the first one if multiple exist).
func (ly *LayerBase) SendPathByRecvNameType(recv, typeName string) (Path, error) {
el := ly.EmerLayer
for pi := range el.NumSendPaths() {
pt := el.SendPath(pi)
if pt.RecvLayer().Label() == recv && pt.TypeName() == typeName {
return pt, nil
}
}
return nil, fmt.Errorf("receiving layer named: %s, type: %s not found in list of sending pathways", recv, typeName)
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package emer
//go:generate core generate -add-types
import (
"fmt"
"io"
"os"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/core"
"cogentcore.org/core/math32"
"cogentcore.org/lab/base/randx"
"github.com/emer/emergent/v2/relpos"
)
// VarCategory represents one category of unit, synapse variables.
type VarCategory struct {
// Category name.
Cat string
// Documentation of the category, used as a tooltip.
Doc string
}
// Network defines the minimal interface for a neural network,
// used for managing the structural elements of a network,
// and for visualization, I/O, etc.
// Most of the standard expected functionality is defined in the
// NetworkBase struct, and this interface only has methods that must be
// implemented specifically for a given algorithmic implementation.
type Network interface {
// AsEmer returns the network as an *emer.NetworkBase,
// to access base functionality.
AsEmer() *NetworkBase
// Label satisfies the core.Labeler interface for getting
// the name of objects generically.
Label() string
// NumLayers returns the number of layers in the network.
NumLayers() int
// EmerLayer returns layer as emer.Layer interface at given index.
// Does not do extra bounds checking.
EmerLayer(idx int) Layer
// MaxParallelData returns the maximum number of data inputs that can be
// processed in parallel by the network.
// The NetView supports display of up to this many data elements.
MaxParallelData() int
// NParallelData returns the current number of data inputs currently being
// processed in parallel by the network.
// Logging supports recording each of these where appropriate.
NParallelData() int
// Defaults sets default parameter values for everything in the Network.
Defaults()
// UpdateParams() updates parameter values for all Network parameters,
// based on any other params that might have changed.
UpdateParams()
// KeyLayerParams returns a listing for all layers in the network,
// of the most important layer-level params (specific to each algorithm).
KeyLayerParams() string
// KeyPathParams returns a listing for all Recv pathways in the network,
// of the most important pathway-level params (specific to each algorithm).
KeyPathParams() string
// UnitVarNames returns a list of variable names available on
// the units in this network.
// This list determines what is shown in the NetView
// (and the order of vars list).
// Not all layers need to support all variables,
// but must safely return math32.NaN() for unsupported ones.
// This is typically a global list so do not modify!
UnitVarNames() []string
// UnitVarProps returns a map of unit variable properties,
// with the key being the name of the variable,
// and the value gives a space-separated list of
// go-tag-style properties for that variable.
// The NetView recognizes the following properties:
// - range:"##" = +- range around 0 for default display scaling
// - min:"##" max:"##" = min, max display range
// - auto-scale:"+" or "-" = use automatic scaling instead of fixed range or not.
// - zeroctr:"+" or "-" = control whether zero-centering is used
// - desc:"txt" tooltip description of the variable
// - cat:"cat" variable category, for category tabs
UnitVarProps() map[string]string
// VarCategories is a list of unit & synapse variable categories,
// which organizes the variables into separate tabs in the network view.
// Using categories results in a more compact display and makes it easier
// to find variables.
// Set the 'cat' property in the UnitVarProps, SynVarProps for each variable.
// If no categories returned, the default is Unit, Wt.
VarCategories() []VarCategory
// SynVarNames returns the names of all the variables
// on the synapses in this network.
// This list determines what is shown in the NetView
// (and the order of vars list).
// Not all pathways need to support all variables,
// but must safely return math32.NaN() for
// unsupported ones.
// This is typically a global list so do not modify!
SynVarNames() []string
// SynVarProps returns a map of synapse variable properties,
// with the key being the name of the variable,
// and the value gives a space-separated list of
// go-tag-style properties for that variable.
// The NetView recognizes the following properties:
// range:"##" = +- range around 0 for default display scaling
// min:"##" max:"##" = min, max display range
// auto-scale:"+" or "-" = use automatic scaling instead of fixed range or not.
// zeroctr:"+" or "-" = control whether zero-centering is used
// Note: this is typically a global list so do not modify!
SynVarProps() map[string]string
// ReadWeightsJSON reads network weights from the receiver-side perspective
// in a JSON text format. Reads entire file into a temporary weights.Weights
// structure that is then passed to Layers etc using SetWeights method.
// Call the NetworkBase version followed by any post-load updates.
ReadWeightsJSON(r io.Reader) error
// WriteWeightsJSON writes the weights from this network
// from the receiver-side perspective in a JSON text format.
// Call the NetworkBase version after pre-load updates.
WriteWeightsJSON(w io.Writer) error
}
// NetworkBase defines the basic data for a neural network,
// used for managing the structural elements of a network,
// and for visualization, I/O, etc.
type NetworkBase struct {
// EmerNetwork provides access to the emer.Network interface
// methods for functions defined in the NetworkBase type.
// Must set this with a pointer to the actual instance
// when created, using InitNetwork function.
EmerNetwork Network `display:"-"`
// overall name of network, which helps discriminate if there are multiple.
Name string
// filename of last weights file loaded or saved.
WeightsFile string
// map of name to layers, for EmerLayerByName methods
LayerNameMap map[string]Layer `display:"-"`
// minimum display position in network
MinPos math32.Vector3 `display:"-"`
// maximum display position in network
MaxPos math32.Vector3 `display:"-"`
// optional metadata that is saved in network weights files,
// e.g., can indicate number of epochs that were trained,
// or any other information about this network that would be useful to save.
MetaData map[string]string
// random number generator for the network.
// all random calls must use this.
// Set seed here for weight initialization values.
Rand randx.SysRand `display:"-"`
// Random seed to be set at the start of configuring
// the network and initializing the weights.
// Set this to get a different set of weights.
RandSeed int64 `edit:"-"`
}
// InitNetwork initializes the network, setting the EmerNetwork interface
// to provide access to it for NetworkBase methods, along with the name.
func InitNetwork(nt Network, name string) {
nb := nt.AsEmer()
nb.EmerNetwork = nt
nb.Name = name
}
func (nt *NetworkBase) AsEmer() *NetworkBase { return nt }
func (nt *NetworkBase) Label() string { return nt.Name }
// UpdateLayerNameMap updates the LayerNameMap.
func (nt *NetworkBase) UpdateLayerNameMap() {
if nt.LayerNameMap == nil {
nt.LayerNameMap = make(map[string]Layer)
}
nl := nt.EmerNetwork.NumLayers()
for li := range nl {
ly := nt.EmerNetwork.EmerLayer(li)
lnm := ly.Label()
nt.LayerNameMap[lnm] = ly
}
}
// EmerLayerByName returns a layer by looking it up by name.
// returns error message if layer is not found.
func (nt *NetworkBase) EmerLayerByName(name string) (Layer, error) {
if nt.LayerNameMap == nil || len(nt.LayerNameMap) != nt.EmerNetwork.NumLayers() {
nt.UpdateLayerNameMap()
}
if ly, ok := nt.LayerNameMap[name]; ok {
return ly, nil
}
err := fmt.Errorf("Layer named: %s not found in Network: %s", name, nt.Name)
return nil, err
}
// EmerPathByName returns a path by looking it up by name.
// Paths are named SendToRecv = sending layer name "To" recv layer name.
// returns error message if path is not found.
func (nt *NetworkBase) EmerPathByName(name string) (Path, error) {
ti := strings.Index(name, "To")
if ti < 0 {
return nil, errors.Log(fmt.Errorf("EmerPathByName: path name must contain 'To': %s", name))
}
sendNm := name[:ti]
recvNm := name[ti+2:]
_, err := nt.EmerLayerByName(sendNm)
if errors.Log(err) != nil {
return nil, err
}
recv, err := nt.EmerLayerByName(recvNm)
if errors.Log(err) != nil {
return nil, err
}
path, err := recv.AsEmer().RecvPathBySendName(sendNm)
if errors.Log(err) != nil {
return nil, err
}
return path, nil
}
// LayoutLayers computes the 3D layout of layers based on their relative
// position settings.
func (nt *NetworkBase) LayoutLayers() {
en := nt.EmerNetwork
nlay := en.NumLayers()
for range 5 {
var lstly *LayerBase
for li := range nlay {
ly := en.EmerLayer(li).AsEmer()
var oly *LayerBase
if lstly != nil && ly.Pos.Rel == relpos.NoRel {
if ly.Pos.Pos.X != 0 || ly.Pos.Pos.Y != 0 || ly.Pos.Pos.Z != 0 {
// Position has been modified, don't mess with it.
continue
}
oly = lstly
ly.Pos = relpos.Pos{Rel: relpos.Above, Other: lstly.Name, XAlign: relpos.Middle, YAlign: relpos.Front}
} else {
if ly.Pos.Other != "" {
olyi, err := nt.EmerLayerByName(ly.Pos.Other)
if errors.Log(err) != nil {
continue
}
oly = olyi.AsEmer()
} else if lstly != nil {
oly = lstly
ly.Pos = relpos.Pos{Rel: relpos.Above, Other: lstly.Name, XAlign: relpos.Middle, YAlign: relpos.Front}
}
}
if oly != nil {
ly.Pos.SetPos(oly.Pos.Pos, oly.DisplaySize(), ly.DisplaySize())
}
lstly = ly
}
}
nt.layoutBoundsUpdate()
}
// layoutBoundsUpdate updates the Min / Max display bounds for 3D display.
func (nt *NetworkBase) layoutBoundsUpdate() {
en := nt.EmerNetwork
nlay := en.NumLayers()
mn := math32.Vector3Scalar(math32.Infinity)
mx := math32.Vector3{}
for li := range nlay {
ly := en.EmerLayer(li).AsEmer()
sz := ly.DisplaySize()
ru := ly.Pos.Pos
ru.X += sz.X
ru.Y += sz.Y
mn.SetMax(ly.Pos.Pos)
mx.SetMax(ru)
}
nt.MaxPos = mn
nt.MaxPos = mx
}
// VerticalLayerLayout arranges layers in a standard vertical (z axis stack)
// layout, by setting the Pos settings.
func (nt *NetworkBase) VerticalLayerLayout() {
lstnm := ""
en := nt.EmerNetwork
nlay := en.NumLayers()
for li := range nlay {
ly := en.EmerLayer(li).AsEmer()
if li == 0 {
ly.Pos = relpos.Pos{Rel: relpos.NoRel}
lstnm = ly.Name
} else {
ly.Pos = relpos.Pos{Rel: relpos.Above, Other: lstnm, XAlign: relpos.Middle, YAlign: relpos.Front}
}
}
}
// VarRange returns the min / max values for given variable.
// error occurs when variable name is not found.
func (nt *NetworkBase) VarRange(varNm string) (min, max float32, err error) {
first := true
en := nt.EmerNetwork
nlay := en.NumLayers()
for li := range nlay {
ly := en.EmerLayer(li)
lmin, lmax, lerr := ly.VarRange(varNm)
if lerr != nil {
err = lerr
return
}
if first {
min = lmin
max = lmax
continue
}
if lmin < min {
min = lmin
}
if lmax > max {
max = lmax
}
}
return
}
//////// Params
const (
// AllParams can be used for ParamsString arg to list all params.
AllParams = false
// NonDefault can be used for ParamsString arg to list only non-default params.
NonDefault = true
)
// ParamsString returns a listing of all parameters in the Layer and
// pathways within the layer. If nonDefault is true, only report those
// not at their default values.
func (nt *NetworkBase) ParamsString(nonDefault bool) string {
var b strings.Builder
en := nt.EmerNetwork
nlay := en.NumLayers()
for li := range nlay {
ly := en.EmerLayer(li)
b.WriteString(ly.ParamsString(nonDefault))
}
return b.String()
}
// SaveParams saves list of parameters in Network to given file.
// If nonDefault is true, only report those not at their default values.
func (nt *NetworkBase) SaveParams(nonDefault bool, filename core.Filename) error {
str := nt.ParamsString(nonDefault)
err := os.WriteFile(string(filename), []byte(str), 0666)
return errors.Log(err)
}
// SetRandSeed sets random seed and calls ResetRandSeed
func (nt *NetworkBase) SetRandSeed(seed int64) {
nt.RandSeed = seed
nt.ResetRandSeed()
}
// ResetRandSeed sets random seed to saved RandSeed, ensuring that the
// network-specific random seed generator has been created.
func (nt *NetworkBase) ResetRandSeed() {
if nt.Rand.Rand == nil {
nt.Rand.NewRand(nt.RandSeed)
} else {
nt.Rand.Seed(nt.RandSeed)
}
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package emer
import (
"io"
"strings"
"cogentcore.org/core/math32"
"github.com/emer/emergent/v2/params"
"github.com/emer/emergent/v2/paths"
"github.com/emer/emergent/v2/weights"
)
// Path defines the minimal interface for a pathway
// which connects two layers, using a specific Pattern
// of connectivity, and with its own set of parameters.
// This supports visualization (NetView), I/O,
// and parameter setting functionality provided by emergent.
// Most of the standard expected functionality is defined in the
// PathBase struct, and this interface only has methods that must be
// implemented specifically for a given algorithmic implementation,
type Path interface {
// AsEmer returns the path as an *emer.PathBase,
// to access base functionality.
AsEmer() *PathBase
// Label satisfies the core.Labeler interface for getting
// the name of objects generically. Use to access Name via interface.
Label() string
// TypeName is the type or category of path, defined
// by the algorithm (and usually set by an enum).
TypeName() string
// TypeNumber is the numerical value for the type or category
// of path, defined by the algorithm (and usually set by an enum).
TypeNumber() int
// SendLayer returns the sending layer for this pathway,
// as an emer.Layer interface. The actual Path implmenetation
// can use a Send field with the actual Layer struct type.
SendLayer() Layer
// RecvLayer returns the receiving layer for this pathway,
// as an emer.Layer interface. The actual Path implmenetation
// can use a Recv field with the actual Layer struct type.
RecvLayer() Layer
// NumSyns returns the number of synapses for this path.
// This is the max idx for SynValue1D and the number
// of vals set by SynValues.
NumSyns() int
// SynIndex returns the index of the synapse between given send, recv unit indexes
// (1D, flat indexes). Returns -1 if synapse not found between these two neurons.
// This requires searching within connections for receiving unit (a bit slow).
SynIndex(sidx, ridx int) int
// SynVarNames returns the names of all the variables on the synapse
// This is typically a global list so do not modify!
SynVarNames() []string
// SynVarNum returns the number of synapse-level variables
// for this paths. This is needed for extending indexes in derived types.
SynVarNum() int
// SynVarIndex returns the index of given variable within the synapse,
// according to *this path's* SynVarNames() list (using a map to lookup index),
// or -1 and error message if not found.
SynVarIndex(varNm string) (int, error)
// SynValues sets values of given variable name for each synapse,
// using the natural ordering of the synapses (sender based for Axon),
// into given float32 slice (only resized if not big enough).
// Returns error on invalid var name.
SynValues(vals *[]float32, varNm string) error
// SynValue1D returns value of given variable index
// (from SynVarIndex) on given SynIndex.
// Returns NaN on invalid index.
// This is the core synapse var access method used by other methods,
// so it is the only one that needs to be updated for derived types.
SynValue1D(varIndex int, synIndex int) float32
// ParamsString returns a listing of all parameters in the pathway.
// If nonDefault is true, only report those not at their default values.
ParamsString(nonDefault bool) string
// WriteWeightsJSON writes the weights from this pathway
// from the receiver-side perspective in a JSON text format.
WriteWeightsJSON(w io.Writer, depth int)
// SetWeights sets the weights for this pathway from weights.Path
// decoded values
SetWeights(pw *weights.Path) error
}
// PathBase defines the basic shared data for a pathway
// which connects two layers, using a specific Pattern
// of connectivity, and with its own set of parameters.
// The same struct token is added to the Recv and Send
// layer path lists,
type PathBase struct {
// EmerPath provides access to the emer.Path interface
// methods for functions defined in the PathBase type.
// Must set this with a pointer to the actual instance
// when created, using InitPath function.
EmerPath Path
// Name of the path, which can be automatically set to
// SendLayer().Name + "To" + RecvLayer().Name via
// SetStandardName method.
Name string
// Class is for applying parameter styles across multiple paths
// that all get the same parameters. This can be space separated
// with multple classes.
Class string
// Doc contains documentation about the pathway.
// This is displayed in a tooltip in the network view.
Doc string
// can record notes about this pathway here.
Notes string
// Pattern specifies the pattern of connectivity
// for interconnecting the sending and receiving layers.
Pattern paths.Pattern
// Off inactivates this pathway, allowing for easy experimentation.
Off bool
}
// InitPath initializes the path, setting the EmerPath interface
// to provide access to it for PathBase methods.
func InitPath(pt Path) {
pb := pt.AsEmer()
pb.EmerPath = pt
}
func (pt *PathBase) AsEmer() *PathBase { return pt }
func (pt *PathBase) Label() string { return pt.Name }
// AddClass adds a CSS-style class name(s) for this path,
// ensuring that it is not a duplicate, and properly space separated.
// Returns Path so it can be chained to set other properties too.
func (pt *PathBase) AddClass(cls ...string) *PathBase {
pt.Class = params.AddClass(pt.Class, cls...)
return pt
}
// IsTypeOrClass returns true if the TypeName or parameter Class for this
// pathway matches the space separated list of values given, using
// case-insensitive, "contains" logic for each match.
func (pt *PathBase) IsTypeOrClass(types string) bool {
cls := strings.Fields(strings.ToLower(pt.Class))
cls = append([]string{strings.ToLower(pt.EmerPath.TypeName())}, cls...)
fs := strings.Fields(strings.ToLower(types))
for _, pt := range fs {
for _, cl := range cls {
if strings.Contains(cl, pt) {
return true
}
}
}
return false
}
// SynValue returns value of given variable name on the synapse
// between given send, recv unit indexes (1D, flat indexes).
// Returns math32.NaN() for access errors.
func (pt *PathBase) SynValue(varNm string, sidx, ridx int) float32 {
vidx, err := pt.EmerPath.SynVarIndex(varNm)
if err != nil {
return math32.NaN()
}
syi := pt.EmerPath.SynIndex(sidx, ridx)
return pt.EmerPath.SynValue1D(vidx, syi)
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package emer
import (
"bufio"
"compress/gzip"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"sort"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/indent"
"cogentcore.org/core/core"
"github.com/emer/emergent/v2/weights"
"golang.org/x/exp/maps"
)
// SaveWeightsJSON saves network weights (and any other state that adapts with learning)
// to a JSON-formatted file. If filename has .gz extension, then file is gzip compressed.
func (nt *NetworkBase) SaveWeightsJSON(filename core.Filename) error { //types:add
fp, err := os.Create(string(filename))
defer fp.Close()
if err != nil {
return errors.Log(err)
}
ext := filepath.Ext(string(filename))
if ext == ".gz" {
gzr := gzip.NewWriter(fp)
err = nt.EmerNetwork.WriteWeightsJSON(gzr)
gzr.Close()
} else {
bw := bufio.NewWriter(fp)
err = nt.EmerNetwork.WriteWeightsJSON(bw)
bw.Flush()
}
return err
}
// OpenWeightsJSON opens network weights (and any other state that adapts with learning)
// from a JSON-formatted file. If filename has .gz extension, then file is gzip uncompressed.
func (nt *NetworkBase) OpenWeightsJSON(filename core.Filename) error { //types:add
fp, err := os.Open(string(filename))
defer fp.Close()
if err != nil {
return errors.Log(err)
}
ext := filepath.Ext(string(filename))
if ext == ".gz" {
gzr, err := gzip.NewReader(fp)
defer gzr.Close()
if err != nil {
return errors.Log(err)
}
return nt.EmerNetwork.ReadWeightsJSON(gzr)
} else {
return nt.EmerNetwork.ReadWeightsJSON(bufio.NewReader(fp))
}
}
// OpenWeightsFS opens network weights (and any other state that adapts with learning)
// from a JSON-formatted file, in filesystem.
// If filename has .gz extension, then file is gzip uncompressed.
func (nt *NetworkBase) OpenWeightsFS(fsys fs.FS, filename string) error {
fp, err := fsys.Open(filename)
defer fp.Close()
if err != nil {
return errors.Log(err)
}
ext := filepath.Ext(filename)
if ext == ".gz" {
gzr, err := gzip.NewReader(fp)
defer gzr.Close()
if err != nil {
return errors.Log(err)
}
return nt.EmerNetwork.ReadWeightsJSON(gzr)
} else {
return nt.EmerNetwork.ReadWeightsJSON(bufio.NewReader(fp))
}
}
// todo: proper error handling here!
// WriteWeightsJSON writes the weights from this network
// from the receiver-side perspective in a JSON text format.
func (nt *NetworkBase) WriteWeightsJSON(w io.Writer) error {
en := nt.EmerNetwork
nlay := en.NumLayers()
depth := 0
w.Write(indent.TabBytes(depth))
w.Write([]byte("{\n"))
depth++
w.Write(indent.TabBytes(depth))
w.Write([]byte(fmt.Sprintf("\"Network\": %q,\n", nt.Name))) // note: can't use \n in `` so need "
w.Write(indent.TabBytes(depth))
onls := make([]Layer, 0, nlay)
for li := range nlay {
ly := en.EmerLayer(li)
if !ly.AsEmer().Off {
onls = append(onls, ly)
}
}
nl := len(onls)
if nl == 0 {
w.Write([]byte("\"Layers\": null\n"))
} else {
w.Write([]byte("\"Layers\": [\n"))
depth++
for li, ly := range onls {
ly.WriteWeightsJSON(w, depth)
if li == nl-1 {
w.Write([]byte("\n"))
} else {
w.Write([]byte(",\n"))
}
}
depth--
w.Write(indent.TabBytes(depth))
w.Write([]byte("]\n"))
}
depth--
w.Write(indent.TabBytes(depth))
_, err := w.Write([]byte("}\n"))
return err
}
// ReadWeightsJSON reads network weights from the receiver-side perspective
// in a JSON text format. Reads entire file into a temporary weights.Weights
// structure that is then passed to Layers etc using SetWeights method.
func (nt *NetworkBase) ReadWeightsJSON(r io.Reader) error {
nw, err := weights.NetReadJSON(r)
if err != nil {
return err // note: already logged
}
err = nt.SetWeights(nw)
if err != nil {
errors.Log(err)
}
return err
}
// SetWeights sets the weights for this network from weights.Network decoded values
func (nt *NetworkBase) SetWeights(nw *weights.Network) error {
var errs []error
if nw.Network != "" {
nt.Name = nw.Network
}
if nw.MetaData != nil {
if nt.MetaData == nil {
nt.MetaData = nw.MetaData
} else {
for mk, mv := range nw.MetaData {
nt.MetaData[mk] = mv
}
}
}
for li := range nw.Layers {
lw := &nw.Layers[li]
ly, err := nt.EmerLayerByName(lw.Layer)
if err != nil {
errs = append(errs, err)
continue
}
ly.SetWeights(lw)
}
return errors.Join(errs...)
}
// WriteWeightsJSONBase writes the weights from this layer
// in a JSON text format. Any values in the layer MetaData
// will be written first, and unit-level variables in unitVars
// are saved as well. Then, all the receiving path data is saved.
func (ly *LayerBase) WriteWeightsJSONBase(w io.Writer, depth int, unitVars ...string) {
el := ly.EmerLayer
w.Write(indent.TabBytes(depth))
w.Write([]byte("{\n"))
depth++
w.Write(indent.TabBytes(depth))
w.Write([]byte(fmt.Sprintf("\"Layer\": %q,\n", ly.Name)))
if len(ly.MetaData) > 0 {
w.Write(indent.TabBytes(depth))
w.Write([]byte(fmt.Sprintf("\"MetaData\": {\n")))
depth++
kys := maps.Keys(ly.MetaData)
sort.StringSlice(kys).Sort()
for i, k := range kys {
w.Write(indent.TabBytes(depth))
comma := ","
if i == len(kys)-1 { // note: last one has no comma
comma = ""
}
w.Write([]byte(fmt.Sprintf("%q: %q%s\n", k, ly.MetaData[k], comma)))
}
depth--
w.Write(indent.TabBytes(depth))
w.Write([]byte("},\n"))
}
if len(unitVars) > 0 {
w.Write(indent.TabBytes(depth))
w.Write([]byte(fmt.Sprintf("\"Units\": {\n")))
depth++
for i, vname := range unitVars {
vidx, err := el.UnitVarIndex(vname)
if errors.Log(err) != nil {
continue
}
w.Write(indent.TabBytes(depth))
w.Write([]byte(fmt.Sprintf("%q: [ ", vname)))
nu := ly.NumUnits()
for ni := range nu {
val := el.UnitValue1D(vidx, ni, 0)
w.Write([]byte(fmt.Sprintf("%g", val)))
if ni < nu-1 {
w.Write([]byte(", "))
}
}
comma := ","
if i == len(unitVars)-1 { // note: last one has no comma
comma = ""
}
w.Write([]byte(fmt.Sprintf(" ]%s\n", comma)))
}
depth--
w.Write(indent.TabBytes(depth))
w.Write([]byte("},\n"))
}
w.Write(indent.TabBytes(depth))
onps := make([]Path, 0, el.NumRecvPaths())
for pi := range el.NumRecvPaths() {
pt := el.RecvPath(pi)
if !pt.AsEmer().Off {
onps = append(onps, pt)
}
}
np := len(onps)
if np == 0 {
w.Write([]byte(fmt.Sprintf("\"Paths\": null\n")))
} else {
w.Write([]byte(fmt.Sprintf("\"Paths\": [\n")))
depth++
for pi := range el.NumRecvPaths() {
pt := el.RecvPath(pi)
pt.WriteWeightsJSON(w, depth) // this leaves path unterminated
if pi == np-1 {
w.Write([]byte("\n"))
} else {
w.Write([]byte(",\n"))
}
}
depth--
w.Write(indent.TabBytes(depth))
w.Write([]byte(" ]\n"))
}
depth--
w.Write(indent.TabBytes(depth))
w.Write([]byte("}")) // note: leave unterminated as outer loop needs to add , or just \n depending
}
// ReadWeightsJSON reads the weights from this layer from the
// receiver-side perspective in a JSON text format.
// This is for a set of weights that were saved *for one layer only*
// and is not used for the network-level ReadWeightsJSON,
// which reads into a separate structure -- see SetWeights method.
func (ly *LayerBase) ReadWeightsJSON(r io.Reader) error {
lw, err := weights.LayReadJSON(r)
if err != nil {
return err // note: already logged
}
return ly.EmerLayer.SetWeights(lw)
}
// ReadWeightsJSON reads the weights from this pathway from the
// receiver-side perspective in a JSON text format.
// This is for a set of weights that were saved *for one path only*
// and is not used for the network-level ReadWeightsJSON,
// which reads into a separate structure -- see SetWeights method.
func (pt *PathBase) ReadWeightsJSON(r io.Reader) error {
pw, err := weights.PathReadJSON(r)
if err != nil {
return err // note: already logged
}
return pt.EmerPath.SetWeights(pw)
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package env
// Counter maintains a current and previous counter value,
// and a Max value with methods to manage.
type Counter struct {
// Cur is the current counter value.
Cur int
// Prev previous counter value, prior to last Incr() call (init to -1)
Prev int `display:"-"`
// Changed reports if it changed on the last Step() call or not.
Changed bool `display:"-"`
// Max is the maximum counter value, above which the counter will reset back to 0.
// Only used if > 0.
Max int
}
// Init initializes counter: Cur = 0, Prev = -1
func (ct *Counter) Init() {
ct.Prev = -1
ct.Cur = 0
ct.Changed = false
}
// Same resets Changed = false -- good idea to call this on all counters at start of Step
// or can put in an else statement, but that is more error-prone.
func (ct *Counter) Same() {
ct.Changed = false
}
// Incr increments the counter by 1. If Max > 0 then if Incr >= Max
// the counter is reset to 0 and true is returned. Otherwise false.
func (ct *Counter) Incr() bool {
ct.Changed = true
ct.Prev = ct.Cur
ct.Cur++
if ct.Max > 0 && ct.Cur >= ct.Max {
ct.Cur = 0
return true
}
return false
}
// Set sets the Cur value if different from Cur, while preserving previous value
// and setting Changed appropriately. Returns true if changed.
// does NOT check Cur vs. Max.
func (ct *Counter) Set(cur int) bool {
if ct.Cur == cur {
ct.Changed = false
return false
}
ct.Changed = true
ct.Prev = ct.Cur
ct.Cur = cur
return true
}
// Query returns the current, previous and changed values for this counter
func (ct *Counter) Query() (cur, prev int, chg bool) {
return ct.Cur, ct.Prev, ct.Changed
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package env
import "cogentcore.org/lab/tensor"
// CurPrev manages current and previous values for basic data types.
type CurPrev[T tensor.DataTypes] struct {
Cur, Prev T
}
// Set sets the new current value, after saving Cur to Prev.
func (cv *CurPrev[T]) Set(cur T) {
cv.Prev = cv.Cur
cv.Cur = cur
}
type CurPrevString = CurPrev[string]
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package env
import (
"fmt"
"cogentcore.org/core/enums"
)
// Envs is a map of environments organized according
// to the evaluation mode string (recommended key value)
type Envs map[string]Env
// Init initializes the map if not yet
func (es *Envs) Init() {
if *es == nil {
*es = make(map[string]Env)
}
}
// Add adds Env(s), using its Label as the key
func (es *Envs) Add(evs ...Env) {
es.Init()
for _, ev := range evs {
(*es)[ev.Label()] = ev
}
}
// ByMode returns env by Modes evaluation mode as the map key.
// returns nil if not found.
func (es *Envs) ByMode(mode enums.Enum) Env {
return (*es)[mode.String()]
}
// ModeDi returns the string of the given mode appended with
// _di data index with leading zero.
func ModeDi(mode enums.Enum, di int) string {
return fmt.Sprintf("%s_%02d", mode.String(), di)
}
// ByModeDi returns env by etime.Modes evaluation mode and
// data parallel index as the map key, using ModeDi function.
// returns nil if not found.
func (es *Envs) ByModeDi(mode enums.Enum, di int) Env {
return (*es)[ModeDi(mode, di)]
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package env
import (
"fmt"
"log/slog"
"math/rand"
"cogentcore.org/lab/base/randx"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
)
// FixedTable is a basic Env that manages patterns from a [table.Table], with
// either sequential or permuted random ordering, with a Trial counter
// to record progress and iterations through the table.
// Use [table.NewView] to provide a unique indexed view of a shared table.
type FixedTable struct {
// name of this environment, usually Train vs. Test.
Name string
// Table has the set of patterns to output.
// The indexes are used for the *sequential* view so you can easily
// sort / split / filter the patterns to be presented using this view.
// This adds the random permuted Order on top of those if !sequential.
Table *table.Table
// present items from the table in sequential order (i.e., according to
// the indexed view on the Table)? otherwise permuted random order.
Sequential bool
// permuted order of items to present if not sequential.
// updated every time through the list.
Order []int
// current ordinal item in Table. if Sequential then = row number in table,
// otherwise is index in Order list that then gives row number in Table.
Trial Counter `display:"inline"`
// if Table has a Name column, this is the contents of that.
TrialName CurPrevString
// if Table has a Group column, this is contents of that.
GroupName CurPrevString
// name of the Name column -- defaults to 'Name'.
NameCol string
// name of the Group column -- defaults to 'Group'.
GroupCol string
}
func (ft *FixedTable) Validate() error {
if ft.Table == nil {
return fmt.Errorf("env.FixedTable: %v has no Table set", ft.Name)
}
if ft.Table.NumColumns() == 0 {
return fmt.Errorf("env.FixedTable: %v Table has no columns -- Outputs will be invalid", ft.Name)
}
return nil
}
func (ft *FixedTable) Label() string { return ft.Name }
func (ft *FixedTable) String() string {
s := ft.TrialName.Cur
if ft.GroupName.Cur != "" {
s = ft.GroupName.Cur + "_" + s
}
return s
}
func (ft *FixedTable) Init(run int) {
if ft.NameCol == "" {
ft.NameCol = "Name"
}
if ft.GroupCol == "" {
ft.GroupCol = "Group"
}
ft.Trial.Init()
ft.NewOrder()
ft.Trial.Cur = -1 // init state -- key so that first Step() = 0
}
// Config configures the environment to use given table IndexView and
// evaluation mode (e.g., etime.Train.String()). If mode is Train
// then a Run counter is added, otherwise just Epoch and Trial.
// NameCol and GroupCol are initialized to "Name" and "Group"
// so set these to something else after this if needed.
func (ft *FixedTable) Config(tbl *table.Table) {
ft.Table = tbl
ft.Init(0)
}
// NewOrder sets a new random Order based on number of rows in the table.
func (ft *FixedTable) NewOrder() {
np := ft.Table.NumRows()
ft.Order = rand.Perm(np) // always start with new one so random order is identical
// and always maintain Order so random number usage is same regardless, and if
// user switches between Sequential and random at any point, it all works..
ft.Trial.Max = np
}
// PermuteOrder permutes the existing order table to get a new random sequence of inputs
// just calls: randx.PermuteInts(ft.Order)
func (ft *FixedTable) PermuteOrder() {
randx.PermuteInts(ft.Order)
}
// Row returns the current row number in table, based on Sequential / perumuted Order.
func (ft *FixedTable) Row() int {
if ft.Sequential {
return ft.Trial.Cur
}
return ft.Order[ft.Trial.Cur]
}
func (ft *FixedTable) SetTrialName() {
if nms := ft.Table.Column(ft.NameCol); nms != nil {
rw := ft.Row()
if rw >= 0 && rw < nms.Len() {
ft.TrialName.Set(nms.StringRow(rw, 0))
}
}
}
func (ft *FixedTable) SetGroupName() {
if nms := ft.Table.Column(ft.GroupCol); nms != nil {
rw := ft.Row()
if rw >= 0 && rw < nms.Len() {
ft.GroupName.Set(nms.StringRow(rw, 0))
}
}
}
func (ft *FixedTable) Step() bool {
if ft.Trial.Incr() { // if true, hit max, reset to 0
ft.PermuteOrder()
}
ft.SetTrialName()
ft.SetGroupName()
return true
}
func (ft *FixedTable) State(element string) tensor.Values {
et := ft.Table.Column(element).RowTensor(ft.Row())
if et == nil {
slog.Error("FixedTable.State: could not find", "element", element)
}
return et
}
// Compile-time check that implements Env interface
var _ Env = (*FixedTable)(nil)
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package env
import (
"fmt"
"log/slog"
"math"
"cogentcore.org/lab/base/randx"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
)
// FreqTable is an Env that manages patterns from an table.Table with frequency
// information so that items are presented according to their associated frequencies
// which are effectively probabilities of presenting any given input -- must have
// a Freq column with these numbers in the table (actual col name in FreqCol).
// Either sequential or permuted random ordering is supported, with std Trial / Epoch
// TimeScale counters to record progress and iterations through the table.
// It also records the outer loop of Run as provided by the model.
// It uses an IndexView indexed view of the Table, so a single shared table
// can be used across different environments, with each having its own unique view.
type FreqTable struct {
// name of this environment
Name string
// Table has the set of patterns to output.
// The indexes are used for the *sequential* view so you can easily
// sort / split / filter the patterns to be presented using this view.
// This adds the random permuted Order on top of those if !sequential.
Table *table.Table
// number of samples to use in constructing the list of items to present according to frequency -- number per epoch ~ NSamples * Freq -- see RandSamp option
NSamples float64
// if true, use random sampling of items NSamples times according to given Freq probability value -- otherwise just directly add NSamples * Freq items to the list
RandSamp bool
// present items from the table in sequential order (i.e., according to the indexed view on the Table)? otherwise permuted random order. All repetitions of given item will be sequential if Sequential
Sequential bool
// list of items to present, with repetitions -- updated every time through the list
Order []int
// current ordinal item in Table -- if Sequential then = row number in table, otherwise is index in Order list that then gives row number in Table
Trial Counter `display:"inline"`
// if Table has a Name column, this is the contents of that
TrialName CurPrevString
// if Table has a Group column, this is contents of that
GroupName CurPrevString
// name of the Name column -- defaults to 'Name'
NameCol string
// name of the Group column -- defaults to 'Group'
GroupCol string
// name of the Freq column -- defaults to 'Freq'
FreqCol string
}
func (ft *FreqTable) Validate() error {
if ft.Table == nil {
return fmt.Errorf("env.FreqTable: %v has no Table set", ft.Name)
}
if ft.Table.NumColumns() == 0 {
return fmt.Errorf("env.FreqTable: %v Table has no columns -- Outputs will be invalid", ft.Name)
}
fc := ft.Table.Column(ft.FreqCol)
if fc == nil {
return fmt.Errorf("env.FreqTable: %v Table has no FreqCol", ft.FreqCol)
}
return nil
}
func (ft *FreqTable) Label() string { return ft.Name }
func (ft *FreqTable) String() string {
s := ft.TrialName.Cur
if ft.GroupName.Cur != "" {
s = ft.GroupName.Cur + "_" + s
}
return s
}
func (ft *FreqTable) Init(run int) {
if ft.NameCol == "" {
ft.NameCol = "Name"
}
if ft.GroupCol == "" {
ft.GroupCol = "Group"
}
if ft.FreqCol == "" {
ft.FreqCol = "Freq"
}
ft.Trial.Init()
ft.Sample()
ft.Trial.Max = len(ft.Order)
ft.Trial.Cur = -1 // init state -- key so that first Step() = 0
}
// Sample generates a new sample of items
func (ft *FreqTable) Sample() {
if ft.NSamples <= 0 {
ft.NSamples = 1
}
np := ft.Table.NumRows()
if ft.Order == nil {
ft.Order = make([]int, 0, int(math.Round(float64(np)*ft.NSamples)))
} else {
ft.Order = ft.Order[:0]
}
frqs := ft.Table.Column(ft.FreqCol)
for ri := 0; ri < np; ri++ {
frq := frqs.FloatRow(ri, 0)
if ft.RandSamp {
n := int(ft.NSamples)
for i := 0; i < n; i++ {
if randx.BoolP(frq) {
ft.Order = append(ft.Order, ri)
}
}
} else {
n := int(math.Round(ft.NSamples * frq))
for i := 0; i < n; i++ {
ft.Order = append(ft.Order, ri)
}
}
}
if !ft.Sequential {
randx.PermuteInts(ft.Order)
}
}
// Row returns the current row number in table, based on Sequential / perumuted Order and
// already de-referenced through the IndexView's indexes to get the actual row in the table.
func (ft *FreqTable) Row() int {
return ft.Table.Indexes[ft.Order[ft.Trial.Cur]]
}
func (ft *FreqTable) SetTrialName() {
if nms := ft.Table.Column(ft.NameCol); nms != nil {
rw := ft.Row()
if rw >= 0 && rw < nms.Len() {
ft.TrialName.Set(nms.StringRow(rw, 0))
}
}
}
func (ft *FreqTable) SetGroupName() {
if nms := ft.Table.Column(ft.GroupCol); nms != nil {
rw := ft.Row()
if rw >= 0 && rw < nms.Len() {
ft.GroupName.Set(nms.StringRow(rw, 0))
}
}
}
func (ft *FreqTable) Step() bool {
if ft.Trial.Incr() { // if true, hit max, reset to 0
ft.Sample()
ft.Trial.Max = len(ft.Order)
}
ft.SetTrialName()
ft.SetGroupName()
return true
}
func (ft *FreqTable) State(element string) tensor.Values {
et := ft.Table.Column(element).RowTensor(ft.Row())
if et == nil {
slog.Error("FreqTable.State: could not find:", "element", element)
}
return et
}
// Compile-time check that implements Env interface
var _ Env = (*FreqTable)(nil)
// Copyright (c) 2020, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package env
import (
"fmt"
"log/slog"
"math/rand"
"cogentcore.org/lab/base/randx"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensor/tensormpi"
)
// MPIFixedTable is an MPI-enabled version of the [FixedTable], which is
// a basic Env that manages patterns from a [table.Table[, with
// either sequential or permuted random ordering, and a Trial counter to
// record iterations through the table.
// Use [table.NewView] to provide a unique indexed view of a shared table.
// The MPI version distributes trials across MPI procs, in the Order list.
// It is ESSENTIAL that the number of trials (rows) in Table is
// evenly divisible by number of MPI procs!
// If all nodes start with the same seed, it should remain synchronized.
type MPIFixedTable struct {
// name of this environment
Name string
// Table has the set of patterns to output.
// The indexes are used for the *sequential* view so you can easily
// sort / split / filter the patterns to be presented using this view.
// This adds the random permuted Order on top of those if !sequential.
Table *table.Table
// present items from the table in sequential order (i.e., according to the indexed view on the Table)? otherwise permuted random order
Sequential bool
// permuted order of items to present if not sequential -- updated every time through the list
Order []int
// current ordinal item in Table -- if Sequential then = row number in table, otherwise is index in Order list that then gives row number in Table
Trial Counter `display:"inline"`
// if Table has a Name column, this is the contents of that
TrialName CurPrevString
// if Table has a Group column, this is contents of that
GroupName CurPrevString
// name of the Name column -- defaults to 'Name'
NameCol string
// name of the Group column -- defaults to 'Group'
GroupCol string
// for MPI, trial we start each epoch on, as index into Order
TrialSt int
// for MPI, trial number we end each epoch before (i.e., when ctr gets to Ed, restarts)
TrialEd int
}
func (ft *MPIFixedTable) Validate() error {
if ft.Table == nil {
return fmt.Errorf("MPIFixedTable: %v has no Table set", ft.Name)
}
if ft.Table.NumColumns() == 0 {
return fmt.Errorf("MPIFixedTable: %v Table has no columns -- Outputs will be invalid", ft.Name)
}
return nil
}
func (ft *MPIFixedTable) Label() string { return ft.Name }
func (ft *MPIFixedTable) String() string {
s := ft.TrialName.Cur
if ft.GroupName.Cur != "" {
s = ft.GroupName.Cur + "_" + s
}
return s
}
func (ft *MPIFixedTable) Init(run int) {
if ft.NameCol == "" {
ft.NameCol = "Name"
}
if ft.GroupCol == "" {
ft.GroupCol = "Group"
}
ft.Trial.Init()
ft.NewOrder()
ft.Trial.Cur = ft.TrialSt - 1 // init state -- key so that first Step() = ft.TrialSt
}
// NewOrder sets a new random Order based on number of rows in the table.
func (ft *MPIFixedTable) NewOrder() {
np := ft.Table.NumRows()
ft.Order = rand.Perm(np) // always start with new one so random order is identical
// and always maintain Order so random number usage is same regardless, and if
// user switches between Sequential and random at any point, it all works..
ft.TrialSt, ft.TrialEd, _ = tensormpi.AllocN(np)
ft.Trial.Max = ft.TrialEd
}
// PermuteOrder permutes the existing order table to get a new random sequence of inputs
// just calls: randx.PermuteInts(ft.Order)
func (ft *MPIFixedTable) PermuteOrder() {
randx.PermuteInts(ft.Order)
}
// Row returns the current row number in table, based on Sequential / perumuted Order.
func (ft *MPIFixedTable) Row() int {
if ft.Sequential {
return ft.Trial.Cur
}
return ft.Order[ft.Trial.Cur]
}
func (ft *MPIFixedTable) SetTrialName() {
if nms := ft.Table.Column(ft.NameCol); nms != nil {
rw := ft.Row()
if rw >= 0 && rw < nms.Len() {
ft.TrialName.Set(nms.StringRow(rw, 0))
}
}
}
func (ft *MPIFixedTable) SetGroupName() {
if nms := ft.Table.Column(ft.GroupCol); nms != nil {
rw := ft.Row()
if rw >= 0 && rw < nms.Len() {
ft.GroupName.Set(nms.StringRow(rw, 0))
}
}
}
func (ft *MPIFixedTable) Step() bool {
if ft.Trial.Incr() { // if true, hit max, reset to 0
ft.Trial.Cur = ft.TrialSt // key to reset always to start
ft.PermuteOrder()
}
ft.SetTrialName()
ft.SetGroupName()
return true
}
func (ft *MPIFixedTable) State(element string) tensor.Values {
et := ft.Table.Column(element).RowTensor(ft.Row())
if et == nil {
slog.Error("MPIFixedTable.State: could not find:", "element", element)
}
return et
}
// Compile-time check that implements Env interface
var _ Env = (*MPIFixedTable)(nil)
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package esg
//go:generate core generate
import (
"fmt"
)
// Conds are conditionals
type Conds []*Cond
// String returns string rep
func (cs *Conds) String() string {
str := ""
for ci := range *cs {
cd := (*cs)[ci]
str += cd.String() + " "
}
return str
}
// True returns true if conditional expression is true
func (cs *Conds) Eval(rls *Rules) bool {
cval := true
hasCval := false
lastBin := CondElsN // binary op
lastNot := false
for ci := range *cs {
cd := (*cs)[ci]
switch cd.El {
case And, Or:
lastBin = cd.El
case Not:
lastNot = true
default:
tst := cd.Eval(rls)
if lastNot {
tst = !tst
lastNot = false
}
if !hasCval {
cval = tst
hasCval = true
continue
}
hasCval = true
switch lastBin {
case And:
cval = cval && tst
case Or:
cval = cval || tst
}
lastBin = CondElsN
}
}
return cval
}
// Validate checks for errors in expression
func (cs *Conds) Validate(rl *Rule, it *Item, rls *Rules) []error {
lastBin := CondElsN // binary op
lastNot := false
var errs []error
ncd := len(*cs)
for ci := range *cs {
cd := (*cs)[ci]
switch cd.El {
case And, Or:
if lastBin != CondElsN {
errs = append(errs, fmt.Errorf("Rule: %v Item: %v Condition has two binary logical operators in a row", rl.Name, it.String()))
}
if ci == 0 || ci == ncd-1 {
errs = append(errs, fmt.Errorf("Rule: %v Item: %v Condition has binary logical operator at start or end", rl.Name, it.String()))
}
lastBin = cd.El
case Not:
if lastNot {
errs = append(errs, fmt.Errorf("Rule: %v Item: %v Condition has two Not operators in a row", rl.Name, it.String()))
}
if ci == 0 || ci == ncd-1 {
errs = append(errs, fmt.Errorf("Rule: %v Item: %v Condition has Not operator at start or end", rl.Name, it.String()))
}
lastNot = true
default:
elers := cd.Validate(rl, it, rls)
if elers != nil {
errs = append(errs, elers...)
}
lastNot = false
lastBin = CondElsN
}
}
return errs
}
/////////////////////////////////////////////////////////////////////////
// Cond
// Cond is one element of a conditional
type Cond struct {
// what type of conditional element is this
El CondEls
// name of rule or token to evaluate for CRule
Rule string
// sub-conditions for SubCond
Conds Conds
}
// String returns string rep
func (cd *Cond) String() string {
switch cd.El {
case And:
return "&&"
case Or:
return "||"
case Not:
return "!"
case CRule:
return cd.Rule
case SubCond:
return "(" + cd.Conds.String() + ")"
}
return ""
}
// True returns true if conditional expression is true
func (cd *Cond) Eval(rls *Rules) bool {
if cd.El == CRule {
if cd.Rule[0] == '\'' {
return rls.HasOutput(cd.Rule)
} else {
return rls.HasFired(cd.Rule)
}
}
if cd.El == SubCond && cd.Conds != nil {
return cd.Conds.Eval(rls)
}
return false
}
// Validate checks for errors in expression
func (cd *Cond) Validate(rl *Rule, it *Item, rls *Rules) []error {
if cd.El == CRule {
if cd.Rule == "" {
return []error{fmt.Errorf("Rule: %v Item: %v Rule Condition has empty Rule", rl.Name, it.String())}
}
if cd.Rule[0] != '\'' {
_, err := rls.Rule(cd.Rule)
if err != nil {
return []error{err}
}
}
return nil
}
if cd.El == SubCond {
if len(cd.Conds) == 0 {
return []error{fmt.Errorf("Rule: %v Item: %v Rule SubConds are empty", rl.Name, it.String())}
}
return cd.Conds.Validate(rl, it, rls)
}
return nil
}
// CondEls are different types of conditional elements
type CondEls int32 //enums:enum
const (
// CRule means Rule is name of a rule to evaluate truth value
CRule CondEls = iota
And
Or
Not
// SubCond is a sub-condition expression
SubCond
)
// Code generated by "core generate -add-types"; DO NOT EDIT.
package esg
import (
"cogentcore.org/core/enums"
)
var _CondElsValues = []CondEls{0, 1, 2, 3, 4}
// CondElsN is the highest valid value for type CondEls, plus one.
const CondElsN CondEls = 5
var _CondElsValueMap = map[string]CondEls{`CRule`: 0, `And`: 1, `Or`: 2, `Not`: 3, `SubCond`: 4}
var _CondElsDescMap = map[CondEls]string{0: `CRule means Rule is name of a rule to evaluate truth value`, 1: ``, 2: ``, 3: ``, 4: `SubCond is a sub-condition expression`}
var _CondElsMap = map[CondEls]string{0: `CRule`, 1: `And`, 2: `Or`, 3: `Not`, 4: `SubCond`}
// String returns the string representation of this CondEls value.
func (i CondEls) String() string { return enums.String(i, _CondElsMap) }
// SetString sets the CondEls value from its string representation,
// and returns an error if the string is invalid.
func (i *CondEls) SetString(s string) error {
return enums.SetString(i, s, _CondElsValueMap, "CondEls")
}
// Int64 returns the CondEls value as an int64.
func (i CondEls) Int64() int64 { return int64(i) }
// SetInt64 sets the CondEls value from an int64.
func (i *CondEls) SetInt64(in int64) { *i = CondEls(in) }
// Desc returns the description of the CondEls value.
func (i CondEls) Desc() string { return enums.Desc(i, _CondElsDescMap) }
// CondElsValues returns all possible values for the type CondEls.
func CondElsValues() []CondEls { return _CondElsValues }
// Values returns all possible values for the type CondEls.
func (i CondEls) Values() []enums.Enum { return enums.Values(_CondElsValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i CondEls) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *CondEls) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "CondEls") }
var _ElementsValues = []Elements{0, 1}
// ElementsN is the highest valid value for type Elements, plus one.
const ElementsN Elements = 2
var _ElementsValueMap = map[string]Elements{`RuleEl`: 0, `TokenEl`: 1}
var _ElementsDescMap = map[Elements]string{0: `RuleEl means Value is name of a rule`, 1: `TokenEl means Value is a token to emit`}
var _ElementsMap = map[Elements]string{0: `RuleEl`, 1: `TokenEl`}
// String returns the string representation of this Elements value.
func (i Elements) String() string { return enums.String(i, _ElementsMap) }
// SetString sets the Elements value from its string representation,
// and returns an error if the string is invalid.
func (i *Elements) SetString(s string) error {
return enums.SetString(i, s, _ElementsValueMap, "Elements")
}
// Int64 returns the Elements value as an int64.
func (i Elements) Int64() int64 { return int64(i) }
// SetInt64 sets the Elements value from an int64.
func (i *Elements) SetInt64(in int64) { *i = Elements(in) }
// Desc returns the description of the Elements value.
func (i Elements) Desc() string { return enums.Desc(i, _ElementsDescMap) }
// ElementsValues returns all possible values for the type Elements.
func ElementsValues() []Elements { return _ElementsValues }
// Values returns all possible values for the type Elements.
func (i Elements) Values() []enums.Enum { return enums.Values(_ElementsValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i Elements) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *Elements) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Elements") }
var _RuleTypesValues = []RuleTypes{0, 1, 2, 3, 4}
// RuleTypesN is the highest valid value for type RuleTypes, plus one.
const RuleTypesN RuleTypes = 5
var _RuleTypesValueMap = map[string]RuleTypes{`UniformItems`: 0, `ProbItems`: 1, `CondItems`: 2, `SequentialItems`: 3, `PermutedItems`: 4}
var _RuleTypesDescMap = map[RuleTypes]string{0: `UniformItems is the default mutually exclusive items chosen at uniform random`, 1: `ProbItems has specific probabilities for each item`, 2: `CondItems has conditionals for each item, indicated by ?`, 3: `SequentialItems progresses through items in sequential order, indicated by |`, 4: `PermutedItems progresses through items in permuted order, indicated by $`}
var _RuleTypesMap = map[RuleTypes]string{0: `UniformItems`, 1: `ProbItems`, 2: `CondItems`, 3: `SequentialItems`, 4: `PermutedItems`}
// String returns the string representation of this RuleTypes value.
func (i RuleTypes) String() string { return enums.String(i, _RuleTypesMap) }
// SetString sets the RuleTypes value from its string representation,
// and returns an error if the string is invalid.
func (i *RuleTypes) SetString(s string) error {
return enums.SetString(i, s, _RuleTypesValueMap, "RuleTypes")
}
// Int64 returns the RuleTypes value as an int64.
func (i RuleTypes) Int64() int64 { return int64(i) }
// SetInt64 sets the RuleTypes value from an int64.
func (i *RuleTypes) SetInt64(in int64) { *i = RuleTypes(in) }
// Desc returns the description of the RuleTypes value.
func (i RuleTypes) Desc() string { return enums.Desc(i, _RuleTypesDescMap) }
// RuleTypesValues returns all possible values for the type RuleTypes.
func RuleTypesValues() []RuleTypes { return _RuleTypesValues }
// Values returns all possible values for the type RuleTypes.
func (i RuleTypes) Values() []enums.Enum { return enums.Values(_RuleTypesValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i RuleTypes) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *RuleTypes) UnmarshalText(text []byte) error {
return enums.UnmarshalText(i, text, "RuleTypes")
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package esg
import (
"fmt"
"strings"
)
// Item is one item within a rule
type Item struct { //git:add
// probability for choosing this item -- 0 if uniform random
Prob float32
// elements of the rule -- for non-Cond rules
Elems []Elem
// conditions for this item -- specified by ?
Cond Conds
// for conditional, this is the sub-rule that is run with sub-items
SubRule *Rule
// state update name=value to set for rule
State State
}
// String returns string rep
func (it *Item) String() string {
if it.Cond != nil {
return it.Cond.String() + it.SubRule.String()
}
sout := ""
if it.Prob > 0 {
sout = "%" + fmt.Sprintf("%g ", it.Prob)
}
for i := range it.Elems {
el := &it.Elems[i]
sout += el.String() + " "
}
return sout
}
// Gen generates expression according to the item
func (it *Item) Gen(rl *Rule, rls *Rules) {
if it.SubRule != nil {
it.State.Set(rls, "") // no value
it.SubRule.Gen(rls)
}
if len(it.Elems) > 0 {
it.State.Set(rls, it.Elems[0].Value)
for i := range it.Elems {
el := &it.Elems[i]
el.Gen(rl, rls)
}
}
}
// CondTrue evalutes whether the condition is true
func (it *Item) CondEval(rl *Rule, rls *Rules) bool {
return it.Cond.Eval(rls)
}
// Validate checks for config errors
func (it *Item) Validate(rl *Rule, rls *Rules) []error {
if it.Cond != nil {
ers := it.Cond.Validate(rl, it, rls)
if it.SubRule == nil {
ers = append(ers, fmt.Errorf("Rule: %v Item: %v IsCond but SubRule == nil", rl.Name, it.String()))
} else {
srs := it.SubRule.Validate(rls)
if len(srs) > 0 {
ers = append(ers, srs...)
}
}
return ers
}
var errs []error
for i := range it.Elems {
el := &it.Elems[i]
ers := el.Validate(it, rl, rls)
if len(ers) > 0 {
errs = append(errs, ers...)
}
}
return errs
}
/////////////////////////////////////////////////////////////////////
// Elem
// Elem is one elemenent in a concrete Item: either rule or token
type Elem struct { //git:add
// type of element: Rule, Token, or SubItems
El Elements
// value of the token: name of Rule or Token
Value string
}
// String returns string rep
func (el *Elem) String() string {
if el.El == TokenEl {
return "'" + el.Value + "'"
}
return el.Value
}
// Gen generates expression according to the element
func (el *Elem) Gen(rl *Rule, rls *Rules) {
switch el.El {
case RuleEl:
rl, _ := rls.Rule(el.Value)
rl.Gen(rls)
case TokenEl:
if rls.Trace {
fmt.Printf("Rule: %v added Token output: %v\n", rl.Name, el.Value)
}
rls.AddOutput(el.Value)
}
}
// Validate checks for config errors
func (el *Elem) Validate(it *Item, rl *Rule, rls *Rules) []error {
switch el.El {
case RuleEl:
_, err := rls.Rule(el.Value)
if err != nil {
return []error{err}
}
return nil
case TokenEl:
if el.Value == "" {
err := fmt.Errorf("Rule: %v Item: %v has empty Token element", rl.Name, it.String())
return []error{err}
}
}
return nil
}
// Elements are different types of elements
type Elements int32 //enums:enum
const (
// RuleEl means Value is name of a rule
RuleEl Elements = iota
// TokenEl means Value is a token to emit
TokenEl
)
/////////////////////////////////////////////////////////////////////
// State
// State holds the name=value state settings associated with rule or item
// as a string, string map
type State map[string]string
// Add adds give name, value to state
func (ss *State) Add(name, val string) {
if *ss == nil {
*ss = make(map[string]string)
}
(*ss)[name] = val
}
// Set sets state in rules States map, using given value for any items that have empty values
func (ss *State) Set(rls *Rules, val string) bool {
if len(*ss) == 0 {
return false
}
for k, v := range *ss {
if v == "" {
v = val
}
rls.States[k] = v
if rls.Trace {
fmt.Printf("Set State: %v = %v\n", k, v)
}
}
return true
}
// TrimQualifiers removes any :X qualifiers after state values
func (ss *State) TrimQualifiers() {
for k, v := range *ss {
ci := strings.Index(v, ":")
if ci > 0 {
(*ss)[k] = v[:ci]
}
}
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package esg
import (
"bufio"
"fmt"
"io"
"os"
"strconv"
"strings"
"cogentcore.org/core/base/errors"
)
// OpenRules reads in a text file with rules, line-by-line simple parser
func (rls *Rules) OpenRules(fname string) []error {
fp, err := os.Open(fname)
defer fp.Close()
if errors.Log(err) != nil {
return []error{err}
}
return rls.ReadRules(fp)
}
// OpenRulesPy reads in a text file with rules, line-by-line simple parser
func (rls *Rules) OpenRulesPy(fname string) {
rls.OpenRules(fname)
}
// AddParseErr adds given parser error, auto including line number
func (rls *Rules) AddParseErr(msg string) {
err := fmt.Errorf("Line: %d \tesg Parse Error: %s", rls.ParseLn, msg)
rls.ParseErrs = append(rls.ParseErrs, err)
}
// ReadRules reads in a text file with rules, line-by-line simple parser
func (rls *Rules) ReadRules(r io.Reader) []error {
rls.Map = nil
rls.Top = nil
rls.ParseErrs = nil
rls.ParseLn = 0
scan := bufio.NewScanner(r) // line at a time
rstack := []*Rule{}
lastwascmt := false
lastcmt := ""
for scan.Scan() {
rls.ParseLn++
b := scan.Bytes()
bs := string(b)
sp := strings.Fields(bs)
nsp := len(sp)
if nsp > 2 && sp[0] != "//" { // get rid of trailing comments
for i, s := range sp {
if s == "//" {
nsp = i
sp = sp[:i]
break
}
}
}
switch {
case nsp == 0:
lastwascmt = false
case sp[0] == "//":
ncmt := strings.Join(sp[1:], " ")
if lastwascmt {
lastcmt += "\n" + ncmt
} else {
lastcmt = ncmt
lastwascmt = true
}
case len(sp[0]) > 2 && sp[0][:2] == "//":
lastwascmt = false // repeated comment line skip these
case sp[0] == "}":
lastwascmt = false
sz := len(rstack)
if sz == 0 {
rls.AddParseErr("mismatched end bracket } has no match")
continue
}
rstack = rstack[:sz-1]
case sp[nsp-1] == "{":
desc := ""
if lastwascmt {
desc = lastcmt
lastwascmt = false
}
if nsp == 1 {
rls.AddParseErr("start bracket: '{' needs at least a rule name")
continue
}
rnm := sp[0]
var rptp float32
prp := sp[nsp-2]
if len(prp) > 2 && prp[0:2] == "=%" {
pct, err := strconv.ParseFloat(prp[2:], 32)
if err != nil {
rls.AddParseErr(err.Error())
} else {
rptp = float32(pct / 100)
}
}
typ := UniformItems
switch prp {
case "?":
typ = CondItems
case "|":
typ = SequentialItems
case "$":
typ = PermutedItems
}
if typ != UniformItems {
if nsp == 2 {
rls.AddParseErr("start special bracket: '? {' needs at least a rule name")
continue
}
}
sz := len(rstack)
if sz > 0 {
cr, ci := rls.ParseAddItem(rstack, sp)
ci.SubRule = &Rule{Name: cr.Name + "SubRule", Desc: desc, Type: typ, RepeatP: rptp}
rstack = append(rstack, ci.SubRule)
ncond := nsp - 1
if typ == CondItems {
ncond--
}
ci.Cond = rls.ParseConds(sp[:ncond])
} else {
nr := &Rule{Name: rnm, Desc: desc, Type: typ, RepeatP: rptp}
rstack = append(rstack, nr)
rls.Add(nr)
}
case sp[nsp-1] == "}":
cr, ci := rls.ParseAddItem(rstack, sp)
if cr == nil {
continue
}
ci.SubRule = &Rule{Name: cr.Name + "SubRule"}
sbidx := 0
for si, s := range sp {
if s == "{" {
sbidx = si
}
}
ci.Cond = rls.ParseConds(sp[:sbidx])
it := &Item{}
ci.SubRule.Items = append(ci.SubRule.Items, it)
rls.ParseElems(ci.SubRule, it, sp[sbidx+1:nsp-1])
case sp[0][0] == '=':
rl := rls.ParseCurRule(rstack, sp)
rls.ParseState(sp[0][1:], &rl.State)
case sp[0][0] == '%':
rl, it := rls.ParseAddItem(rstack, sp)
if rl == nil {
continue
}
pct, err := strconv.ParseFloat(sp[0][1:], 32)
if err != nil {
rls.AddParseErr(err.Error())
}
it.Prob = float32(pct / 100)
if rl.Type == UniformItems {
rl.Type = ProbItems
}
rls.ParseElems(rl, it, sp[1:])
default:
rl, it := rls.ParseAddItem(rstack, sp)
if rl == nil {
continue
}
rls.ParseElems(rl, it, sp)
}
}
if len(rls.ParseErrs) > 0 {
fmt.Printf("\nesg Parse errors for: %s\n", rls.Name)
for _, err := range rls.ParseErrs {
fmt.Println(err)
}
}
return rls.ParseErrs
}
func (rls *Rules) ParseCurRule(rstack []*Rule, sp []string) *Rule {
sz := len(rstack)
if sz == 0 {
rls.AddParseErr(fmt.Sprintf("no active rule when defining items: %v", sp))
return nil
}
return rstack[sz-1]
}
func (rls *Rules) ParseAddItem(rstack []*Rule, sp []string) (*Rule, *Item) {
rl := rls.ParseCurRule(rstack, sp)
if rl == nil {
return nil, nil
}
it := &Item{}
rl.Items = append(rl.Items, it)
return rl, it
}
func (rls *Rules) ParseElems(rl *Rule, it *Item, els []string) {
for _, es := range els {
switch {
case es[0] == '=':
rls.ParseState(es[1:], &it.State)
case es[0] == '\'':
if len(es) < 3 {
rls.AddParseErr(fmt.Sprintf("empty token: %v in els: %v", es, els))
} else {
tok := es[1 : len(es)-1]
it.Elems = append(it.Elems, Elem{El: TokenEl, Value: tok})
}
default:
it.Elems = append(it.Elems, Elem{El: RuleEl, Value: es})
}
}
}
func (rls *Rules) ParseState(ststr string, state *State) {
stsp := strings.Split(ststr, "=")
if len(stsp) == 0 {
rls.AddParseErr(fmt.Sprintf("state expr: %v empty", ststr))
} else {
if len(stsp) > 1 {
state.Add(stsp[0], stsp[1])
} else {
state.Add(stsp[0], "")
}
}
}
func (rls *Rules) ParseConds(cds []string) Conds {
cs := Conds{}
cur := &cs
substack := []*Conds{cur}
for _, c := range cds {
for {
csz := len(c)
switch {
case csz == 0:
rls.AddParseErr("no text left in cond expr")
case c == "&&":
*cur = append(*cur, &Cond{El: And})
case c == "||":
*cur = append(*cur, &Cond{El: Or})
case c[0] == '!':
*cur = append(*cur, &Cond{El: Not})
c = c[1:]
continue
case c == "(":
sub := &Cond{El: SubCond}
*cur = append(*cur, sub)
cur = &sub.Conds
substack = append(substack, cur)
case c[0] == '(':
sub := &Cond{El: SubCond}
*cur = append(*cur, sub)
cur = &sub.Conds
substack = append(substack, cur)
c = c[1:]
continue
case c[csz-1] == ')':
ssz := len(substack)
if ssz == 1 {
rls.AddParseErr("imbalanced parens in cond expr: " + strings.Join(cds, " "))
} else {
*cur = append(*cur, &Cond{El: CRule, Rule: c[:csz-1]})
cur = substack[ssz-2]
substack = substack[:ssz-1]
}
case c == ")":
ssz := len(substack)
if ssz == 1 {
rls.AddParseErr("imbalanced parens in cond expr: " + strings.Join(cds, " "))
} else {
cur = substack[ssz-2]
substack = substack[:ssz-1]
}
default:
*cur = append(*cur, &Cond{El: CRule, Rule: c})
}
break
}
}
return cs
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package esg
//go:generate core generate -add-types
import (
"fmt"
"math/rand"
"strings"
"cogentcore.org/lab/base/randx"
)
// RuleTypes are different types of rules (i.e., how the items are selected)
type RuleTypes int32 //enums:enum
const (
// UniformItems is the default mutually exclusive items chosen at uniform random
UniformItems RuleTypes = iota
// ProbItems has specific probabilities for each item
ProbItems
// CondItems has conditionals for each item, indicated by ?
CondItems
// SequentialItems progresses through items in sequential order, indicated by |
SequentialItems
// PermutedItems progresses through items in permuted order, indicated by $
PermutedItems
)
// Rule is one rule containing some number of items
type Rule struct { //git:add
// name of rule
Name string
// description / notes on rule
Desc string
// type of rule -- how to choose the items
Type RuleTypes
// items in rule
Items []*Item
// state update for rule
State State
// previously selected item (from perspective of current rule)
PrevIndex int
// current index in Items (what will be used next)
CurIndex int
// probability of repeating same item -- signaled by =%p
RepeatP float32
// permuted order if doing that
Order []int
}
// Init initializes the rules -- only relevant for ordered rules (restarts at start)
func (rl *Rule) Init() {
rl.CurIndex = 0
rl.PrevIndex = -1
if rl.Type == PermutedItems {
rl.Order = rand.Perm(len(rl.Items))
}
}
// Gen generates expression according to the rule, appending output
// to the rls.Output array
func (rl *Rule) Gen(rls *Rules) {
rls.SetFired(rl.Name)
rl.State.Set(rls, rl.Name)
if rls.Trace {
fmt.Printf("Fired Rule: %v\n", rl.Name)
}
if rl.RepeatP > 0 && rl.PrevIndex >= 0 {
rpt := randx.BoolP32(rl.RepeatP)
if rpt {
if rls.Trace {
fmt.Printf("Selected item: %v due to RepeatP = %v\n", rl.PrevIndex, rl.RepeatP)
}
rl.Items[rl.PrevIndex].Gen(rl, rls)
return
}
}
switch rl.Type {
case UniformItems:
no := len(rl.Items)
opt := rand.Intn(no)
if rls.Trace {
fmt.Printf("Selected item: %v from: %v uniform random\n", opt, no)
}
rl.PrevIndex = opt
rl.Items[opt].Gen(rl, rls)
case ProbItems:
pv := rand.Float32()
sum := float32(0)
for ii, it := range rl.Items {
sum += it.Prob
if pv < sum { // note: lower values already excluded
if rls.Trace {
fmt.Printf("Selected item: %v using rnd val: %v sum: %v\n", ii, pv, sum)
}
rl.PrevIndex = ii
it.Gen(rl, rls)
return
}
}
rl.PrevIndex = -1
if rls.Trace {
fmt.Printf("No items selected using rnd val: %v sum: %v\n", pv, sum)
}
case CondItems:
var copts []int
for ii, it := range rl.Items {
if it.CondEval(rl, rls) {
copts = append(copts, ii)
}
}
no := len(copts)
if no == 0 {
if rls.Trace {
fmt.Printf("No items match Conds\n")
}
return
}
opt := rand.Intn(no)
if rls.Trace {
fmt.Printf("Selected item: %v from: %v matching Conds\n", copts[opt], no)
}
rl.PrevIndex = copts[opt]
rl.Items[copts[opt]].Gen(rl, rls)
case SequentialItems:
no := len(rl.Items)
if no == 0 {
return
}
if rl.CurIndex >= no {
rl.CurIndex = 0
}
opt := rl.CurIndex
if rls.Trace {
fmt.Printf("Selected item: %v sequentially\n", opt)
}
rl.PrevIndex = opt
rl.CurIndex++
rl.Items[opt].Gen(rl, rls)
case PermutedItems:
no := len(rl.Items)
if no == 0 {
return
}
if len(rl.Order) != no {
rl.Order = rand.Perm(no)
rl.CurIndex = 0
}
if rl.CurIndex >= no {
randx.PermuteInts(rl.Order)
rl.CurIndex = 0
}
opt := rl.Order[rl.CurIndex]
if rls.Trace {
fmt.Printf("Selected item: %v sequentially\n", opt)
}
rl.PrevIndex = opt
rl.CurIndex++
rl.Items[opt].Gen(rl, rls)
}
}
// String generates string representation of rule
func (rl *Rule) String() string {
if strings.HasSuffix(rl.Name, "SubRule") {
str := " {\n"
for _, it := range rl.Items {
str += "\t\t" + it.String() + "\n"
}
str += "\t}\n"
return str
} else {
str := "\n\n"
if rl.Desc != "" {
str += "// " + rl.Desc + "\n"
}
str += rl.Name
switch rl.Type {
case CondItems:
str += " ? "
case SequentialItems:
str += " | "
case PermutedItems:
str += " $ "
}
str += " {\n"
for _, it := range rl.Items {
str += "\t" + it.String() + "\n"
}
str += "}\n"
return str
}
}
// Validate checks for config errors
func (rl *Rule) Validate(rls *Rules) []error {
nr := len(rl.Items)
if nr == 0 {
err := fmt.Errorf("Rule: %v has no items", rl.Name)
return []error{err}
}
var errs []error
for _, it := range rl.Items {
if rl.Type == CondItems {
if len(it.Cond) == 0 {
errs = append(errs, fmt.Errorf("Rule: %v is CondItems, but Item: %v has no Cond", rl.Name, it.String()))
}
if it.SubRule == nil {
errs = append(errs, fmt.Errorf("Rule: %v is CondItems, but Item: %v has nil SubRule", rl.Name, it.String()))
}
} else {
if rl.Type == ProbItems && it.Prob == 0 {
errs = append(errs, fmt.Errorf("Rule: %v is ProbItems, but Item: %v has 0 Prob", rl.Name, it.String()))
} else if rl.Type == UniformItems && it.Prob > 0 {
errs = append(errs, fmt.Errorf("Rule: %v is UniformItems, but Item: %v has > 0 Prob", rl.Name, it.String()))
}
}
iterrs := it.Validate(rl, rls)
if len(iterrs) > 0 {
errs = append(errs, iterrs...)
}
}
return errs
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package esg
import (
"fmt"
)
// Rules is a collection of rules
type Rules struct { //git:add
// name of this rule collection
Name string
// description of this rule collection
Desc string
// if true, will print out a trace during generation
Trace bool
// top-level rule -- this is where to start generating
Top *Rule
// map of each rule
Map map[string]*Rule
// map of names of all the rules that have fired
Fired map[string]bool
// array of output strings -- appended as the rules generate output
Output []string
// user-defined state map optionally created during generation
States State
// errors from parsing
ParseErrs []error
// current line number during parsing
ParseLn int
}
// Gen generates one expression according to the rules.
// returns the token output, which is also avail as rls.Output
func (rls *Rules) Gen() []string {
rls.Fired = make(map[string]bool)
rls.States = make(State)
rls.Output = nil
if rls.Trace {
fmt.Printf("\n#########################\nRules: %v starting Gen\n", rls.Name)
}
rls.Top.Gen(rls)
return rls.Output
}
// String generates string representation of all rules
func (rls *Rules) String() string {
str := "Rules: " + rls.Name
if rls.Desc != "" {
str += ": " + rls.Desc
}
str += "\n"
for _, rl := range rls.Map {
str += rl.String()
}
return str
}
// Validate checks for config errors
func (rls *Rules) Validate() []error {
if len(rls.Map) == 0 {
return []error{fmt.Errorf("Rules: %v has no Rules", rls.Name)}
}
var errs []error
if rls.Top == nil {
errs = append(errs, fmt.Errorf("Rules: %v Top is nil", rls.Name))
}
for _, rl := range rls.Map {
ers := rl.Validate(rls)
if len(ers) > 0 {
errs = append(errs, ers...)
}
}
if len(errs) > 0 {
fmt.Printf("\nValidation errors:\n")
for _, err := range errs {
fmt.Println(err)
}
}
return errs
}
// Init initializes rule order state
func (rls *Rules) Init() {
rls.Top.Init()
for _, rl := range rls.Map {
rl.Init()
}
}
// Rule returns rule of given name, and error if not found
func (rls *Rules) Rule(name string) (*Rule, error) {
rl, ok := rls.Map[name]
if ok {
return rl, nil
}
return nil, fmt.Errorf("Rule: %v not found in Rules: %v", name, rls.Name)
}
// HasFired returns true if rule of given name has fired
func (rls *Rules) HasFired(name string) bool {
_, has := rls.Fired[name]
return has
}
// HasOutput returns true if given token is in the output string
// strips ' ' delimiters if present in out string
func (rls *Rules) HasOutput(out string) bool {
if out[0] == '\'' {
out = out[1 : len(out)-1]
}
for _, o := range rls.Output {
if o == out {
return true
}
}
return false
}
// SetFired adds given rule name to map of those that fired this round
func (rls *Rules) SetFired(name string) {
rls.Fired[name] = true
}
// AddOutput adds given string to Output array
func (rls *Rules) AddOutput(out string) {
rls.Output = append(rls.Output, out)
}
// Adds given rule
func (rls *Rules) Add(rl *Rule) {
if rls.Map == nil {
rls.Map = make(map[string]*Rule)
rls.Top = rl
}
rls.Map[rl.Name] = rl
}
// Code generated by "core generate -add-types"; DO NOT EDIT.
package etime
import (
"cogentcore.org/core/enums"
)
var _ModesValues = []Modes{0, 1, 2, 3, 4, 5, 6}
// ModesN is the highest valid value for type Modes, plus one.
const ModesN Modes = 7
var _ModesValueMap = map[string]Modes{`NoEvalMode`: 0, `AllModes`: 1, `Train`: 2, `Test`: 3, `Validate`: 4, `Analyze`: 5, `Debug`: 6}
var _ModesDescMap = map[Modes]string{0: ``, 1: `AllModes indicates that the log should occur over all modes present in other items.`, 2: `Train is when the network is learning`, 3: `Test is when testing, typically without learning`, 4: `Validate is typically for a special held-out testing set`, 5: `Analyze is when analyzing the representations and behavior of the network`, 6: `Debug is for recording info particularly useful for debugging`}
var _ModesMap = map[Modes]string{0: `NoEvalMode`, 1: `AllModes`, 2: `Train`, 3: `Test`, 4: `Validate`, 5: `Analyze`, 6: `Debug`}
// String returns the string representation of this Modes value.
func (i Modes) String() string { return enums.String(i, _ModesMap) }
// SetString sets the Modes value from its string representation,
// and returns an error if the string is invalid.
func (i *Modes) SetString(s string) error { return enums.SetString(i, s, _ModesValueMap, "Modes") }
// Int64 returns the Modes value as an int64.
func (i Modes) Int64() int64 { return int64(i) }
// SetInt64 sets the Modes value from an int64.
func (i *Modes) SetInt64(in int64) { *i = Modes(in) }
// Desc returns the description of the Modes value.
func (i Modes) Desc() string { return enums.Desc(i, _ModesDescMap) }
// ModesValues returns all possible values for the type Modes.
func ModesValues() []Modes { return _ModesValues }
// Values returns all possible values for the type Modes.
func (i Modes) Values() []enums.Enum { return enums.Values(_ModesValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i Modes) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *Modes) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Modes") }
var _TimesValues = []Times{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}
// TimesN is the highest valid value for type Times, plus one.
const TimesN Times = 20
var _TimesValueMap = map[string]Times{`NoTime`: 0, `AllTimes`: 1, `Cycle`: 2, `FastSpike`: 3, `GammaCycle`: 4, `Phase`: 5, `BetaCycle`: 6, `AlphaCycle`: 7, `ThetaCycle`: 8, `Event`: 9, `Trial`: 10, `Tick`: 11, `Sequence`: 12, `Epoch`: 13, `Block`: 14, `Condition`: 15, `Run`: 16, `Expt`: 17, `Scene`: 18, `Episode`: 19}
var _TimesDescMap = map[Times]string{0: `NoTime represents a non-initialized value, or a null result`, 1: `AllTimes indicates that the log should occur over all times present in other items.`, 2: `Cycle is the finest time scale -- typically 1 msec -- a single activation update.`, 3: `FastSpike is typically 10 cycles = 10 msec (100hz) = the fastest spiking time generally observed in the brain. This can be useful for visualizing updates at a granularity in between Cycle and GammaCycle.`, 4: `GammaCycle is typically 25 cycles = 25 msec (40hz)`, 5: `Phase is typically a Minus or Plus phase, where plus phase is bursting / outcome that drives positive learning relative to prediction in minus phase. It can also be used for other time scales involving multiple Cycles.`, 6: `BetaCycle is typically 50 cycles = 50 msec (20 hz) = one beta-frequency cycle. Gating in the basal ganglia and associated updating in prefrontal cortex occurs at this frequency.`, 7: `AlphaCycle is typically 100 cycles = 100 msec (10 hz) = one alpha-frequency cycle.`, 8: `ThetaCycle is typically 200 cycles = 200 msec (5 hz) = two alpha-frequency cycles. This is the modal duration of a saccade, the update frequency of medial temporal lobe episodic memory, and the minimal predictive learning cycle (perceive an Alpha 1, predict on 2).`, 9: `Event is the smallest unit of naturalistic experience that coheres unto itself (e.g., something that could be described in a sentence). Typically this is on the time scale of a few seconds: e.g., reaching for something, catching a ball.`, 10: `Trial is one unit of behavior in an experiment -- it is typically environmentally defined instead of endogenously defined in terms of basic brain rhythms. In the minimal case it could be one ThetaCycle, but could be multiple, and could encompass multiple Events (e.g., one event is fixation, next is stimulus, last is response)`, 11: `Tick is one step in a sequence -- often it is useful to have Trial count up throughout the entire Epoch but also include a Tick to count trials within a Sequence`, 12: `Sequence is a sequential group of Trials (not always needed).`, 13: `Epoch is used in two different contexts. In machine learning, it represents a collection of Trials, Sequences or Events that constitute a "representative sample" of the environment. In the simplest case, it is the entire collection of Trials used for training. In electrophysiology, it is a timing window used for organizing the analysis of electrode data.`, 14: `Block is a collection of Trials, Sequences or Events, often used in experiments when conditions are varied across blocks.`, 15: `Condition is a collection of Blocks that share the same set of parameters. This is intermediate between Block and Run levels. Aggregation of stats at this level is based on the last 5 rows by default.`, 16: `Run is a complete run of a model / subject, from training to testing, etc. Often multiple runs are done in an Expt to obtain statistics over initial random weights etc. Aggregation of stats at this level is based on the last 5 rows by default.`, 17: `Expt is an entire experiment -- multiple Runs through a given protocol / set of parameters.`, 18: `Scene is a sequence of events that constitutes the next larger-scale coherent unit of naturalistic experience corresponding e.g., to a scene in a movie. Typically consists of events that all take place in one location over e.g., a minute or so. This could be a paragraph or a page or so in a book.`, 19: `Episode is a sequence of scenes that constitutes the next larger-scale unit of naturalistic experience e.g., going to the grocery store or eating at a restaurant, attending a wedding or other "event". This could be a chapter in a book.`}
var _TimesMap = map[Times]string{0: `NoTime`, 1: `AllTimes`, 2: `Cycle`, 3: `FastSpike`, 4: `GammaCycle`, 5: `Phase`, 6: `BetaCycle`, 7: `AlphaCycle`, 8: `ThetaCycle`, 9: `Event`, 10: `Trial`, 11: `Tick`, 12: `Sequence`, 13: `Epoch`, 14: `Block`, 15: `Condition`, 16: `Run`, 17: `Expt`, 18: `Scene`, 19: `Episode`}
// String returns the string representation of this Times value.
func (i Times) String() string { return enums.String(i, _TimesMap) }
// SetString sets the Times value from its string representation,
// and returns an error if the string is invalid.
func (i *Times) SetString(s string) error { return enums.SetString(i, s, _TimesValueMap, "Times") }
// Int64 returns the Times value as an int64.
func (i Times) Int64() int64 { return int64(i) }
// SetInt64 sets the Times value from an int64.
func (i *Times) SetInt64(in int64) { *i = Times(in) }
// Desc returns the description of the Times value.
func (i Times) Desc() string { return enums.Desc(i, _TimesDescMap) }
// TimesValues returns all possible values for the type Times.
func TimesValues() []Times { return _TimesValues }
// Values returns all possible values for the type Times.
func (i Times) Values() []enums.Enum { return enums.Values(_TimesValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i Times) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *Times) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Times") }
// Copyright (c) 2022, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package etime
//go:generate core generate
// todo: this only has Train, Test in final V2 case
//gosl:start
// Modes are evaluation modes (Training, Testing, etc)
type Modes int32 //enums:enum
// The evaluation modes
const (
NoEvalMode Modes = iota
// AllModes indicates that the log should occur over all modes present in other items.
AllModes
// Train is when the network is learning
Train
// Test is when testing, typically without learning
Test
// Validate is typically for a special held-out testing set
Validate
// Analyze is when analyzing the representations and behavior of the network
Analyze
// Debug is for recording info particularly useful for debugging
Debug
)
//gosl:end
// ModeFromString returns Mode int value from string name
func ModeFromString(str string) Modes {
var mode Modes
mode.SetString(str)
return mode
}
// Copyright (c) 2022, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package etime
import (
"sort"
"strings"
)
// ScopeKey the associated string representation of a scope or scopes.
// They include one or more Modes and one or more Times.
// It is fully extensible with arbitrary mode and time strings --
// the enums are a convenience for standard cases.
// Ultimately a single mode, time pair is used concretely, but the
// All* cases and lists of multiple can be used as a convenience
// to specify ranges
type ScopeKey string
// Like "Train|Test&Epoch|Trial"
var (
ScopeKeySeparator = "&" // between mode and time
ScopeKeyList = "|" // between multiple modes, times
)
// FromScopesStr creates an associated scope merging
// the modes and times that are specified as strings
// If you modify this, also modify ModesAndTimes, below.
func (sk *ScopeKey) FromScopesStr(modes, times []string) {
var mstr string
var tstr string
for _, str := range modes {
if mstr == "" {
mstr = str
} else {
mstr += ScopeKeyList + str
}
}
for _, str := range times {
if tstr == "" {
tstr = str
} else {
tstr += ScopeKeyList + str
}
}
*sk = ScopeKey(mstr + ScopeKeySeparator + tstr)
}
// FromScopes creates an associated scope merging
// the modes and times that are specified
// If you modify this, also modify ModesAndTimes, below.
func (sk *ScopeKey) FromScopes(modes []Modes, times []Times) {
mstr := make([]string, len(modes))
for i, mode := range modes {
mstr[i] = mode.String()
}
tstr := make([]string, len(times))
for i, time := range times {
tstr[i] = time.String()
}
sk.FromScopesStr(mstr, tstr)
}
// FromScope create an associated scope from given
// standard mode and time
func (sk *ScopeKey) FromScope(mode Modes, time Times) {
sk.FromScopesStr([]string{mode.String()}, []string{time.String()})
}
// FromScopeStr create an associated scope from given
// mode and time as strings
func (sk *ScopeKey) FromScopeStr(mode, time string) {
sk.FromScopesStr([]string{mode}, []string{time})
}
// ModesAndTimes returns the mode(s) and time(s) as strings
// from the current key value. This must be the inverse
// of FromScopesStr
func (sk *ScopeKey) ModesAndTimes() (modes, times []string) {
skstr := strings.Split(string(*sk), ScopeKeySeparator)
modestr := skstr[0]
timestr := skstr[1]
modes = strings.Split(modestr, ScopeKeyList)
times = strings.Split(timestr, ScopeKeyList)
return
}
// ModeAndTimeStr returns the mode and time as strings
// from the current key value. Returns the first of each
// if there are multiple (intended for case when only 1).
func (sk *ScopeKey) ModeAndTimeStr() (mode, time string) {
md, tm := sk.ModesAndTimes()
return md[0], tm[0]
}
// ModeAndTime returns the singular mode and time as enums from a
// concrete scope key having one of each (No* cases if not standard)
func (sk *ScopeKey) ModeAndTime() (mode Modes, time Times) {
modes, times := sk.ModesAndTimes()
if len(modes) != 1 {
mode = NoEvalMode
} else {
if mode.SetString(modes[0]) != nil {
mode = NoEvalMode
}
}
if len(times) != 1 {
time = NoTime
} else {
if time.SetString(times[0]) != nil {
time = NoTime
}
}
return
}
// FromScopesMap creates an associated scope key merging
// the modes and times that are specified by map of strings.
func (sk *ScopeKey) FromScopesMap(modes, times map[string]bool) {
ml := make([]string, len(modes))
tl := make([]string, len(times))
idx := 0
for m := range modes {
ml[idx] = m
idx++
}
idx = 0
for t := range times {
tl[idx] = t
idx++
}
sk.FromScopesStr(ml, tl)
}
// ModesAndTimesMap returns maps of modes and times as strings
// parsed from the current scopekey
func (sk *ScopeKey) ModesAndTimesMap() (modes, times map[string]bool) {
ml, tl := sk.ModesAndTimes()
modes = make(map[string]bool)
times = make(map[string]bool)
for _, m := range ml {
modes[m] = true
}
for _, t := range tl {
times[t] = true
}
return
}
//////////////////////////////////////////////////
// Standalone funcs
// Scope generates a scope key string from one mode and time
func Scope(mode Modes, time Times) ScopeKey {
var ss ScopeKey
ss.FromScope(mode, time)
return ss
}
// ScopeStr generates a scope key string from string
// values for mode, time
func ScopeStr(mode, time string) ScopeKey {
var ss ScopeKey
ss.FromScopeStr(mode, time)
return ss
}
// Scopes generates a scope key string from multiple modes, times
func Scopes(modes []Modes, times []Times) ScopeKey {
var ss ScopeKey
ss.FromScopes(modes, times)
return ss
}
// ScopesStr generates a scope key string from multiple modes, times
func ScopesStr(modes, times []string) ScopeKey {
var ss ScopeKey
ss.FromScopesStr(modes, times)
return ss
}
// ScopesMap generates a scope key from maps of modes and times (warning: ordering is random!)
func ScopesMap(modes, times map[string]bool) ScopeKey {
var ss ScopeKey
ss.FromScopesMap(modes, times)
return ss
}
// ScopeName generates a string name as just the concatenation of mode + time
// e.g., used for naming log tables
func ScopeName(mode Modes, time Times) string {
return mode.String() + time.String()
}
// SortScopes sorts a list of concrete mode, time
// scopes according to the Modes and Times enum ordering
func SortScopes(scopes []ScopeKey) []ScopeKey {
sort.Slice(scopes, func(i, j int) bool {
mi, ti := scopes[i].ModeAndTime()
mj, tj := scopes[j].ModeAndTime()
if mi < mj {
return true
}
if mi > mj {
return false
}
return ti < tj
})
return scopes
}
// CloneScopeSlice returns a copy of given ScopeKey slice
func CloneScopeSlice(ss []ScopeKey) []ScopeKey {
cp := make([]ScopeKey, len(ss))
for i, sc := range ss {
cp[i] = sc
}
return cp
}
// Copyright (c) 2022, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package etime
//go:generate core generate -add-types
// Times the enum
type Times int32 //enums:enum
// A list of predefined time scales at which logging can occur
const (
// NoTime represents a non-initialized value, or a null result
NoTime Times = iota
// AllTimes indicates that the log should occur over all times present in other items.
AllTimes
// Cycle is the finest time scale -- typically 1 msec -- a single activation update.
Cycle
// FastSpike is typically 10 cycles = 10 msec (100hz) = the fastest spiking time
// generally observed in the brain. This can be useful for visualizing updates
// at a granularity in between Cycle and GammaCycle.
FastSpike
// GammaCycle is typically 25 cycles = 25 msec (40hz)
GammaCycle
// Phase is typically a Minus or Plus phase, where plus phase is bursting / outcome
// that drives positive learning relative to prediction in minus phase.
// It can also be used for other time scales involving multiple Cycles.
Phase
// BetaCycle is typically 50 cycles = 50 msec (20 hz) = one beta-frequency cycle.
// Gating in the basal ganglia and associated updating in prefrontal cortex
// occurs at this frequency.
BetaCycle
// AlphaCycle is typically 100 cycles = 100 msec (10 hz) = one alpha-frequency cycle.
AlphaCycle
// ThetaCycle is typically 200 cycles = 200 msec (5 hz) = two alpha-frequency cycles.
// This is the modal duration of a saccade, the update frequency of medial temporal lobe
// episodic memory, and the minimal predictive learning cycle (perceive an Alpha 1, predict on 2).
ThetaCycle
// Event is the smallest unit of naturalistic experience that coheres unto itself
// (e.g., something that could be described in a sentence).
// Typically this is on the time scale of a few seconds: e.g., reaching for
// something, catching a ball.
Event
// Trial is one unit of behavior in an experiment -- it is typically environmentally
// defined instead of endogenously defined in terms of basic brain rhythms.
// In the minimal case it could be one ThetaCycle, but could be multiple, and
// could encompass multiple Events (e.g., one event is fixation, next is stimulus,
// last is response)
Trial
// Tick is one step in a sequence -- often it is useful to have Trial count
// up throughout the entire Epoch but also include a Tick to count trials
// within a Sequence
Tick
// Sequence is a sequential group of Trials (not always needed).
Sequence
// Epoch is used in two different contexts. In machine learning, it represents a
// collection of Trials, Sequences or Events that constitute a "representative sample"
// of the environment. In the simplest case, it is the entire collection of Trials
// used for training. In electrophysiology, it is a timing window used for organizing
// the analysis of electrode data.
Epoch
// Block is a collection of Trials, Sequences or Events, often used in experiments
// when conditions are varied across blocks.
Block
// Condition is a collection of Blocks that share the same set of parameters.
// This is intermediate between Block and Run levels.
// Aggregation of stats at this level is based on the last 5 rows by default.
Condition
// Run is a complete run of a model / subject, from training to testing, etc.
// Often multiple runs are done in an Expt to obtain statistics over initial
// random weights etc.
// Aggregation of stats at this level is based on the last 5 rows by default.
Run
// Expt is an entire experiment -- multiple Runs through a given protocol / set of
// parameters.
Expt
// Scene is a sequence of events that constitutes the next larger-scale coherent unit
// of naturalistic experience corresponding e.g., to a scene in a movie.
// Typically consists of events that all take place in one location over
// e.g., a minute or so. This could be a paragraph or a page or so in a book.
Scene
// Episode is a sequence of scenes that constitutes the next larger-scale unit
// of naturalistic experience e.g., going to the grocery store or eating at a
// restaurant, attending a wedding or other "event".
// This could be a chapter in a book.
Episode
)
// TimeFromString returns Time int value from string name
func TimeFromString(str string) Times {
var time Times
time.SetString(str)
return time
}
// Copyright (c) 2022, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package looper
// Counter combines an integer with a maximum value.
// It supports iteration tracking within looper.
type Counter struct {
// Cur is the current counter value.
Cur int
// Max is the maximum counter value.
// Only used if > 0 ([Loop] requires an IsDone condition to stop).
Max int
// Inc is the increment per iteration.
Inc int
}
// SetMaxIncr sets the given Max and Inc value for the counter.
func (ct *Counter) SetMaxInc(mx, inc int) {
ct.Max = mx
ct.Inc = inc
}
// Incr increments the counter by Inc. Does not interact with Max.
func (ct *Counter) Incr() {
ct.Cur += ct.Inc
}
// SkipToMax sets the counter to its Max value,
// for skipping over rest of loop iterations.
func (ct *Counter) SkipToMax() {
ct.Cur = ct.Max
}
// IsOverMax returns true if counter is at or over Max (only if Max > 0).
func (ct *Counter) IsOverMax() bool {
return ct.Max > 0 && ct.Cur >= ct.Max
}
// Set sets the Cur value with return value indicating whether it is different
// from current Cur.
func (ct *Counter) Set(cur int) bool {
if ct.Cur == cur {
return false
}
ct.Cur = cur
return true
}
// SetCurMax sets the Cur and Max values, as a convenience.
func (ct *Counter) SetCurMax(cur, max int) {
ct.Cur = cur
ct.Max = max
}
// SetCurMaxPlusN sets the Cur value and Max as Cur + N -- run N more beyond current.
func (ct *Counter) SetCurMaxPlusN(cur, n int) {
ct.Cur = cur
ct.Max = cur + n
}
// Copyright (c) 2022, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package looper
import (
"strconv"
)
// A Event has function(s) that can be called at a particular point
// in the loop, when the counter is AtCounter value.
type Event struct {
// Name of this event.
Name string
// AtCounter is the counter value upon which this Event occurs.
AtCounter int
// OnEvent are the functions to run when Counter == AtCounter.
OnEvent NamedFuncs
}
// String describes the Event in human readable text.
func (event *Event) String() string {
s := event.Name + ": "
s = s + "[at " + strconv.Itoa(event.AtCounter) + "] "
if len(event.OnEvent) > 0 {
s = s + "Events: " + event.OnEvent.String()
}
return s
}
// NewEvent returns a new event with given name, function, at given counter
func NewEvent(name string, atCtr int, fun func()) *Event {
ev := &Event{Name: name, AtCounter: atCtr}
ev.OnEvent.Add(name, fun)
return ev
}
// Copyright (c) 2022, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package looper
import (
"fmt"
"slices"
"cogentcore.org/core/base/errors"
)
// NamedFunc is a function closure with a name.
// Function returns a bool which is needed for stopping condition
// but is otherwise not used.
type NamedFunc struct {
Name string
Func func() bool
}
// NamedFuncs is an ordered list of named functions.
type NamedFuncs []NamedFunc
// Add adds a named function (with no bool return value).
func (funcs *NamedFuncs) Add(name string, fun func()) *NamedFuncs {
*funcs = append(*funcs, NamedFunc{Name: name, Func: func() bool { fun(); return true }})
return funcs
}
// AddBool adds a named function with a bool return value, for IsDone case.
func (funcs *NamedFuncs) AddBool(name string, fun func() bool) *NamedFuncs {
*funcs = append(*funcs, NamedFunc{Name: name, Func: fun})
return funcs
}
// String prints the list of named functions.
func (funcs *NamedFuncs) String() string {
s := ""
for _, f := range *funcs {
s = s + f.Name + " "
}
return s
}
// Run runs all of the functions, returning true if any of
// the functions returned true.
func (funcs NamedFuncs) Run() bool {
ret := false
for _, fn := range funcs {
r := fn.Func()
if r {
ret = true
}
}
return ret
}
// FuncIndex finds index of function by name.
// Returns not found err message if not found.
func (funcs *NamedFuncs) FuncIndex(name string) (int, error) {
for i, fn := range *funcs {
if fn.Name == name {
return i, nil
}
}
err := fmt.Errorf("looper.Funcs:FuncIndex: named function %s not found", name)
return -1, err
}
// InsertAt inserts function at given index.
func (funcs *NamedFuncs) InsertAt(i int, name string, fun func() bool) {
*funcs = slices.Insert(*funcs, i, NamedFunc{Name: name, Func: fun})
}
// Prepend adds a function to the start of the list.
func (funcs *NamedFuncs) Prepend(name string, fun func() bool) {
funcs.InsertAt(0, name, fun)
}
// InsertBefore inserts function before other function of given name.
func (funcs *NamedFuncs) InsertBefore(before, name string, fun func() bool) error {
i, err := funcs.FuncIndex(before)
if errors.Log(err) != nil {
return err
}
funcs.InsertAt(i, name, fun)
return nil
}
// InsertAfter inserts function after other function of given name.
func (funcs *NamedFuncs) InsertAfter(after, name string, fun func() bool) error {
i, err := funcs.FuncIndex(after)
if errors.Log(err) != nil {
return err
}
funcs.InsertAt(i+1, name, fun)
return nil
}
// Replace replaces function with other function of given name.
func (funcs *NamedFuncs) Replace(name string, fun func() bool) error {
i, err := funcs.FuncIndex(name)
if errors.Log(err) != nil {
return err
}
(*funcs)[i] = NamedFunc{Name: name, Func: fun}
return nil
}
// Delete deletes function of given name.
func (funcs *NamedFuncs) Delete(name string) error {
i, err := funcs.FuncIndex(name)
if errors.Log(err) != nil {
return err
}
*funcs = slices.Delete(*funcs, i, i+1)
return nil
}
// Code generated by "core generate"; DO NOT EDIT.
package levels
import (
"cogentcore.org/core/enums"
)
var _ModesValues = []Modes{0, 1}
// ModesN is the highest valid value for type Modes, plus one.
//
//gosl:start
const ModesN Modes = 2
//gosl:end
var _ModesValueMap = map[string]Modes{`Train`: 0, `Test`: 1}
var _ModesDescMap = map[Modes]string{0: ``, 1: ``}
var _ModesMap = map[Modes]string{0: `Train`, 1: `Test`}
// String returns the string representation of this Modes value.
func (i Modes) String() string { return enums.String(i, _ModesMap) }
// SetString sets the Modes value from its string representation,
// and returns an error if the string is invalid.
func (i *Modes) SetString(s string) error { return enums.SetString(i, s, _ModesValueMap, "Modes") }
// Int64 returns the Modes value as an int64.
func (i Modes) Int64() int64 { return int64(i) }
// SetInt64 sets the Modes value from an int64.
func (i *Modes) SetInt64(in int64) { *i = Modes(in) }
// Desc returns the description of the Modes value.
func (i Modes) Desc() string { return enums.Desc(i, _ModesDescMap) }
// ModesValues returns all possible values for the type Modes.
func ModesValues() []Modes { return _ModesValues }
// Values returns all possible values for the type Modes.
func (i Modes) Values() []enums.Enum { return enums.Values(_ModesValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i Modes) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *Modes) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Modes") }
var _LevelsValues = []Levels{0, 1, 2, 3}
// LevelsN is the highest valid value for type Levels, plus one.
//
//gosl:start
const LevelsN Levels = 4
//gosl:end
var _LevelsValueMap = map[string]Levels{`Cycle`: 0, `Trial`: 1, `Epoch`: 2, `Run`: 3}
var _LevelsDescMap = map[Levels]string{0: ``, 1: ``, 2: ``, 3: ``}
var _LevelsMap = map[Levels]string{0: `Cycle`, 1: `Trial`, 2: `Epoch`, 3: `Run`}
// String returns the string representation of this Levels value.
func (i Levels) String() string { return enums.String(i, _LevelsMap) }
// SetString sets the Levels value from its string representation,
// and returns an error if the string is invalid.
func (i *Levels) SetString(s string) error { return enums.SetString(i, s, _LevelsValueMap, "Levels") }
// Int64 returns the Levels value as an int64.
func (i Levels) Int64() int64 { return int64(i) }
// SetInt64 sets the Levels value from an int64.
func (i *Levels) SetInt64(in int64) { *i = Levels(in) }
// Desc returns the description of the Levels value.
func (i Levels) Desc() string { return enums.Desc(i, _LevelsDescMap) }
// LevelsValues returns all possible values for the type Levels.
func LevelsValues() []Levels { return _LevelsValues }
// Values returns all possible values for the type Levels.
func (i Levels) Values() []enums.Enum { return enums.Values(_LevelsValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i Levels) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *Levels) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Levels") }
// Copyright (c) 2022, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package looper
import (
"fmt"
"strings"
)
// Loop contains one level of a multi-level iteration stack,
// with functions that can be called at the start and end
// of each iteration of the loop, and a Counter that increments
// for each iteration, terminating if >= Max, or IsDone returns true.
// Within each iteration, any sub-loop at the next level down
// in its [Stack] runs its full set of iterations.
// The control flow is:
//
// for {
// Events[Counter == AtCounter] // run events at counter
// OnStart()
// Run Sub-Loop to completion
// OnEnd()
// Counter += Inc
// if Counter >= Max || IsDone() {
// break
// }
// }
type Loop struct {
// Counter increments every iteration through the loop, up to [Counter.Max].
Counter Counter
// Events occur when Counter.Cur is at their AtCounter.
Events []*Event
// OnStart functions are called at the beginning of each loop iteration.
OnStart NamedFuncs
// OnEnd functions are called at the end of each loop iteration.
OnEnd NamedFuncs
// IsDone functions are called after each loop iteration,
// and if any return true, then the loop iteration is terminated.
IsDone NamedFuncs
// StepCount is the default step count for this loop level.
StepCount int
}
// NewLoop returns a new loop with given Counter Max and increment.
func NewLoop(ctrMax, ctrIncr int) *Loop {
lp := &Loop{}
lp.Counter.SetMaxInc(ctrMax, ctrIncr)
lp.StepCount = 1
return lp
}
// AddEvent adds a new event at given counter. If an event already exists
// for that counter, the function is added to the list for that event.
func (lp *Loop) AddEvent(name string, atCtr int, fun func()) *Event {
ev := lp.EventByCounter(atCtr)
if ev == nil {
ev = NewEvent(name, atCtr, fun)
lp.Events = append(lp.Events, ev)
} else {
ev.OnEvent.Add(name, fun)
}
return ev
}
// EventByCounter returns event for given atCounter value, nil if not found.
func (lp *Loop) EventByCounter(atCtr int) *Event {
for _, ev := range lp.Events {
if ev.AtCounter == atCtr {
return ev
}
}
return nil
}
// EventByName returns event by name, nil if not found.
func (lp *Loop) EventByName(name string) *Event {
for _, ev := range lp.Events {
if ev.Name == name {
return ev
}
}
return nil
}
// SkipToMax sets the counter to its Max value for this level.
// for skipping over rest of loop.
func (lp *Loop) SkipToMax() {
lp.Counter.SkipToMax()
}
// DocString returns an indented summary of this loop and those below it.
func (lp *Loop) DocString(st *Stack, level int) string {
var sb strings.Builder
ctrs := ""
if lp.Counter.Inc > 1 {
ctrs = fmt.Sprintf("[0 : %d : %d]:\n", lp.Counter.Max, lp.Counter.Inc)
} else {
ctrs = fmt.Sprintf("[0 : %d]:\n", lp.Counter.Max)
}
sb.WriteString(indent(level+1) + st.Order[level].String() + ctrs)
if len(lp.Events) > 0 {
sb.WriteString(indent(level+2) + "Events:\n")
for _, ev := range lp.Events {
sb.WriteString(indent(level+3) + ev.String() + "\n")
}
}
if len(lp.OnStart) > 0 {
sb.WriteString(indent(level+2) + "Start: " + lp.OnStart.String() + "\n")
}
if level < len(st.Order)-1 {
slp := st.Level(level + 1)
sb.WriteString(slp.DocString(st, level+1))
}
if len(lp.OnEnd) > 0 {
sb.WriteString(indent(level+2) + "End: " + lp.OnEnd.String() + "\n")
}
if len(lp.IsDone) > 0 {
sb.WriteString(indent(level+2) + "IsDone: " + lp.IsDone.String() + "\n")
}
return sb.String()
}
// Copyright (c) 2022, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package looper
import (
"fmt"
"strings"
"cogentcore.org/core/enums"
)
func indent(level int) string {
return strings.Repeat(" ", level)
}
// runLevel implements nested run for loops recursively.
// It can be stopped and resumed at any point.
// returns true if the level was completed, and the level where it stopped.
func (ss *Stacks) runLevel(currentLevel int) (bool, enums.Enum) {
st := ss.Stacks[ss.Mode]
if currentLevel >= len(st.Order) {
return true, st.Order[0] // Stack overflow, should not happen
}
level := st.Order[currentLevel]
stoppedLevel := level // return value for what level it stopped at
loop := st.Loops[level]
ctr := &loop.Counter
for ctr.Cur < ctr.Max || ctr.Max <= 0 { // Loop forever for non-maxes
stoplev := int64(-1)
if st.StopLevel != nil {
stoplev = st.StopLevel.Int64()
stoppedLevel = st.StopLevel
}
stopAtLevelOrLarger := st.Order[currentLevel].Int64() >= stoplev
if st.StopFlag && stopAtLevelOrLarger {
ss.internalStop = true
}
if ss.internalStop {
// This should occur before ctr incrementing and before functions.
st.StopFlag = false
return false, stoppedLevel // Don't continue above, e.g. Stop functions
}
if st.StopNext && st.Order[currentLevel] == st.StopLevel {
st.StopCount -= 1
if st.StopCount <= 0 {
st.StopNext = false
st.StopFlag = true // Stop at the top of the next StopLevel
}
}
// Don't ever Start the same iteration of the same level twice.
lastCounter, ok := ss.lastStartedCounter[ToScope(ss.Mode, level)]
if !ok || ctr.Cur > lastCounter {
ss.lastStartedCounter[ToScope(ss.Mode, level)] = ctr.Cur
if PrintControlFlow {
fmt.Printf("%s%s: Start: %d\n", indent(currentLevel), level.String(), ctr.Cur)
}
for _, ev := range loop.Events {
if ctr.Cur == ev.AtCounter {
ev.OnEvent.Run()
}
}
loop.OnStart.Run()
} else if PrintControlFlow {
fmt.Printf("%s%s: Skipping Start: %d\n", indent(currentLevel), level.String(), ctr.Cur)
}
done := true
if currentLevel+1 < len(st.Order) {
done, stoppedLevel = ss.runLevel(currentLevel + 1)
}
if done {
if PrintControlFlow {
fmt.Printf("%s%s: End: %d\n", indent(currentLevel), level.String(), ctr.Cur)
}
loop.OnEnd.Run()
ctr.Incr()
// Reset the counter at the next level.
// Do this here so that the counter number is visible during loop.OnEnd.
if currentLevel+1 < len(st.Order) {
st.Level(currentLevel + 1).Counter.Cur = 0
ss.lastStartedCounter[ToScope(ss.Mode, st.Order[currentLevel+1])] = -1
}
for _, fun := range loop.IsDone {
if fun.Func() {
if PrintControlFlow {
fmt.Printf("%s%s: IsDone Stop at: %d from: %s\n", indent(currentLevel), level.String(), ctr.Cur, fun.Name)
}
goto exitLoop // Exit IsDone and Counter for-loops without flag variable.
}
}
}
}
exitLoop:
// Only get to this point if this loop is done.
return true, level
}
// Copyright (c) 2022, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package looper
import "cogentcore.org/core/enums"
// Scope is a combined Mode + Level value.
// Mode is encoded by multiples of 1000 and Level is added to that.
type Scope int
func (sc Scope) ModeLevel() (mode, level int64) {
mode = int64(sc / 1000)
level = int64(sc % 1000)
return
}
func ToScope(mode, level enums.Enum) Scope {
return Scope(mode.Int64()*1000 + level.Int64())
}
// Copyright (c) 2022, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package looper
import (
"fmt"
"strings"
"cogentcore.org/core/enums"
)
// Stack contains a list of Loops to run, for a given Mode of processing,
// which distinguishes this stack, and is its key in the map of Stacks.
// The order of Loop stacks is determined by the Order list of loop levels.
type Stack struct {
// Mode identifies the mode of processing this stack performs, e.g., Train or Test.
Mode enums.Enum
// Loops is the set of Loops for this Stack, keyed by the level enum value.
// Order is determined by the Order list.
Loops map[enums.Enum]*Loop
// Order is the list and order of levels looped over by this stack of loops.
// The order is from top to bottom, so longer timescales like Run should be at
// the start and shorter level timescales like Trial should be at the end.
Order []enums.Enum
// OnInit are functions to run when Init is called, to restart processing,
// which also resets the counters for this stack.
OnInit NamedFuncs
// StopNext will stop running at the end of the current StopLevel if set.
StopNext bool
// StopFlag will stop running ASAP if set.
StopFlag bool
// StopLevel sets the level to stop at the end of.
// This is the current active Step level, which will be reset when done.
StopLevel enums.Enum
// StopCount determines how many iterations at StopLevel before actually stopping.
// This is the current active Step control value.
StopCount int
// StepLevel is a saved copy of StopLevel for stepping.
// This is what was set for last Step call (which sets StopLevel) or by GUI.
StepLevel enums.Enum
// StepCount is a saved copy of StopCount for stepping.
// This is what was set for last Step call (which sets StopCount) or by GUI.
StepCount int
}
// NewStack returns a new Stack for given mode and default step level.
func NewStack(mode, stepLevel enums.Enum) *Stack {
st := &Stack{}
st.newInit(mode, stepLevel)
return st
}
// newInit initializes new data structures for a newly created object.
func (st *Stack) newInit(mode, stepLevel enums.Enum) {
st.Mode = mode
st.StepLevel = stepLevel
st.StepCount = 1
st.Loops = map[enums.Enum]*Loop{}
st.Order = []enums.Enum{}
}
// Level returns the [Loop] at the given ordinal level in the Order list.
// Will panic if out of range.
func (st *Stack) Level(i int) *Loop {
return st.Loops[st.Order[i]]
}
// AddLevel adds a new level to this Stack with a given counterMax number of iterations.
// The order in which this method is invoked is important,
// as it adds loops in order from top to bottom.
// Sets a default increment of 1 for the counter -- see AddLevelIncr for different increment.
func (st *Stack) AddLevel(level enums.Enum, counterMax int) *Stack {
st.Loops[level] = NewLoop(counterMax, 1)
st.Order = append(st.Order, level)
return st
}
// AddOnStartToAll adds given function taking mode and level args to OnStart in all loops.
func (st *Stack) AddOnStartToAll(name string, fun func(mode, level enums.Enum)) {
for tt, lp := range st.Loops {
lp.OnStart.Add(name, func() {
fun(st.Mode, tt)
})
}
}
// AddOnEndToAll adds given function taking mode and level args to OnEnd in all loops.
func (st *Stack) AddOnEndToAll(name string, fun func(mode, level enums.Enum)) {
for tt, lp := range st.Loops {
lp.OnEnd.Add(name, func() {
fun(st.Mode, tt)
})
}
}
// AddLevelIncr adds a new level to this Stack with a given counterMax
// number of iterations, and increment per step.
// The order in which this method is invoked is important,
// as it adds loops in order from top to bottom.
func (st *Stack) AddLevelIncr(level enums.Enum, counterMax, counterIncr int) *Stack {
st.Loops[level] = NewLoop(counterMax, counterIncr)
st.Order = append(st.Order, level)
return st
}
// LevelAbove returns the level above the given level in the stack
// returning false if this is the highest level,
// or given level does not exist in order.
func (st *Stack) LevelAbove(level enums.Enum) (enums.Enum, bool) {
for i, tt := range st.Order {
if tt == level && i > 0 {
return st.Order[i-1], true
}
}
return level, false
}
// LevelBelow returns the level below the given level in the stack
// returning false if this is the lowest level,
// or given level does not exist in order.
func (st *Stack) LevelBelow(level enums.Enum) (enums.Enum, bool) {
for i, tt := range st.Order {
if tt == level && i+1 < len(st.Order) {
return st.Order[i+1], true
}
}
return level, false
}
//////// Control
// SetStep sets stepping to given level and number of iterations.
// If numSteps == 0 then the default for the given stops
func (st *Stack) SetStep(numSteps int, stopLevel enums.Enum) {
st.StopLevel = stopLevel
lp := st.Loops[stopLevel]
if numSteps > 0 {
st.StopCount = numSteps
lp.StepCount = numSteps
} else {
numSteps = lp.StepCount
}
st.StopCount = numSteps
st.StepLevel = stopLevel
st.StepCount = numSteps
st.StopFlag = false
st.StopNext = true
}
// ClearStep clears the active stepping control state: StopNext and StopFlag.
func (st *Stack) ClearStep() {
st.StopNext = false
st.StopFlag = false
}
// Counters returns a slice of the current counter values
// for this stack, in Order.
func (st *Stack) Counters() []int {
ctrs := make([]int, len(st.Order))
for i, tm := range st.Order {
ctrs[i] = st.Loops[tm].Counter.Cur
}
return ctrs
}
// CountersString returns a string with loop level and counter values.
func (st *Stack) CountersString() string {
ctrs := ""
for _, tm := range st.Order {
ctrs += fmt.Sprintf("%s: %d ", tm.String(), st.Loops[tm].Counter.Cur)
}
return ctrs
}
// DocString returns an indented summary of the loops and functions in the Stack.
func (st *Stack) DocString() string {
var sb strings.Builder
sb.WriteString("Stack " + st.Mode.String() + ":\n")
sb.WriteString(st.Level(0).DocString(st, 0))
return sb.String()
}
// Copyright (c) 2022, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package looper
//go:generate core generate -add-types
import (
"cmp"
"slices"
"strings"
"cogentcore.org/core/enums"
"golang.org/x/exp/maps"
)
var (
// If you want to debug the flow of processing, set this to true.
PrintControlFlow = false
)
// Stacks holds data relating to multiple stacks of loops,
// as well as the logic for stepping through it.
// It also holds helper methods for constructing the data.
// It's also a control object for stepping through Stacks of Loops.
// It holds data about how the flow is going.
type Stacks struct {
// Stacks is the map of stacks by Mode.
Stacks map[enums.Enum]*Stack
// Mode has the current evaluation mode.
Mode enums.Enum
// following are internal run control state: see runLevel in run.go.
isRunning bool
lastStartedCounter map[Scope]int
internalStop bool
}
// NewStacks returns a new initialized collection of Stacks.
func NewStacks() *Stacks {
ls := &Stacks{}
ls.newInit()
return ls
}
// newInit initializes the state of the stacks, to be called on a newly created object.
func (ls *Stacks) newInit() {
ls.Stacks = map[enums.Enum]*Stack{}
ls.lastStartedCounter = map[Scope]int{}
}
//////// Run API
// Run runs the stack of loops for given mode (Train, Test, etc).
// This resets any stepping settings for this stack and runs
// until completion or stopped externally.
// Returns the level that was running when it stopped.
func (ls *Stacks) Run(mode enums.Enum) enums.Enum {
ls.Mode = mode
ls.ClearStep(mode)
return ls.Cont()
}
// ResetAndRun calls ResetCountersByMode on this mode
// and then Run. This ensures that the Stack is run from
// the start, regardless of what state it might have been in.
// Returns the level that was running when it stopped.
func (ls *Stacks) ResetAndRun(mode enums.Enum) enums.Enum {
ls.ResetCountersByMode(mode)
return ls.Run(mode)
}
// Cont continues running based on current state of the stacks.
// This is common pathway for Step and Run, which set state and
// call Cont. Programatic calling of Step can continue with Cont.
// Returns the level that was running when it stopped.
func (ls *Stacks) Cont() enums.Enum {
ls.isRunning = true
ls.internalStop = false
_, stop := ls.runLevel(0) // 0 Means the top level loop
ls.isRunning = false
return stop
}
// Step numSteps at given stopLevel. Use this if you want to do exactly one trial
// or two epochs or 50 cycles or whatever. If numSteps <= 0 then the default
// number of steps for given step level is used.
// Returns the level that was running when it stopped.
func (ls *Stacks) Step(mode enums.Enum, numSteps int, stopLevel enums.Enum) enums.Enum {
ls.Mode = mode
st := ls.Stacks[ls.Mode]
st.SetStep(numSteps, stopLevel)
return ls.Cont()
}
// ClearStep clears stepping variables from given mode,
// so it will run to completion in a subsequent Cont().
// Called by Run.
func (ls *Stacks) ClearStep(mode enums.Enum) {
st := ls.Stacks[ls.Mode]
st.ClearStep()
}
// Stop stops currently running stack of loops at given run level.
func (ls *Stacks) Stop(level enums.Enum) {
st := ls.Stacks[ls.Mode]
st.StopLevel = level
st.StopCount = 0
st.StopFlag = true
}
//////// Config API
// AddStack adds a new Stack for given mode and default step level.
func (ls *Stacks) AddStack(mode, stepLevel enums.Enum) *Stack {
st := NewStack(mode, stepLevel)
ls.Stacks[mode] = st
return st
}
// Loop returns the Loop associated with given mode and loop level.
func (ls *Stacks) Loop(mode, level enums.Enum) *Loop {
st := ls.Stacks[mode]
if st == nil {
return nil
}
return st.Loops[level]
}
// ModeStack returns the Stack for the current Mode
func (ls *Stacks) ModeStack() *Stack {
return ls.Stacks[ls.Mode]
}
// AddEventAllModes adds a new event for all modes at given loop level.
func (ls *Stacks) AddEventAllModes(level enums.Enum, name string, atCtr int, fun func()) {
for _, st := range ls.Stacks {
st.Loops[level].AddEvent(name, atCtr, fun)
}
}
// AddOnStartToAll adds given function taking mode and level args to OnStart in all stacks, loops
func (ls *Stacks) AddOnStartToAll(name string, fun func(mode, level enums.Enum)) {
for _, st := range ls.Stacks {
st.AddOnStartToAll(name, fun)
}
}
// AddOnEndToAll adds given function taking mode and level args to OnEnd in all stacks, loops
func (ls *Stacks) AddOnEndToAll(name string, fun func(mode, level enums.Enum)) {
for _, st := range ls.Stacks {
st.AddOnEndToAll(name, fun)
}
}
// AddOnStartToLoop adds given function taking mode arg to OnStart in all stacks for given loop.
func (ls *Stacks) AddOnStartToLoop(level enums.Enum, name string, fun func(mode enums.Enum)) {
for m, st := range ls.Stacks {
st.Loops[level].OnStart.Add(name, func() { fun(m) })
}
}
// AddOnEndToLoop adds given function taking mode arg to OnEnd in all stacks for given loop.
func (ls *Stacks) AddOnEndToLoop(level enums.Enum, name string, fun func(mode enums.Enum)) {
for m, st := range ls.Stacks {
st.Loops[level].OnEnd.Add(name, func() { fun(m) })
}
}
// Modes returns a sorted list of stack modes, for iterating in Mode enum value order.
func (ls *Stacks) Modes() []enums.Enum {
mds := maps.Keys(ls.Stacks)
slices.SortFunc(mds, func(a, b enums.Enum) int {
return cmp.Compare(a.Int64(), b.Int64())
})
return mds
}
//////// More detailed control API
// IsRunning is True if running.
func (ls *Stacks) IsRunning() bool {
return ls.isRunning
}
// InitMode initializes [Stack] of given mode,
// resetting counters and calling the OnInit functions.
func (ls *Stacks) InitMode(mode enums.Enum) {
ls.ResetCountersByMode(mode)
st := ls.Stacks[mode]
st.OnInit.Run()
}
// ResetCountersByMode resets counters for given mode.
func (ls *Stacks) ResetCountersByMode(mode enums.Enum) {
for sk, _ := range ls.lastStartedCounter {
skm, _ := sk.ModeLevel()
if skm == mode.Int64() {
delete(ls.lastStartedCounter, sk)
}
}
for m, st := range ls.Stacks {
if m == mode {
for _, loop := range st.Loops {
loop.Counter.Cur = 0
}
}
}
}
// Init initializes all stacks. See [Stacks.InitMode] for more info.
func (ls *Stacks) Init() {
ls.lastStartedCounter = map[Scope]int{}
for _, st := range ls.Stacks {
ls.InitMode(st.Mode)
}
}
// ResetCounters resets the Cur on all loop Counters,
// and resets the Stacks's place in the loops.
func (ls *Stacks) ResetCounters() {
ls.lastStartedCounter = map[Scope]int{}
for _, st := range ls.Stacks {
for _, loop := range st.Loops {
loop.Counter.Cur = 0
}
}
}
// ResetCountersBelow resets the Cur on all loop Counters below given level
// (inclusive), and resets the Stacks's place in the loops.
func (ls *Stacks) ResetCountersBelow(mode enums.Enum, level enums.Enum) {
for _, st := range ls.Stacks {
if st.Mode != mode {
continue
}
for lt, loop := range st.Loops {
if lt.Int64() > level.Int64() {
continue
}
loop.Counter.Cur = 0
sk := ToScope(mode, lt)
delete(ls.lastStartedCounter, sk)
}
}
}
// DocString returns an indented summary of the loops and functions in the stack.
func (ls *Stacks) DocString() string {
var sb strings.Builder
for _, st := range ls.Stacks {
sb.WriteString(st.DocString())
}
return sb.String()
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package netview
import (
"cogentcore.org/core/math32"
"github.com/emer/emergent/v2/emer"
)
// LayData maintains a record of all the data for a given layer
type LayData struct {
// the layer name
LayName string
// cached number of units
NUnits int
// the full data, in that order
Data []float32
// receiving pathway data -- shared with SendPaths
RecvPaths []*PathData
// sending pathway data -- shared with RecvPaths
SendPaths []*PathData
}
// AllocSendPaths allocates Sending pathways for given layer.
// does nothing if already allocated.
func (ld *LayData) AllocSendPaths(ly emer.Layer) {
nsp := ly.NumSendPaths()
if len(ld.SendPaths) == nsp {
for si := range ly.NumSendPaths() {
pt := ly.SendPath(si)
spd := ld.SendPaths[si]
spd.Path = pt
}
return
}
ld.SendPaths = make([]*PathData, nsp)
for si := range ly.NumSendPaths() {
pt := ly.SendPath(si)
pd := &PathData{Send: pt.SendLayer().Label(), Recv: pt.RecvLayer().Label(), Path: pt}
ld.SendPaths[si] = pd
pd.Alloc()
}
}
// FreePaths nils path data -- for NoSynDat
func (ld *LayData) FreePaths() {
ld.RecvPaths = nil
ld.SendPaths = nil
}
// PathData holds display state for a pathway
type PathData struct {
// name of sending layer
Send string
// name of recv layer
Recv string
// source pathway
Path emer.Path
// synaptic data, by variable in SynVars and number of data points
SynData []float32
}
// Alloc allocates SynData to hold number of variables * nsyn synapses.
// If already has capacity, nothing happens.
func (pd *PathData) Alloc() {
pt := pd.Path
nvar := pt.SynVarNum()
nsyn := pt.NumSyns()
nt := nvar * nsyn
if cap(pd.SynData) < nt {
pd.SynData = make([]float32, nt)
} else {
pd.SynData = pd.SynData[:nt]
}
}
// RecordData records synaptic data from given paths.
// must use sender or recv based depending on natural ordering.
func (pd *PathData) RecordData(nd *NetData) {
pt := pd.Path
vnms := pt.SynVarNames()
nvar := pt.SynVarNum()
nsyn := pt.NumSyns()
for vi := 0; vi < nvar; vi++ {
vnm := vnms[vi]
si := vi * nsyn
sv := pd.SynData[si : si+nsyn]
pt.SynValues(&sv, vnm)
nvi := nd.SynVarIndexes[vnm]
mn := &nd.SynMinVar[nvi]
mx := &nd.SynMaxVar[nvi]
for _, vl := range sv {
if !math32.IsNaN(vl) {
*mn = math32.Min(*mn, vl)
*mx = math32.Max(*mx, vl)
}
}
}
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package netview
import (
"fmt"
"image"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/math32"
"cogentcore.org/core/xyz"
"cogentcore.org/core/xyz/xyzcore"
"github.com/emer/emergent/v2/emer"
)
// Scene is a Widget for managing the 3D Scene of the NetView
type Scene struct {
xyzcore.Scene
NetView *NetView
}
func (sw *Scene) Init() {
sw.Scene.Init()
sw.On(events.MouseDown, func(e events.Event) {
sw.MouseDownEvent(e)
sw.NeedsRender()
})
sw.On(events.Scroll, func(e events.Event) {
pos := sw.Geom.ContentBBox.Min
e.SetLocalOff(e.LocalOff().Add(pos))
sw.SceneXYZ().MouseScrollEvent(e.(*events.MouseScroll))
sw.NeedsRender()
})
sw.On(events.KeyChord, func(e events.Event) {
sw.SceneXYZ().KeyChordEvent(e)
sw.NeedsRender()
})
// sw.HandleSlideEvents() // TODO: need this
}
func (sw *Scene) MouseDownEvent(e events.Event) {
pos := e.Pos().Sub(sw.Geom.ContentBBox.Min)
pt := sw.PathAtPoint(pos)
if pt != nil {
FormDialog(sw, pt, "Path: "+pt.Label())
e.SetHandled()
return
}
lay := sw.LayerLabelAtPoint(pos)
if lay != nil {
FormDialog(sw, lay, "Layer: "+lay.Label())
e.SetHandled()
return
}
lay, _, _, unIndex := sw.LayerUnitAtPoint(pos)
if lay == nil {
return
}
nv := sw.NetView
nv.Data.PathUnIndex = unIndex
nv.Data.PathLay = lay.Label()
nv.UpdateView()
e.SetHandled()
}
func (sw *Scene) WidgetTooltip(pos image.Point) (string, image.Point) {
if pos == image.Pt(-1, -1) {
return "_", image.Point{}
}
nv := sw.NetView
lpos := pos.Sub(sw.Geom.ContentBBox.Min)
pt := sw.PathAtPoint(lpos)
if pt != nil {
pe := pt.AsEmer()
tt := "[Click to edit] " + pe.Name
if pe.Doc != "" {
tt += ": " + pe.Doc
}
return tt, pos
}
lay := sw.LayerLabelAtPoint(lpos)
if lay != nil {
le := lay.AsEmer()
tt := "[Click to edit]"
if le.Doc != "" {
tt += " " + le.Doc
}
return tt, pos
}
lay, lx, ly, _ := sw.LayerUnitAtPoint(lpos)
if lay == nil {
return "", pos
}
lb := lay.AsEmer()
tt := ""
if lb.Is2D() {
idx := []int{ly, lx}
val, _, _, hasval := nv.UnitValue(lay, idx)
if !hasval {
tt = fmt.Sprintf("[%d,%d]=n/a\n", lx, ly)
} else {
tt = fmt.Sprintf("[%d,%d]=%g\n", lx, ly, val)
}
} else if lb.Is4D() {
idx, ok := lb.Index4DFrom2D(lx, ly)
if !ok {
return "", pos
}
val, _, _, hasval := nv.UnitValue(lay, idx)
if !hasval {
tt = fmt.Sprintf("[%d,%d][%d,%d]=n/a\n", idx[1], idx[0], idx[3], idx[2])
} else {
tt = fmt.Sprintf("[%d,%d][%d,%d]=%g\n", idx[1], idx[0], idx[3], idx[2], val)
}
} else {
return "", pos // not supported
}
return tt, pos
}
func (sw *Scene) LayerLabelAtPoint(pos image.Point) emer.Layer {
ns := xyz.NodesUnderPoint(sw.SceneXYZ(), pos)
for _, n := range ns {
ln, ok := n.(*LayName)
if ok {
lay, _ := ln.NetView.Net.AsEmer().EmerLayerByName(ln.Text)
if lay != nil {
return lay
}
}
}
return nil
}
func (sw *Scene) PathAtPoint(pos image.Point) emer.Path {
ns := xyz.NodesUnderPoint(sw.SceneXYZ(), pos)
for _, n := range ns {
ln, ok := n.(*xyz.Solid)
if ok && ln.Parent != nil {
gpnm := ln.Parent.AsTree().Name
pt, _ := sw.NetView.Net.AsEmer().EmerPathByName(gpnm)
if pt != nil {
return pt
}
}
}
return nil
}
func (sw *Scene) LayerUnitAtPoint(pos image.Point) (lay emer.Layer, lx, ly, unIndex int) {
sc := sw.SceneXYZ()
laysGpi := sc.ChildByName("Layers", 0)
if laysGpi == nil {
return
}
_, laysGp := xyz.AsNode(laysGpi)
nv := sw.NetView
nb := nv.Net.AsEmer()
nmin, nmax := nb.MinPos, nb.MaxPos
nsz := nmax.Sub(nmin).Sub(math32.Vec3(1, 1, 0)).Max(math32.Vec3(1, 1, 1))
nsc := math32.Vec3(1.0/nsz.X, 1.0/nsz.Y, 1.0/nsz.Z)
szc := math32.Max(nsc.X, nsc.Y)
poff := math32.Vector3Scalar(0.5)
poff.Y = -0.5
for li, lgi := range laysGp.Children {
lay = nv.Net.EmerLayer(li)
lb := lay.AsEmer()
lg := lgi.(*xyz.Group)
lp := lb.Pos.Pos
lp.Y = -lp.Y // reverse direction
lp = lp.Sub(nmin).Mul(nsc).Sub(poff)
lg.Pose.Pos.Set(lp.X, lp.Z, lp.Y)
lg.Pose.Scale.Set(nsc.X*lb.Pos.Scale, szc, nsc.Y*lb.Pos.Scale)
lo := lg.Child(0).(*LayObj)
ray := lo.RayPick(pos)
// layer is in XZ plane with norm pointing up in Y axis
// offset is 0 in local coordinates
plane := math32.Plane{Norm: math32.Vec3(0, 1, 0), Off: 0}
pt, ok := ray.IntersectPlane(plane)
if !ok || pt.Z > 0 { // Z > 0 means clicked "in front" of plane -- where labels are
continue
}
lx = int(pt.X)
ly = -int(pt.Z)
// fmt.Printf("selected unit: %v, %v\n", lx, ly)
if lx < 0 || ly < 0 {
continue
}
lshp := lb.Shape
if lb.Is2D() {
idx := []int{ly, lx}
if !lshp.IndexIsValid(idx...) {
continue
}
unIndex = lshp.IndexTo1D(idx...)
return
} else if lb.Is4D() {
idx, ok := lb.Index4DFrom2D(lx, ly)
if !ok {
continue
}
unIndex = lshp.IndexTo1D(idx...)
return
} else {
continue // not supported
}
}
lay = nil
return
}
// FormDialog opens a dialog in a new, separate window
// for viewing / editing the given struct object, in
// the context of the given ctx widget.
func FormDialog(ctx core.Widget, v any, title string) {
d := core.NewBody(title)
core.NewForm(d).SetStruct(v)
if tb, ok := v.(core.ToolbarMaker); ok {
d.AddTopBar(func(bar *core.Frame) {
core.NewToolbar(bar).Maker(tb.MakeToolbar)
})
}
d.RunWindowDialog(ctx)
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package netview
import (
"cogentcore.org/core/gpu/shape"
"cogentcore.org/core/math32"
"cogentcore.org/core/xyz"
"cogentcore.org/lab/tensor"
"github.com/emer/emergent/v2/emer"
)
// LayMesh is a xyz.Mesh that represents a layer -- it is dynamically updated using the
// Update method which only resets the essential Vertex elements.
// The geometry is literal in the layer size: 0,0,0 lower-left corner and increasing X,Z
// for the width and height of the layer, in unit (1) increments per unit..
// NetView applies an overall scaling to make it fit within the larger view.
type LayMesh struct {
xyz.MeshBase
// layer that we render
Lay emer.Layer
// current shape that has been constructed -- if same, just update
Shape tensor.Shape
// netview that we're in
View *NetView
}
// NewLayMesh adds LayMesh mesh to given scene for given layer
func NewLayMesh(sc *xyz.Scene, nv *NetView, lay emer.Layer) *LayMesh {
lm := &LayMesh{}
lm.View = nv
lm.Lay = lay
lm.Name = lay.Label()
sc.SetMesh(lm)
return lm
}
func (lm *LayMesh) MeshSize() (nVtx, nIndex int, hasColor bool) {
lm.Transparent = true
lm.HasColor = true
if lm.Lay == nil {
return 0, 0, true
}
shp := &lm.Lay.AsEmer().Shape
lm.Shape.CopyFrom(shp)
if lm.View.Options.Raster.On {
if shp.NumDims() == 4 {
lm.NumVertex, lm.NumIndex = lm.RasterSize4D()
} else {
lm.NumVertex, lm.NumIndex = lm.RasterSize2D()
}
} else {
if shp.NumDims() == 4 {
lm.NumVertex, lm.NumIndex = lm.Size4D()
} else {
lm.NumVertex, lm.NumIndex = lm.Size2D()
}
}
return lm.NumVertex, lm.NumIndex, lm.HasColor
}
func (lm *LayMesh) Size2D() (nVtx, nIndex int) {
nz := lm.Shape.DimSize(0)
nx := lm.Shape.DimSize(1)
segs := 1
vtxSz, idxSz := shape.PlaneN(segs, segs)
nVtx = vtxSz * 5 * nz * nx
nIndex = idxSz * 5 * nz * nx
return
}
func (lm *LayMesh) Size4D() (nVtx, nIndex int) {
npz := lm.Shape.DimSize(0) // p = pool
npx := lm.Shape.DimSize(1)
nuz := lm.Shape.DimSize(2) // u = unit
nux := lm.Shape.DimSize(3)
segs := 1
vtxSz, idxSz := shape.PlaneN(segs, segs)
nVtx = vtxSz * 5 * npz * npx * nuz * nux
nIndex = idxSz * 5 * npz * npx * nuz * nux
return
}
func (lm *LayMesh) Set(vtxAry, normAry, texAry, clrAry math32.ArrayF32, idxAry math32.ArrayU32) {
if lm.Lay == nil || lm.Shape.NumDims() == 0 {
return // nothing
}
if lm.View.Options.Raster.On {
if lm.View.Options.Raster.XAxis {
if lm.Shape.NumDims() == 4 {
lm.RasterSet4DX(vtxAry, normAry, texAry, clrAry, idxAry)
} else {
lm.RasterSet2DX(vtxAry, normAry, texAry, clrAry, idxAry)
}
} else {
if lm.Shape.NumDims() == 4 {
lm.RasterSet4DZ(vtxAry, normAry, texAry, clrAry, idxAry)
} else {
lm.RasterSet2DZ(vtxAry, normAry, texAry, clrAry, idxAry)
}
}
} else {
if lm.Shape.NumDims() == 4 {
lm.Set4D(vtxAry, normAry, texAry, clrAry, idxAry)
} else {
lm.Set2D(vtxAry, normAry, texAry, clrAry, idxAry)
}
}
}
// MinUnitHeight ensures that there is always at least some dimensionality
// to the unit cubes -- affects transparency rendering etc
var MinUnitHeight = float32(1.0e-6)
func (lm *LayMesh) Set2D(vtxAry, normAry, texAry, clrAry math32.ArrayF32, idxAry math32.ArrayU32) {
nz := lm.Shape.DimSize(0)
nx := lm.Shape.DimSize(1)
fnz := float32(nz)
fnx := float32(nx)
uw := lm.View.Options.UnitSize
uo := (1.0 - uw)
segs := 1
vtxSz, idxSz := shape.PlaneN(segs, segs)
pidx := 0 // plane index
pos := math32.Vector3{}
lm.View.ReadLock()
for zi := nz - 1; zi >= 0; zi-- {
z0 := uo - float32(zi+1)
for xi := 0; xi < nx; xi++ {
poff := pidx * vtxSz * 5
ioff := pidx * idxSz * 5
x0 := uo + float32(xi)
_, scaled, clr, _ := lm.View.UnitValue(lm.Lay, []int{zi, xi})
v4c := math32.NewVector4Color(clr)
shape.SetColor(clrAry, poff, 5*vtxSz, v4c)
ht := 0.5 * math32.Abs(scaled)
if ht < MinUnitHeight {
ht = MinUnitHeight
}
if scaled >= 0 {
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff, ioff, math32.X, math32.Y, -1, -1, uw, ht, x0, 0, z0, segs, segs, pos) // nz
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+1*vtxSz, ioff+1*idxSz, math32.Z, math32.Y, -1, -1, uw, ht, z0, 0, x0+uw, segs, segs, pos) // px
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+2*vtxSz, ioff+2*idxSz, math32.Z, math32.Y, 1, -1, uw, ht, z0, 0, x0, segs, segs, pos) // nx
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+3*vtxSz, ioff+3*idxSz, math32.X, math32.Z, 1, 1, uw, uw, x0, z0, ht, segs, segs, pos) // py <-
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+4*vtxSz, ioff+4*idxSz, math32.X, math32.Y, 1, -1, uw, ht, x0, 0, z0+uw, segs, segs, pos) // pz
} else {
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff, ioff, math32.X, math32.Y, 1, -1, uw, ht, x0, -ht, z0, segs, segs, pos) // nz = pz norm
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+1*vtxSz, ioff+1*idxSz, math32.Z, math32.Y, 1, -1, uw, ht, z0, -ht, x0+uw, segs, segs, pos) // px = nx norm
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+2*vtxSz, ioff+2*idxSz, math32.Z, math32.Y, 1, -1, uw, ht, z0, -ht, x0, segs, segs, pos) // nx
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+3*vtxSz, ioff+3*idxSz, math32.X, math32.Z, 1, 1, uw, uw, x0, z0, -ht, segs, segs, pos) // ny <-
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+4*vtxSz, ioff+4*idxSz, math32.X, math32.Y, 1, -1, uw, ht, x0, -ht, z0+uw, segs, segs, pos) // pz
}
pidx++
}
}
lm.View.ReadUnlock()
lm.BBox.SetBounds(math32.Vec3(0, -0.5, -fnz), math32.Vec3(fnx, 0.5, 0))
}
func (lm *LayMesh) Set4D(vtxAry, normAry, texAry, clrAry math32.ArrayF32, idxAry math32.ArrayU32) {
npz := lm.Shape.DimSize(0) // p = pool
npx := lm.Shape.DimSize(1)
nuz := lm.Shape.DimSize(2) // u = unit
nux := lm.Shape.DimSize(3)
fnpz := float32(npz)
fnpx := float32(npx)
fnuz := float32(nuz)
fnux := float32(nux)
usz := lm.View.Options.UnitSize
uo := (1.0 - usz) // offset = space
// for 4D, we build in spaces between groups without changing the overall size of layer
// by shrinking the spacing of each unit according to the spaces we introduce
xsc := (fnpx * fnux) / ((fnpx-1)*uo + (fnpx * fnux))
zsc := (fnpz * fnuz) / ((fnpz-1)*uo + (fnpz * fnuz))
xuw := xsc * usz
zuw := zsc * usz
segs := 1
vtxSz, idxSz := shape.PlaneN(segs, segs)
pidx := 0 // plane index
pos := math32.Vector3{}
lm.View.ReadLock()
for zpi := npz - 1; zpi >= 0; zpi-- {
zp0 := zsc * (-float32(zpi) * (uo + fnuz))
for xpi := 0; xpi < npx; xpi++ {
xp0 := xsc * (float32(xpi)*uo + float32(xpi)*fnux)
for zui := nuz - 1; zui >= 0; zui-- {
z0 := zp0 + zsc*(uo-float32(zui+1))
for xui := 0; xui < nux; xui++ {
poff := pidx * vtxSz * 5
ioff := pidx * idxSz * 5
x0 := xp0 + xsc*(uo+float32(xui))
_, scaled, clr, _ := lm.View.UnitValue(lm.Lay, []int{zpi, xpi, zui, xui})
v4c := math32.NewVector4Color(clr)
shape.SetColor(clrAry, poff, 5*vtxSz, v4c)
ht := 0.5 * math32.Abs(scaled)
if ht < MinUnitHeight {
ht = MinUnitHeight
}
if scaled >= 0 {
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff, ioff, math32.X, math32.Y, -1, -1, xuw, ht, x0, 0, z0, segs, segs, pos) // nz
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+1*vtxSz, ioff+1*idxSz, math32.Z, math32.Y, -1, -1, zuw, ht, z0, 0, x0+xuw, segs, segs, pos) // px
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+2*vtxSz, ioff+2*idxSz, math32.Z, math32.Y, 1, -1, zuw, ht, z0, 0, x0, segs, segs, pos) // nx
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+3*vtxSz, ioff+3*idxSz, math32.X, math32.Z, 1, 1, xuw, zuw, x0, z0, ht, segs, segs, pos) // py <-
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+4*vtxSz, ioff+4*idxSz, math32.X, math32.Y, 1, -1, xuw, ht, x0, 0, z0+zuw, segs, segs, pos) // pz
} else {
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff, ioff, math32.X, math32.Y, 1, -1, xuw, ht, x0, -ht, z0, segs, segs, pos) // nz = pz norm
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+1*vtxSz, ioff+1*idxSz, math32.Z, math32.Y, 1, -1, zuw, ht, z0, -ht, x0+xuw, segs, segs, pos) // px = nx norm
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+2*vtxSz, ioff+2*idxSz, math32.Z, math32.Y, 1, -1, zuw, ht, z0, -ht, x0, segs, segs, pos) // nx
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+3*vtxSz, ioff+3*idxSz, math32.X, math32.Z, 1, 1, xuw, zuw, x0, z0, -ht, segs, segs, pos) // ny <-
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+4*vtxSz, ioff+4*idxSz, math32.X, math32.Y, 1, -1, xuw, ht, x0, -ht, z0+zuw, segs, segs, pos) // pz
}
pidx++
}
}
}
}
lm.View.ReadUnlock()
lm.BBox.SetBounds(math32.Vec3(0, -0.5, -fnpz*fnuz), math32.Vec3(fnpx*fnux, 0.5, 0))
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package netview
import (
"cogentcore.org/core/gpu/shape"
"cogentcore.org/core/math32"
)
func (lm *LayMesh) RasterSize2D() (nVtx, nIndex int) {
ss := lm.Lay.AsEmer().GetSampleShape()
nuz := ss.DimSize(0)
nux := ss.DimSize(1)
nz := nuz*nux + nuz - 1
nx := lm.View.Options.Raster.Max + 1
segs := 1
vtxSz, idxSz := shape.PlaneN(segs, segs)
nVtx = vtxSz * 5 * nz * nx
nIndex = idxSz * 5 * nz * nx
return
}
func (lm *LayMesh) RasterSize4D() (nVtx, nIndex int) {
ss := lm.Lay.AsEmer().GetSampleShape()
npz := ss.DimSize(0) // p = pool
npx := ss.DimSize(1)
nuz := ss.DimSize(2) // u = unit
nux := ss.DimSize(3)
nz := nuz*nux + nuz - 1
nx := lm.View.Options.Raster.Max + 1
segs := 1
vtxSz, idxSz := shape.PlaneN(segs, segs)
nVtx = vtxSz * 5 * npz * npx * nz * nx
nIndex = idxSz * 5 * npz * npx * nz * nx
return
}
func (lm *LayMesh) RasterSet2DX(vtxAry, normAry, texAry, clrAry math32.ArrayF32, idxAry math32.ArrayU32) {
ss := lm.Lay.AsEmer().GetSampleShape()
nuz := ss.DimSize(0)
nux := ss.DimSize(1)
nz := nuz*nux + nuz - 1
nx := lm.View.Options.Raster.Max + 1
htsc := 0.5 * lm.View.Options.Raster.UnitHeight
fnoz := float32(lm.Shape.DimSize(0))
fnox := float32(lm.Shape.DimSize(1))
fnuz := float32(nuz)
fnux := float32(nux)
fnz := float32(nz)
fnx := float32(nx)
usz := lm.View.Options.Raster.UnitSize
uo := (1.0 - usz)
xsc := fnux / fnx
zsc := fnuz / fnz
// rescale rep -> full size
xsc *= fnox / fnux
zsc *= fnoz / fnuz
xuw := xsc * usz
zuw := zsc * usz
segs := 1
vtxSz, idxSz := shape.PlaneN(segs, segs)
pidx := 0 // plane index
pos := math32.Vector3{}
curRast, _ := lm.View.Data.RasterCtr(-1)
lm.View.ReadLock()
for zi := nz - 1; zi >= 0; zi-- {
z0 := uo - zsc*float32(zi+1)
uy := zi / (nux + 1)
ux := zi % (nux + 1)
xoff := 0
for xi := 0; xi < nx; xi++ {
poff := pidx * vtxSz * 5
ioff := pidx * idxSz * 5
x0 := uo + xsc*float32(xi)
_, scaled, clr, _ := lm.View.UnitValRaster(lm.Lay, []int{uy, ux}, xi-xoff)
if xi-1 == curRast || ux >= nux {
clr = NilColor
scaled = 0
}
if xi-1 == curRast {
xoff++
}
v4c := math32.NewVector4Color(clr)
shape.SetColor(clrAry, poff, 5*vtxSz, v4c)
ht := htsc * math32.Abs(scaled)
if ht < MinUnitHeight {
ht = MinUnitHeight
}
if scaled >= 0 {
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff, ioff, math32.X, math32.Y, -1, -1, xuw, ht, x0, 0, z0, segs, segs, pos) // nz
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+1*vtxSz, ioff+1*idxSz, math32.Z, math32.Y, -1, -1, zuw, ht, z0, 0, x0+xuw, segs, segs, pos) // px
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+2*vtxSz, ioff+2*idxSz, math32.Z, math32.Y, 1, -1, zuw, ht, z0, 0, x0, segs, segs, pos) // nx
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+3*vtxSz, ioff+3*idxSz, math32.X, math32.Z, 1, 1, xuw, zuw, x0, z0, ht, segs, segs, pos) // py <-
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+4*vtxSz, ioff+4*idxSz, math32.X, math32.Y, 1, -1, xuw, ht, x0, 0, z0+zuw, segs, segs, pos) // pz
} else {
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff, ioff, math32.X, math32.Y, 1, -1, xuw, ht, x0, -ht, z0, segs, segs, pos) // nz = pz norm
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+1*vtxSz, ioff+1*idxSz, math32.Z, math32.Y, 1, -1, zuw, ht, z0, -ht, x0+xuw, segs, segs, pos) // px = nx norm
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+2*vtxSz, ioff+2*idxSz, math32.Z, math32.Y, 1, -1, zuw, ht, z0, -ht, x0, segs, segs, pos) // nx
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+3*vtxSz, ioff+3*idxSz, math32.X, math32.Z, 1, 1, xuw, zuw, x0, z0, -ht, segs, segs, pos) // ny <-
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+4*vtxSz, ioff+4*idxSz, math32.X, math32.Y, 1, -1, xuw, ht, x0, -ht, z0+zuw, segs, segs, pos) // pz
}
pidx++
}
}
lm.View.ReadUnlock()
lm.BBox.SetBounds(math32.Vec3(0, -0.5, -fnz), math32.Vec3(fnx, 0.5, 0))
}
func (lm *LayMesh) RasterSet2DZ(vtxAry, normAry, texAry, clrAry math32.ArrayF32, idxAry math32.ArrayU32) {
ss := lm.Lay.AsEmer().GetSampleShape()
nuz := ss.DimSize(0)
nux := ss.DimSize(1)
nx := nuz*nux + nuz - 1
nz := lm.View.Options.Raster.Max + 1
htsc := 0.5 * lm.View.Options.Raster.UnitHeight
fnoz := float32(lm.Shape.DimSize(0))
fnox := float32(lm.Shape.DimSize(1))
fnuz := float32(nuz)
fnux := float32(nux)
fnz := float32(nz)
fnx := float32(nx)
usz := lm.View.Options.Raster.UnitSize
uo := (1.0 - usz)
xsc := fnux / fnx
zsc := fnuz / fnz
// rescale rep -> full size
xsc *= fnox / fnux
zsc *= fnoz / fnuz
xuw := xsc * usz
zuw := zsc * usz
segs := 1
vtxSz, idxSz := shape.PlaneN(segs, segs)
pidx := 0 // plane index
pos := math32.Vector3{}
curRast, _ := lm.View.Data.RasterCtr(-1)
lm.View.ReadLock()
zoff := 1
for zi := nz - 1; zi >= 0; zi-- {
z0 := uo - zsc*float32(zi+1)
for xi := 0; xi < nx; xi++ {
uy := xi / (nux + 1)
ux := xi % (nux + 1)
poff := pidx * vtxSz * 5
ioff := pidx * idxSz * 5
x0 := uo + xsc*float32(xi)
_, scaled, clr, _ := lm.View.UnitValRaster(lm.Lay, []int{uy, ux}, zi-zoff)
if zi-1 == curRast || ux >= nux {
clr = NilColor
scaled = 0
}
if zi-1 == curRast {
zoff = 0
}
v4c := math32.NewVector4Color(clr)
shape.SetColor(clrAry, poff, 5*vtxSz, v4c)
ht := htsc * math32.Abs(scaled)
if ht < MinUnitHeight {
ht = MinUnitHeight
}
if scaled >= 0 {
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff, ioff, math32.X, math32.Y, -1, -1, xuw, ht, x0, 0, z0, segs, segs, pos) // nz
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+1*vtxSz, ioff+1*idxSz, math32.Z, math32.Y, -1, -1, zuw, ht, z0, 0, x0+xuw, segs, segs, pos) // px
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+2*vtxSz, ioff+2*idxSz, math32.Z, math32.Y, 1, -1, zuw, ht, z0, 0, x0, segs, segs, pos) // nx
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+3*vtxSz, ioff+3*idxSz, math32.X, math32.Z, 1, 1, xuw, zuw, x0, z0, ht, segs, segs, pos) // py <-
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+4*vtxSz, ioff+4*idxSz, math32.X, math32.Y, 1, -1, xuw, ht, x0, 0, z0+zuw, segs, segs, pos) // pz
} else {
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff, ioff, math32.X, math32.Y, 1, -1, xuw, ht, x0, -ht, z0, segs, segs, pos) // nz = pz norm
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+1*vtxSz, ioff+1*idxSz, math32.Z, math32.Y, 1, -1, zuw, ht, z0, -ht, x0+xuw, segs, segs, pos) // px = nx norm
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+2*vtxSz, ioff+2*idxSz, math32.Z, math32.Y, 1, -1, zuw, ht, z0, -ht, x0, segs, segs, pos) // nx
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+3*vtxSz, ioff+3*idxSz, math32.X, math32.Z, 1, 1, xuw, zuw, x0, z0, -ht, segs, segs, pos) // ny <-
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+4*vtxSz, ioff+4*idxSz, math32.X, math32.Y, 1, -1, xuw, ht, x0, -ht, z0+zuw, segs, segs, pos) // pz
}
pidx++
}
}
lm.View.ReadUnlock()
lm.BBox.SetBounds(math32.Vec3(0, -0.5, -fnz), math32.Vec3(fnx, 0.5, 0))
}
func (lm *LayMesh) RasterSet4DX(vtxAry, normAry, texAry, clrAry math32.ArrayF32, idxAry math32.ArrayU32) {
ss := lm.Lay.AsEmer().GetSampleShape()
npz := ss.DimSize(0) // p = pool
npx := ss.DimSize(1)
nuz := ss.DimSize(2) // u = unit
nux := ss.DimSize(3)
nz := nuz*nux + nuz - 1
nx := lm.View.Options.Raster.Max + 1
htsc := 0.5 * lm.View.Options.Raster.UnitHeight
fnpoz := float32(lm.Shape.DimSize(0))
fnpox := float32(lm.Shape.DimSize(1))
fnpz := float32(npz)
fnpx := float32(npx)
fnuz := float32(nuz)
fnux := float32(nux)
fnx := float32(nx)
fnz := float32(nz)
usz := lm.View.Options.UnitSize
uo := 2.0 * (1.0 - usz) // offset = space
// for 4D, we build in spaces between groups without changing the overall size of layer
// by shrinking the spacing of each unit according to the spaces we introduce
// these scales are for overall group positioning
xsc := (fnpx * fnux) / ((fnpx-1)*uo + (fnpx * fnux))
zsc := (fnpz * fnuz) / ((fnpz-1)*uo + (fnpz * fnuz))
// rescale rep -> full size
xsc *= fnpox / fnpx
zsc *= fnpoz / fnpz
// these are for the raster within
xscr := xsc * (fnux / fnx)
zscr := zsc * (fnuz / fnz)
uszr := lm.View.Options.Raster.UnitSize
uor := (1.0 - uszr) // offset = space
xuw := xscr * uszr
zuw := zscr * uszr
segs := 1
vtxSz, idxSz := shape.PlaneN(segs, segs)
pidx := 0 // plane index
pos := math32.Vector3{}
curRast, _ := lm.View.Data.RasterCtr(-1)
lm.View.ReadLock()
for zpi := npz - 1; zpi >= 0; zpi-- {
zp0 := zsc * (-float32(zpi) * (uo + fnuz))
for xpi := 0; xpi < npx; xpi++ {
xp0 := xsc * (float32(xpi)*uo + float32(xpi)*fnux)
for zi := nz - 1; zi >= 0; zi-- {
z0 := zp0 + zscr*(uor-float32(zi+1))
uy := zi / (nux + 1)
ux := zi % (nux + 1)
xoff := 0
for xi := 0; xi < nx; xi++ {
poff := pidx * vtxSz * 5
ioff := pidx * idxSz * 5
x0 := xp0 + xscr*(uor+float32(xi))
_, scaled, clr, _ := lm.View.UnitValRaster(lm.Lay, []int{zpi, xpi, uy, ux}, xi-xoff)
if xi-1 == curRast || ux >= nux {
clr = NilColor
scaled = 0
}
if xi-1 == curRast {
xoff++
}
v4c := math32.NewVector4Color(clr)
shape.SetColor(clrAry, poff, 5*vtxSz, v4c)
ht := htsc * math32.Abs(scaled)
if ht < MinUnitHeight {
ht = MinUnitHeight
}
if scaled >= 0 {
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff, ioff, math32.X, math32.Y, -1, -1, xuw, ht, x0, 0, z0, segs, segs, pos) // nz
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+1*vtxSz, ioff+1*idxSz, math32.Z, math32.Y, -1, -1, zuw, ht, z0, 0, x0+xuw, segs, segs, pos) // px
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+2*vtxSz, ioff+2*idxSz, math32.Z, math32.Y, 1, -1, zuw, ht, z0, 0, x0, segs, segs, pos) // nx
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+3*vtxSz, ioff+3*idxSz, math32.X, math32.Z, 1, 1, xuw, zuw, x0, z0, ht, segs, segs, pos) // py <-
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+4*vtxSz, ioff+4*idxSz, math32.X, math32.Y, 1, -1, xuw, ht, x0, 0, z0+zuw, segs, segs, pos) // pz
} else {
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff, ioff, math32.X, math32.Y, 1, -1, xuw, ht, x0, -ht, z0, segs, segs, pos) // nz = pz norm
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+1*vtxSz, ioff+1*idxSz, math32.Z, math32.Y, 1, -1, zuw, ht, z0, -ht, x0+xuw, segs, segs, pos) // px = nx norm
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+2*vtxSz, ioff+2*idxSz, math32.Z, math32.Y, 1, -1, zuw, ht, z0, -ht, x0, segs, segs, pos) // nx
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+3*vtxSz, ioff+3*idxSz, math32.X, math32.Z, 1, 1, xuw, zuw, x0, z0, -ht, segs, segs, pos) // ny <-
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+4*vtxSz, ioff+4*idxSz, math32.X, math32.Y, 1, -1, xuw, ht, x0, -ht, z0+zuw, segs, segs, pos) // pz
}
pidx++
}
}
}
}
lm.View.ReadUnlock()
lm.BBox.SetBounds(math32.Vec3(0, -0.5, -fnpoz*fnuz), math32.Vec3(fnpox*fnux, 0.5, 0))
}
func (lm *LayMesh) RasterSet4DZ(vtxAry, normAry, texAry, clrAry math32.ArrayF32, idxAry math32.ArrayU32) {
ss := lm.Lay.AsEmer().GetSampleShape()
npz := ss.DimSize(0) // p = pool
npx := ss.DimSize(1)
nuz := ss.DimSize(2) // u = unit
nux := ss.DimSize(3)
nx := nuz*nux + nuz - 1
nz := lm.View.Options.Raster.Max + 1
htsc := 0.5 * lm.View.Options.Raster.UnitHeight
fnpoz := float32(lm.Shape.DimSize(0))
fnpox := float32(lm.Shape.DimSize(1))
fnpz := float32(npz)
fnpx := float32(npx)
fnuz := float32(nuz)
fnux := float32(nux)
fnx := float32(nx)
fnz := float32(nz)
usz := lm.View.Options.UnitSize
uo := 2.0 * (1.0 - usz) // offset = space
// for 4D, we build in spaces between groups without changing the overall size of layer
// by shrinking the spacing of each unit according to the spaces we introduce
// these scales are for overall group positioning
xsc := (fnpx * fnux) / ((fnpx-1)*uo + (fnpx * fnux))
zsc := (fnpz * fnuz) / ((fnpz-1)*uo + (fnpz * fnuz))
// rescale rep -> full size
xsc *= fnpox / fnpx
zsc *= fnpoz / fnpz
// these are for the raster within
xscr := xsc * (fnux / fnx)
zscr := zsc * (fnuz / fnz)
uszr := lm.View.Options.Raster.UnitSize
uor := (1.0 - uszr) // offset = space
xuw := xscr * uszr
zuw := zscr * uszr
segs := 1
vtxSz, idxSz := shape.PlaneN(segs, segs)
pidx := 0 // plane index
pos := math32.Vector3{}
curRast, _ := lm.View.Data.RasterCtr(-1)
lm.View.ReadLock()
for zpi := npz - 1; zpi >= 0; zpi-- {
zp0 := zsc * (-float32(zpi) * (uo + fnuz))
for xpi := 0; xpi < npx; xpi++ {
xp0 := xsc * (float32(xpi)*uo + float32(xpi)*fnux)
zoff := 1
for zi := nz - 1; zi >= 0; zi-- {
z0 := zp0 + zscr*(uor-float32(zi+1))
for xi := 0; xi < nx; xi++ {
uy := xi / (nux + 1)
ux := xi % (nux + 1)
poff := pidx * vtxSz * 5
ioff := pidx * idxSz * 5
x0 := xp0 + xscr*(uor+float32(xi))
_, scaled, clr, _ := lm.View.UnitValRaster(lm.Lay, []int{zpi, xpi, uy, ux}, zi-zoff)
if zi-1 == curRast || ux >= nux {
clr = NilColor
scaled = 0
}
if zi-1 == curRast {
zoff = 0
}
v4c := math32.NewVector4Color(clr)
shape.SetColor(clrAry, poff, 5*vtxSz, v4c)
ht := htsc * math32.Abs(scaled)
if ht < MinUnitHeight {
ht = MinUnitHeight
}
if scaled >= 0 {
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff, ioff, math32.X, math32.Y, -1, -1, xuw, ht, x0, 0, z0, segs, segs, pos) // nz
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+1*vtxSz, ioff+1*idxSz, math32.Z, math32.Y, -1, -1, zuw, ht, z0, 0, x0+xuw, segs, segs, pos) // px
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+2*vtxSz, ioff+2*idxSz, math32.Z, math32.Y, 1, -1, zuw, ht, z0, 0, x0, segs, segs, pos) // nx
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+3*vtxSz, ioff+3*idxSz, math32.X, math32.Z, 1, 1, xuw, zuw, x0, z0, ht, segs, segs, pos) // py <-
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+4*vtxSz, ioff+4*idxSz, math32.X, math32.Y, 1, -1, xuw, ht, x0, 0, z0+zuw, segs, segs, pos) // pz
} else {
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff, ioff, math32.X, math32.Y, 1, -1, xuw, ht, x0, -ht, z0, segs, segs, pos) // nz = pz norm
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+1*vtxSz, ioff+1*idxSz, math32.Z, math32.Y, 1, -1, zuw, ht, z0, -ht, x0+xuw, segs, segs, pos) // px = nx norm
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+2*vtxSz, ioff+2*idxSz, math32.Z, math32.Y, 1, -1, zuw, ht, z0, -ht, x0, segs, segs, pos) // nx
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+3*vtxSz, ioff+3*idxSz, math32.X, math32.Z, 1, 1, xuw, zuw, x0, z0, -ht, segs, segs, pos) // ny <-
shape.SetPlane(vtxAry, normAry, texAry, idxAry, poff+4*vtxSz, ioff+4*idxSz, math32.X, math32.Y, 1, -1, xuw, ht, x0, -ht, z0+zuw, segs, segs, pos) // pz
}
pidx++
}
}
}
}
lm.View.ReadUnlock()
lm.BBox.SetBounds(math32.Vec3(0, -0.5, -fnpoz*fnuz), math32.Vec3(fnpox*fnux, 0.5, 0))
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package netview
import (
"bufio"
"compress/gzip"
"encoding/json"
"fmt"
"io"
"math"
"os"
"path/filepath"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/core"
"cogentcore.org/core/math32"
"cogentcore.org/lab/lab"
"cogentcore.org/lab/plot"
"cogentcore.org/lab/plotcore"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorfs"
"github.com/emer/emergent/v2/emer"
"github.com/emer/emergent/v2/ringidx"
)
// NetData maintains a record of all the network data that has been displayed
// up to a given maximum number of records (updates), using efficient ring index logic
// with no copying to store in fixed-sized buffers.
type NetData struct { //types:add
// the network that we're viewing
Net emer.Network `json:"-"`
// copied from Params -- do not record synapse level data -- turn this on for very large networks where recording the entire synaptic state would be prohibitive
NoSynData bool
// name of the layer with unit for viewing pathways (connection / synapse-level values)
PathLay string
// 1D index of unit within PathLay for for viewing pathways
PathUnIndex int
// copied from NetView Params: if non-empty, this is the type pathway to show when there are multiple pathways from the same layer -- e.g., Inhib, Lateral, Forward, etc
PathType string `edit:"-"`
// the list of unit variables saved
UnVars []string
// index of each variable in the Vars slice
UnVarIndexes map[string]int
// the list of synaptic variables saved
SynVars []string
// index of synaptic variable in the SynVars slice
SynVarIndexes map[string]int
// the circular ring index -- Max here is max number of values to store, Len is number stored, and Index(Len-1) is the most recent one, etc
Ring ringidx.Index
// max data parallel data per unit
MaxData int
// the layer data -- map keyed by layer name
LayData map[string]*LayData
// unit var min values for each Ring.Max * variable
UnMinPer []float32
// unit var max values for each Ring.Max * variable
UnMaxPer []float32
// min values for unit variables
UnMinVar []float32
// max values for unit variables
UnMaxVar []float32
// min values for syn variables
SynMinVar []float32
// max values for syn variables
SynMaxVar []float32
// counter strings
Counters []string
// raster counter values
RasterCtrs []int
// map of raster counter values to record numbers
RasterMap map[int]int
// dummy raster counter when passed a -1 -- increments and wraps around
RastCtr int
}
// Init initializes the main params and configures the data
func (nd *NetData) Init(net emer.Network, max int, noSynData bool, maxData int) {
nd.Net = net
nd.Ring.Max = max
nd.MaxData = maxData
nd.NoSynData = noSynData
nd.Config()
nd.RastCtr = 0
nd.RasterMap = make(map[int]int)
}
// Config configures the data storage for given network
// only re-allocates if needed.
func (nd *NetData) Config() {
nlay := nd.Net.NumLayers()
if nlay == 0 {
return
}
if nd.Ring.Max == 0 {
nd.Ring.Max = 2
}
rmax := nd.Ring.Max
if nd.Ring.Len > rmax {
nd.Ring.Reset()
}
nvars := nd.Net.UnitVarNames()
vlen := len(nvars)
if len(nd.UnVars) != vlen {
nd.UnVars = nvars
nd.UnVarIndexes = make(map[string]int, vlen)
for vi, vn := range nd.UnVars {
nd.UnVarIndexes[vn] = vi
}
}
svars := nd.Net.SynVarNames()
svlen := len(svars)
if len(nd.SynVars) != svlen {
nd.SynVars = svars
nd.SynVarIndexes = make(map[string]int, svlen)
for vi, vn := range nd.SynVars {
nd.SynVarIndexes[vn] = vi
}
}
makeData:
if len(nd.LayData) != nlay {
nd.LayData = make(map[string]*LayData, nlay)
for li := range nlay {
lay := nd.Net.EmerLayer(li).AsEmer()
nm := lay.Name
ld := &LayData{LayName: nm, NUnits: lay.Shape.Len()}
nd.LayData[nm] = ld
if nd.NoSynData {
ld.FreePaths()
} else {
ld.AllocSendPaths(lay.EmerLayer)
}
}
if !nd.NoSynData {
for li := range nlay {
rlay := nd.Net.EmerLayer(li)
rld := nd.LayData[rlay.Label()]
rld.RecvPaths = make([]*PathData, rlay.NumRecvPaths())
for ri := 0; ri < rlay.NumRecvPaths(); ri++ {
rpj := rlay.RecvPath(ri)
slay := rpj.SendLayer()
sld := nd.LayData[slay.Label()]
for _, spj := range sld.SendPaths {
if spj.Path == rpj {
rld.RecvPaths[ri] = spj // link
}
}
}
}
}
} else {
for li := range nlay {
lay := nd.Net.EmerLayer(li)
ld := nd.LayData[lay.Label()]
if nd.NoSynData {
ld.FreePaths()
} else {
ld.AllocSendPaths(lay)
}
}
}
vmax := vlen * rmax * nd.MaxData
for li := range nlay {
lay := nd.Net.EmerLayer(li).AsEmer()
nm := lay.Name
ld, ok := nd.LayData[nm]
if !ok {
nd.LayData = nil
goto makeData
}
ld.NUnits = lay.Shape.Len()
nu := ld.NUnits
ltot := vmax * nu
if len(ld.Data) != ltot {
ld.Data = make([]float32, ltot)
}
}
if len(nd.UnMinPer) != vmax {
nd.UnMinPer = make([]float32, vmax)
nd.UnMaxPer = make([]float32, vmax)
}
if len(nd.UnMinVar) != vlen {
nd.UnMinVar = make([]float32, vlen)
nd.UnMaxVar = make([]float32, vlen)
}
if len(nd.SynMinVar) != svlen {
nd.SynMinVar = make([]float32, svlen)
nd.SynMaxVar = make([]float32, svlen)
}
if len(nd.Counters) != rmax {
nd.Counters = make([]string, rmax)
nd.RasterCtrs = make([]int, rmax)
}
}
// Record records the current full set of data from the network,
// and the given counters string (displayed at bottom of window)
// and raster counter value -- if negative, then an internal
// wraping-around counter is used.
func (nd *NetData) Record(ctrs string, rastCtr, rastMax int) {
nlay := nd.Net.NumLayers()
if nlay == 0 {
return
}
nd.Config() // inexpensive if no diff, and safe..
vlen := len(nd.UnVars)
nd.Ring.Add(1)
lidx := nd.Ring.LastIndex()
maxData := nd.MaxData
if rastCtr < 0 {
rastCtr = nd.RastCtr
nd.RastCtr++
if nd.RastCtr >= rastMax {
nd.RastCtr = 0
}
}
nd.Counters[lidx] = ctrs
nd.RasterCtrs[lidx] = rastCtr
nd.RasterMap[rastCtr] = lidx
mmidx := lidx * vlen
for vi := range nd.UnVars {
nd.UnMinPer[mmidx+vi] = math.MaxFloat32
nd.UnMaxPer[mmidx+vi] = -math.MaxFloat32
}
for li := range nlay {
lay := nd.Net.EmerLayer(li).AsEmer()
laynm := lay.Name
ld := nd.LayData[laynm]
nu := lay.Shape.Len()
nvu := vlen * maxData * nu
for vi, vnm := range nd.UnVars {
mn := &nd.UnMinPer[mmidx+vi]
mx := &nd.UnMaxPer[mmidx+vi]
for di := 0; di < maxData; di++ {
idx := lidx*nvu + vi*maxData*nu + di*nu
dvals := ld.Data[idx : idx+nu]
lay.UnitValues(&dvals, vnm, di)
for ui := range dvals {
vl := dvals[ui]
if !math32.IsNaN(vl) {
*mn = math32.Min(*mn, vl)
*mx = math32.Max(*mx, vl)
}
}
}
}
}
nd.UpdateUnVarRange()
}
// RecordLastCtrs records just the last counter string to be the given string
// overwriting what was there before.
func (nd *NetData) RecordLastCtrs(ctrs string) {
lidx := nd.Ring.LastIndex()
nd.Counters[lidx] = ctrs
}
// UpdateUnVarRange updates the range for unit variables, integrating over
// the entire range of stored values, so it is valid when iterating
// over history.
func (nd *NetData) UpdateUnVarRange() {
vlen := len(nd.UnVars)
rlen := nd.Ring.Len
for vi := range nd.UnVars {
vmn := &nd.UnMinVar[vi]
vmx := &nd.UnMaxVar[vi]
*vmn = math.MaxFloat32
*vmx = -math.MaxFloat32
for ri := 0; ri < rlen; ri++ {
ridx := nd.Ring.Index(ri)
mmidx := ridx * vlen
mn := nd.UnMinPer[mmidx+vi]
mx := nd.UnMaxPer[mmidx+vi]
*vmn = math32.Min(*vmn, mn)
*vmx = math32.Max(*vmx, mx)
}
}
}
// VarRange returns the current min, max range for given variable.
// Returns false if not found or no data.
func (nd *NetData) VarRange(vnm string) (float32, float32, bool) {
if nd.Ring.Len == 0 {
return 0, 0, false
}
if strings.HasPrefix(vnm, "r.") || strings.HasPrefix(vnm, "s.") {
vnm = vnm[2:]
vi, ok := nd.SynVarIndexes[vnm]
if !ok {
return 0, 0, false
}
return nd.SynMinVar[vi], nd.SynMaxVar[vi], true
}
vi, ok := nd.UnVarIndexes[vnm]
if !ok {
return 0, 0, false
}
return nd.UnMinVar[vi], nd.UnMaxVar[vi], true
}
// RecordSyns records synaptic data -- stored separate from unit data
// and only needs to be called when synaptic values are updated.
// Should be done when the DWt values have been computed, before
// updating Wts and zeroing.
// NetView displays this recorded data when Update is next called.
func (nd *NetData) RecordSyns() {
if nd.NoSynData {
return
}
nlay := nd.Net.NumLayers()
if nlay == 0 {
return
}
nd.Config() // inexpensive if no diff, and safe..
for vi := range nd.SynVars {
nd.SynMinVar[vi] = math.MaxFloat32
nd.SynMaxVar[vi] = -math.MaxFloat32
}
for li := range nlay {
lay := nd.Net.EmerLayer(li)
laynm := lay.Label()
ld := nd.LayData[laynm]
for si := 0; si < lay.NumSendPaths(); si++ {
spd := ld.SendPaths[si]
spd.RecordData(nd)
}
}
}
// RecIndex returns record index for given record number,
// which is -1 for current (last) record, or in [0..Len-1] for prior records.
func (nd *NetData) RecIndex(recno int) int {
ridx := nd.Ring.LastIndex()
if nd.Ring.IndexIsValid(recno) {
ridx = nd.Ring.Index(recno)
}
return ridx
}
// CounterRec returns counter string for given record,
// which is -1 for current (last) record, or in [0..Len-1] for prior records.
func (nd *NetData) CounterRec(recno int) string {
if nd.Ring.Len == 0 {
return ""
}
ridx := nd.RecIndex(recno)
return nd.Counters[ridx]
}
// UnitVal returns the value for given layer, variable name, unit index, data parallel idx di,
// and record number, which is -1 for current (last) record, or in [0..Len-1] for prior records.
// Returns false if value unavailable for any reason (including recorded as such as NaN).
func (nd *NetData) UnitValue(laynm string, vnm string, uidx1d int, recno int, di int) (float32, bool) {
if nd.Ring.Len == 0 {
return 0, false
}
ridx := nd.RecIndex(recno)
return nd.UnitValueIndex(laynm, vnm, uidx1d, ridx, di)
}
// RasterCtr returns the raster counter value at given record number (-1 = current)
func (nd *NetData) RasterCtr(recno int) (int, bool) {
if nd.Ring.Len == 0 {
return 0, false
}
ridx := nd.RecIndex(recno)
return nd.RasterCtrs[ridx], true
}
// UnitValRaster returns the value for given layer, variable name, unit index, and
// raster counter number.
// Returns false if value unavailable for any reason (including recorded as such as NaN).
func (nd *NetData) UnitValRaster(laynm string, vnm string, uidx1d int, rastCtr int, di int) (float32, bool) {
ridx, has := nd.RasterMap[rastCtr]
if !has {
return 0, false
}
return nd.UnitValueIndex(laynm, vnm, uidx1d, ridx, di)
}
// UnitValueIndex returns the value for given layer, variable name, unit index, stored idx,
// and data parallel index.
// Returns false if value unavailable for any reason (including recorded as such as NaN).
func (nd *NetData) UnitValueIndex(laynm string, vnm string, uidx1d int, ridx int, di int) (float32, bool) {
if strings.HasPrefix(vnm, "r.") {
svar := vnm[2:]
return nd.RecvUnitValue(laynm, svar, uidx1d)
} else if strings.HasPrefix(vnm, "s.") {
svar := vnm[2:]
return nd.SendUnitValue(laynm, svar, uidx1d)
}
vi, ok := nd.UnVarIndexes[vnm]
if !ok {
return 0, false
}
vlen := len(nd.UnVars)
ld, ok := nd.LayData[laynm]
if !ok {
return 0, false
}
nu := ld.NUnits
nvu := vlen * nd.MaxData * nu
idx := ridx*nvu + vi*nd.MaxData*nu + di*nu + uidx1d
val := ld.Data[idx]
if math32.IsNaN(val) {
return 0, false
}
return val, true
}
// RecvUnitVal returns the value for given layer, variable name, unit index,
// for receiving pathway variable, based on recorded synaptic pathway data.
// Returns false if value unavailable for any reason (including recorded as such as NaN).
func (nd *NetData) RecvUnitValue(laynm string, vnm string, uidx1d int) (float32, bool) {
ld, ok := nd.LayData[laynm]
if nd.NoSynData || !ok || nd.PathLay == "" {
return 0, false
}
recvLay := errors.Ignore1(nd.Net.AsEmer().EmerLayerByName(nd.PathLay)).AsEmer()
if recvLay == nil {
return 0, false
}
var pj emer.Path
var err error
if nd.PathType != "" {
pj, err = recvLay.RecvPathBySendNameType(laynm, nd.PathType)
if pj == nil {
pj, err = recvLay.RecvPathBySendName(laynm)
}
} else {
pj, err = recvLay.RecvPathBySendName(laynm)
}
if pj == nil {
return 0, false
}
var spd *PathData
for _, pd := range ld.SendPaths {
if pd.Path == pj {
spd = pd
break
}
}
if spd == nil {
return 0, false
}
varIndex, err := pj.SynVarIndex(vnm)
if err != nil {
return 0, false
}
synIndex := pj.SynIndex(uidx1d, nd.PathUnIndex)
if synIndex < 0 {
return 0, false
}
nsyn := pj.NumSyns()
val := spd.SynData[varIndex*nsyn+synIndex]
return val, true
}
// SendUnitVal returns the value for given layer, variable name, unit index,
// for sending pathway variable, based on recorded synaptic pathway data.
// Returns false if value unavailable for any reason (including recorded as such as NaN).
func (nd *NetData) SendUnitValue(laynm string, vnm string, uidx1d int) (float32, bool) {
ld, ok := nd.LayData[laynm]
if nd.NoSynData || !ok || nd.PathLay == "" {
return 0, false
}
sendLay := errors.Ignore1(nd.Net.AsEmer().EmerLayerByName(nd.PathLay)).AsEmer()
if sendLay == nil {
return 0, false
}
var pj emer.Path
var err error
if nd.PathType != "" {
pj, err = sendLay.SendPathByRecvNameType(laynm, nd.PathType)
if pj == nil {
pj, err = sendLay.SendPathByRecvName(laynm)
}
} else {
pj, err = sendLay.SendPathByRecvName(laynm)
}
if pj == nil {
return 0, false
}
var rpd *PathData
for _, pd := range ld.RecvPaths {
if pd.Path == pj {
rpd = pd
break
}
}
if rpd == nil {
return 0, false
}
varIndex, err := pj.SynVarIndex(vnm)
if err != nil {
return 0, false
}
synIndex := pj.SynIndex(nd.PathUnIndex, uidx1d)
if synIndex < 0 {
return 0, false
}
nsyn := pj.NumSyns()
val := rpd.SynData[varIndex*nsyn+synIndex]
return val, true
}
////////////////////////////////////////////////////////////////
// IO
// OpenJSON opens colors from a JSON-formatted file.
func (nd *NetData) OpenJSON(filename core.Filename) error { //types:add
fp, err := os.Open(string(filename))
defer fp.Close()
if err != nil {
return errors.Log(err)
}
ext := filepath.Ext(string(filename))
if ext == ".gz" {
gzr, err := gzip.NewReader(fp)
defer gzr.Close()
if err != nil {
return errors.Log(err)
}
return nd.ReadJSON(gzr)
} else {
return nd.ReadJSON(bufio.NewReader(fp))
}
}
// SaveJSON saves colors to a JSON-formatted file.
func (nd *NetData) SaveJSON(filename core.Filename) error { //types:add
fp, err := os.Create(string(filename))
defer fp.Close()
if err != nil {
return errors.Log(err)
}
ext := filepath.Ext(string(filename))
if ext == ".gz" {
gzr := gzip.NewWriter(fp)
err = nd.WriteJSON(gzr)
gzr.Close()
} else {
bw := bufio.NewWriter(fp)
err = nd.WriteJSON(bw)
bw.Flush()
}
return err
}
// ReadJSON reads netdata from JSON format
func (nd *NetData) ReadJSON(r io.Reader) error {
dec := json.NewDecoder(r)
err := dec.Decode(nd) // this is way to do it on reader instead of bytes
nan := math32.NaN()
for _, ld := range nd.LayData {
for i := range ld.Data {
if ld.Data[i] == NaNSub {
ld.Data[i] = nan
}
}
}
if err == nil || err == io.EOF {
return nil
}
return errors.Log(err)
}
// NaNSub is used to replace NaN values for saving -- JSON doesn't handle nan's
const NaNSub = -1.11e-37
// WriteJSON writes netdata to JSON format
func (nd *NetData) WriteJSON(w io.Writer) error {
for _, ld := range nd.LayData {
for i := range ld.Data {
if math32.IsNaN(ld.Data[i]) {
ld.Data[i] = NaNSub
}
}
}
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
return errors.Log(enc.Encode(nd))
}
// func (ld *LayData) MarshalJSON() ([]byte, error) {
//
// }
// PlotSelectedUnit opens a window with a plot of all the data for the
// currently selected unit, saving data to the [tensorfs.CurRoot]/NetView
// directory.
// Useful for replaying detailed trace for units of interest.
func (nv *NetView) PlotSelectedUnit() (*table.Table, *plotcore.Editor) { //types:add
nd := &nv.Data
if nd.PathLay == "" || nd.PathUnIndex < 0 {
fmt.Printf("NetView:PlotSelectedUnit -- no unit selected\n")
return nil, nil
}
selnm := nd.PathLay + fmt.Sprintf("[%d]", nd.PathUnIndex)
dt := nd.SelectedUnitTable(nv.Di)
for _, vnm := range nd.UnVars {
vp, ok := nv.VarOptions[vnm]
if !ok {
continue
}
disp := (vnm == nv.Var)
min := vp.Range.Min
if min < 0 && vp.Range.FixMin && vp.MinMax.Min >= 0 {
min = 0 // netview uses -1..1 but not great for graphs unless needed
}
dc := dt.Column(vnm)
plot.Styler(dc, func(s *plot.Style) {
s.On = disp
if !vp.Range.FixMax {
s.RightY = true
}
s.Range.SetMin(float64(min)).SetMax(float64(vp.Range.Max))
})
}
if tensorfs.CurRoot != nil && lab.Lab != nil {
dir := tensorfs.CurRoot.Dir("NetView")
udir := dir.Dir(selnm)
tensorfs.DirFromTable(udir, dt)
plt := lab.Lab.PlotTensorFS(udir)
return dt, plt
} else {
b := core.NewBody("netview-selectedunit").SetTitle("NetView SelectedUnit Plot: " + selnm)
plt := plotcore.NewEditor(b)
plt.SetTable(dt)
b.AddTopBar(func(bar *core.Frame) {
core.NewToolbar(bar).Maker(plt.MakeToolbar)
})
b.RunWindow()
return dt, plt
}
}
// SelectedUnitTable returns a table with all of the data for the
// currently-selected unit, and data parallel index.
func (nd *NetData) SelectedUnitTable(di int) *table.Table {
if nd.PathLay == "" || nd.PathUnIndex < 0 {
fmt.Printf("NetView:SelectedUnitTable -- no unit selected\n")
return nil
}
ld, ok := nd.LayData[nd.PathLay]
if !ok {
fmt.Printf("NetView:SelectedUnitTable -- layer name incorrect\n")
return nil
}
selnm := nd.PathLay + fmt.Sprintf("[%d]", nd.PathUnIndex)
dt := table.New()
metadata.SetName(dt, "NetView: "+selnm)
metadata.Set(dt, "read-only", true)
tensor.SetPrecision(dt, 4)
ln := nd.Ring.Len
vlen := len(nd.UnVars)
nu := ld.NUnits
nvu := vlen * nd.MaxData * nu
uidx1d := nd.PathUnIndex
dt.AddIntColumn("Rec")
for _, vnm := range nd.UnVars {
dt.AddFloat64Column(vnm)
}
dt.SetNumRows(ln)
for ri := 0; ri < ln; ri++ {
ridx := nd.RecIndex(ri)
dt.Columns.Values[0].SetFloat1D(float64(ri), ri)
for vi := 0; vi < vlen; vi++ {
idx := ridx*nvu + vi*nd.MaxData*nu + di*nu + uidx1d
val := ld.Data[idx]
dt.Columns.Values[vi+1].SetFloat1D(float64(val), ri)
}
}
return dt
}
/*
var NetDataProps = tree.Props{
"CallMethods": tree.PropSlice{
{"SaveJSON", tree.Props{
"desc": "save recorded network view data to file",
"icon": "file-save",
"Args": tree.PropSlice{
{"File Name", tree.Props{
"ext": ".netdat,.netdat.gz",
}},
},
}},
{"OpenJSON", tree.Props{
"desc": "open recorded network view data from file",
"icon": "file-open",
"Args": tree.PropSlice{
{"File Name", tree.Props{
"ext": ".netdat,.netdat.gz",
}},
},
}},
},
}
*/
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package netview provides the NetView interactive 3D network viewer,
// implemented in the Cogent Core 3D framework.
package netview
//go:generate core generate -add-types
import (
"image/color"
"log"
"log/slog"
"reflect"
"strings"
"sync"
"time"
"cogentcore.org/core/colors"
"cogentcore.org/core/colors/colormap"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/math32"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/core/styles"
"cogentcore.org/core/system"
"cogentcore.org/core/text/textcore"
"cogentcore.org/core/tree"
"cogentcore.org/core/types"
"cogentcore.org/core/xyz"
"github.com/emer/emergent/v2/emer"
)
// NetView is a Cogent Core Widget that provides a 3D network view using the Cogent Core gi3d
// 3D framework.
type NetView struct {
core.Frame
// the network that we're viewing
Net emer.Network `set:"-"`
// current variable that we're viewing
Var string `set:"-"`
// current data parallel index di, for networks capable of processing input patterns in parallel.
Di int
// the list of variables to view
Vars []string
// list of synaptic variables
SynVars []string
// map of synaptic variable names to index
SynVarsMap map[string]int
// parameters for the list of variables to view
VarOptions map[string]*VarOptions
// current var params -- only valid during Update of display
CurVarOptions *VarOptions `json:"-" xml:"-" display:"-"`
// parameters controlling how the view is rendered
Options Options
// color map for mapping values to colors -- set by name in Options
ColorMap *colormap.Map
// color map value representing ColorMap
ColorMapButton *core.ColorMapButton
// record number to display -- use -1 to always track latest, otherwise in range
RecNo int
// last non-empty counters string provided -- re-used if no new one
LastCtrs string
// current counters
CurCtrs string
// contains all the network data with history
Data NetData
// mutex on data access
DataMu sync.RWMutex `display:"-" copier:"-" json:"-" xml:"-"`
// these are used to detect need to update
layerNameSizeShown float32
hasPaths bool
pathTypeShown string
pathWidthShown float32
}
func (nv *NetView) Init() {
nv.Frame.Init()
nv.Options.Defaults()
nv.ColorMap = colormap.AvailableMaps[string(nv.Options.ColorMap)]
nv.RecNo = -1
nv.Styler(func(s *styles.Style) {
s.Direction = styles.Column
s.Grow.Set(1, 1)
})
tree.AddChildAt(nv, "tbar", func(w *core.Toolbar) {
w.Styler(func(s *styles.Style) {
s.Wrap = true
})
w.Maker(nv.MakeToolbar)
})
tree.AddChildAt(nv, "netframe", func(w *core.Frame) {
w.Styler(func(s *styles.Style) {
s.Direction = styles.Row
s.Grow.Set(1, 1)
})
nv.makeVars(w)
tree.AddChildAt(w, "scene", func(w *Scene) {
w.NetView = nv
se := w.SceneXYZ()
nv.ViewDefaults(se)
pathsGp := xyz.NewGroup(se)
pathsGp.Name = "Paths"
laysGp := xyz.NewGroup(se)
laysGp.Name = "Layers"
})
w.OnShow(func(e events.Event) {
nv.Current()
})
})
tree.AddChildAt(nv, "counters", func(w *core.Text) {
w.SetText("Counters: ").
Styler(func(s *styles.Style) {
s.Min.X.Pw(90)
})
w.Updater(func() {
if w.Text != nv.CurCtrs && nv.CurCtrs != "" {
w.SetText(nv.CurCtrs)
}
})
})
tree.AddChildAt(nv, "vbar", func(w *core.Toolbar) {
w.Styler(func(s *styles.Style) {
s.Wrap = true
})
w.Maker(nv.MakeViewbar)
})
}
// SetNet sets the network to view and updates view
func (nv *NetView) SetNet(net emer.Network) {
nv.Net = net
nv.DataMu.Lock()
nv.Data.Init(nv.Net, nv.Options.MaxRecs, nv.Options.NoSynData, nv.Net.MaxParallelData())
nv.DataMu.Unlock()
nv.UpdateTree() // need children
nv.UpdateLayers()
nv.Current()
}
// SetVar sets the variable to view and updates the display
func (nv *NetView) SetVar(vr string) {
nv.DataMu.Lock()
nv.Var = vr
nv.VarsFrame().Update()
nv.DataMu.Unlock()
nv.Toolbar().Update()
nv.UpdateView()
}
// SetMaxRecs sets the maximum number of records that are maintained (default 210)
// resets the current data in the process
func (nv *NetView) SetMaxRecs(max int) {
nv.Options.MaxRecs = max
nv.Data.Init(nv.Net, nv.Options.MaxRecs, nv.Options.NoSynData, nv.Net.MaxParallelData())
}
// HasLayers returns true if network has any layers -- else no display
func (nv *NetView) HasLayers() bool {
if nv.Net == nil || nv.Net.NumLayers() == 0 {
return false
}
return true
}
// IsViewingSynapse returns true if netview is viewing synapses.
func (nv *NetView) IsViewingSynapse() bool {
if !nv.IsVisible() {
return false
}
vvar := nv.Var
if strings.HasPrefix(vvar, "r.") || strings.HasPrefix(vvar, "s.") {
return true
}
return false
}
// RecordCounters saves the counters, so they are available for a Current update
func (nv *NetView) RecordCounters(counters string) {
nv.DataMu.Lock()
defer nv.DataMu.Unlock()
if counters != "" {
nv.LastCtrs = counters
}
}
// Record records the current state of the network, along with provided counters
// string, which is displayed at the bottom of the view to show the current
// state of the counters. The rastCtr is the raster counter value used for
// the raster plot mode -- use -1 for a default incrementing counter.
// The NetView displays this recorded data when Update is next called.
func (nv *NetView) Record(counters string, rastCtr int) {
nv.DataMu.Lock()
defer nv.DataMu.Unlock()
if counters != "" {
nv.LastCtrs = counters
}
nv.Data.PathType = nv.Options.PathType
nv.Data.Record(nv.LastCtrs, rastCtr, nv.Options.Raster.Max)
nv.RecTrackLatest() // if we make a new record, then user expectation is to track latest..
}
// RecordSyns records synaptic data -- stored separate from unit data
// and only needs to be called when synaptic values are updated.
// Should be done when the DWt values have been computed, before
// updating Wts and zeroing.
// NetView displays this recorded data when Update is next called.
func (nv *NetView) RecordSyns() {
nv.DataMu.Lock()
defer nv.DataMu.Unlock()
nv.Data.RecordSyns()
}
// GoUpdateView is the update call to make from another go routine
// it does the proper blocking to coordinate with GUI updates
// generated on the main GUI thread.
func (nv *NetView) GoUpdateView() {
if !nv.IsVisible() || !nv.HasLayers() {
return
}
sw := nv.SceneWidget()
sw.Scene.AsyncLock()
nv.UpdateImpl()
sw.NeedsRender()
sw.Scene.AsyncUnlock()
if core.TheApp.Platform() == system.Web {
time.Sleep(time.Millisecond) // critical to prevent hanging!
}
}
// UpdateView updates the display based on last recorded state of network.
func (nv *NetView) UpdateView() {
if !nv.IsVisible() || !nv.HasLayers() {
return
}
sw := nv.SceneWidget()
nv.UpdateImpl()
sw.NeedsRender()
}
// Current records the current state of the network, including synaptic values,
// and updates the display. Use this when switching to NetView tab after network
// has been running while viewing another tab, because the network state
// is typically not recored then.
func (nv *NetView) Current() { //types:add
nv.Record("", -1)
nv.RecordSyns()
nv.UpdateView()
}
// UpdateImpl does the guts of updating -- backend for Update or GoUpdate
func (nv *NetView) UpdateImpl() {
nv.DataMu.Lock()
vp, ok := nv.VarOptions[nv.Var]
if !ok {
nv.DataMu.Unlock()
log.Printf("NetView: %v variable: %v not found\n", nv.Name, nv.Var)
return
}
nv.CurVarOptions = vp
if !vp.Range.FixMin || !vp.Range.FixMax {
needUpdate := false
// need to autoscale
min, max, ok := nv.Data.VarRange(nv.Var)
if ok {
vp.MinMax.Set(min, max)
if !vp.Range.FixMin {
nmin := float32(minmax.NiceRoundNumber(float64(min), true)) // true = below
if vp.Range.Min != nmin {
vp.Range.Min = nmin
needUpdate = true
}
}
if !vp.Range.FixMax {
nmax := float32(minmax.NiceRoundNumber(float64(max), false)) // false = above
if vp.Range.Max != nmax {
vp.Range.Max = nmax
needUpdate = true
}
}
if vp.ZeroCtr && !vp.Range.FixMin && !vp.Range.FixMax {
bmax := math32.Max(math32.Abs(vp.Range.Max), math32.Abs(vp.Range.Min))
if !needUpdate {
if vp.Range.Max != bmax || vp.Range.Min != -bmax {
needUpdate = true
}
}
vp.Range.Max = bmax
vp.Range.Min = -bmax
}
if needUpdate {
tb := nv.Toolbar()
tb.UpdateTree()
tb.NeedsRender()
}
}
}
nv.SetCounters(nv.Data.CounterRec(nv.RecNo))
nv.UpdateRecNo()
nv.DataMu.Unlock()
nv.UpdateLayers()
}
// // ReconfigMeshes reconfigures the layer meshes
// func (nv *NetView) ReconfigMeshes() {
// se := nv.SceneXYZ()
// se.ReconfigMeshes()
// }
func (nv *NetView) Toolbar() *core.Toolbar {
return nv.ChildByName("tbar", 0).(*core.Toolbar)
}
func (nv *NetView) NetFrame() *core.Frame {
return nv.ChildByName("netframe", 1).(*core.Frame)
}
func (nv *NetView) Counters() *core.Text {
return nv.ChildByName("counters", 2).(*core.Text)
}
func (nv *NetView) Viewbar() *core.Toolbar {
return nv.ChildByName("vbar", 3).(*core.Toolbar)
}
func (nv *NetView) SceneWidget() *Scene {
return nv.NetFrame().ChildByName("scene", 1).(*Scene)
}
func (nv *NetView) SceneXYZ() *xyz.Scene {
return nv.SceneWidget().SceneXYZ()
}
func (nv *NetView) VarsFrame() *core.Tabs {
return nv.NetFrame().ChildByName("vars", 0).(*core.Tabs)
}
// SetCounters sets the counters widget view display at bottom of netview
func (nv *NetView) SetCounters(ctrs string) {
if ctrs == "" {
return
}
nv.CurCtrs = ctrs
ct := nv.Counters()
ct.UpdateWidget().NeedsRender()
}
// UpdateRecNo updates the record number viewing
func (nv *NetView) UpdateRecNo() {
vbar := nv.Viewbar()
rlbl := vbar.ChildByName("rec", 10)
if rlbl != nil {
rlbl.(*core.Text).UpdateWidget().NeedsRender()
}
}
// RecFullBkwd move view record to start of history.
func (nv *NetView) RecFullBkwd() bool {
if nv.RecNo == 0 {
return false
}
nv.RecNo = 0
return true
}
// RecFastBkwd move view record N (default 10) steps backward. Returns true if updated.
func (nv *NetView) RecFastBkwd() bool {
if nv.RecNo == 0 {
return false
}
if nv.RecNo < 0 {
nv.RecNo = nv.Data.Ring.Len - nv.Options.NFastSteps
} else {
nv.RecNo -= nv.Options.NFastSteps
}
if nv.RecNo < 0 {
nv.RecNo = 0
}
return true
}
// RecBkwd move view record 1 steps backward. Returns true if updated.
func (nv *NetView) RecBkwd() bool {
if nv.RecNo == 0 {
return false
}
if nv.RecNo < 0 {
nv.RecNo = nv.Data.Ring.Len - 1
} else {
nv.RecNo -= 1
}
if nv.RecNo < 0 {
nv.RecNo = 0
}
return true
}
// RecFwd move view record 1 step forward. Returns true if updated.
func (nv *NetView) RecFwd() bool {
if nv.RecNo >= nv.Data.Ring.Len-1 {
nv.RecNo = nv.Data.Ring.Len - 1
return false
}
if nv.RecNo < 0 {
return false
}
nv.RecNo += 1
if nv.RecNo >= nv.Data.Ring.Len-1 {
nv.RecNo = nv.Data.Ring.Len - 1
}
return true
}
// RecFastFwd move view record N (default 10) steps forward. Returns true if updated.
func (nv *NetView) RecFastFwd() bool {
if nv.RecNo >= nv.Data.Ring.Len-1 {
nv.RecNo = nv.Data.Ring.Len - 1
return false
}
if nv.RecNo < 0 {
return false
}
nv.RecNo += nv.Options.NFastSteps
if nv.RecNo >= nv.Data.Ring.Len-1 {
nv.RecNo = nv.Data.Ring.Len - 1
}
return true
}
// RecTrackLatest sets view to track latest record (-1). Returns true if updated.
func (nv *NetView) RecTrackLatest() bool {
if nv.RecNo == -1 {
return false
}
nv.RecNo = -1
return true
}
// NetVarsList returns the list of layer and path variables for given network.
// layEven ensures that the number of layer variables is an even number if true
// (used for display but not storage).
func (nv *NetView) NetVarsList(net emer.Network, layEven bool) (nvars, synvars []string) {
if net == nil || net.NumLayers() == 0 {
return nil, nil
}
unvars := net.UnitVarNames()
synvars = net.SynVarNames()
ulen := len(unvars)
tlen := ulen + 2*len(synvars)
nvars = make([]string, tlen)
copy(nvars, unvars)
st := ulen
for pi := 0; pi < len(synvars); pi++ {
nvars[st+2*pi] = "r." + synvars[pi]
nvars[st+2*pi+1] = "s." + synvars[pi]
}
return
}
// VarsListUpdate updates the list of network variables
func (nv *NetView) VarsListUpdate() {
nvars, synvars := nv.NetVarsList(nv.Net, true) // true = layEven
if len(nvars) == len(nv.Vars) {
return
}
nv.Vars = nvars
nv.VarOptions = make(map[string]*VarOptions, len(nv.Vars))
nv.SynVars = synvars
nv.SynVarsMap = make(map[string]int, len(synvars))
for i, vn := range nv.SynVars {
nv.SynVarsMap[vn] = i
}
unprops := nv.Net.UnitVarProps()
pathprops := nv.Net.SynVarProps()
for _, nm := range nv.Vars {
vp := &VarOptions{Var: nm}
vp.Defaults()
var vtag string
if strings.HasPrefix(nm, "r.") || strings.HasPrefix(nm, "s.") {
vtag = pathprops[nm[2:]]
} else {
vtag = unprops[nm]
}
if vtag != "" {
vp.SetProps(vtag)
}
nv.VarOptions[nm] = vp
}
}
// makeVars configures the variables
func (nv *NetView) makeVars(netframe *core.Frame) {
nv.VarsListUpdate()
if nv.Net == nil {
return
}
unprops := nv.Net.UnitVarProps()
pathprops := nv.Net.SynVarProps()
cats := nv.Net.VarCategories()
if len(cats) == 0 {
cats = []emer.VarCategory{
{"Unit", "unit variables"},
{"Wt", "connection weight variables"},
}
}
tree.AddChildAt(netframe, "vars", func(w *core.Tabs) {
w.Styler(func(s *styles.Style) {
s.Grow.Set(0, 1)
s.Overflow.Y = styles.OverflowAuto
})
tabs := make(map[string]*core.Frame)
for _, ct := range cats {
tf, tb := w.NewTab(ct.Cat)
tb.Tooltip = ct.Doc
tabs[ct.Cat] = tf
tf.Styler(func(s *styles.Style) {
s.Display = styles.Grid
s.Columns = nv.Options.NVarCols
s.Grow.Set(1, 1)
s.Overflow.Y = styles.OverflowAuto
s.Background = colors.Scheme.SurfaceContainerLow
})
}
for _, vn := range nv.Vars {
cat := ""
pstr := ""
doc := ""
if strings.HasPrefix(vn, "r.") || strings.HasPrefix(vn, "s.") {
pstr = pathprops[vn[2:]]
cat = "Wt" // default
} else {
pstr = unprops[vn]
cat = "Unit"
}
if pstr != "" {
rstr := reflect.StructTag(pstr)
doc = rstr.Get("doc")
cat = rstr.Get("cat")
if rstr.Get("display") == "-" {
continue
}
}
tf, ok := tabs[cat]
if !ok {
slog.Error("emergent.NetView UnitVarProps 'cat' name not found in VarCategories list", "cat", cat, "variable", vn)
cat = cats[0].Cat
tf = tabs[cat]
}
w := core.NewButton(tf).SetText(vn)
if doc != "" {
w.Tooltip = vn + ": " + doc
}
w.SetText(vn).SetType(core.ButtonAction)
w.OnClick(func(e events.Event) {
nv.SetVar(vn)
})
w.Updater(func() {
w.SetSelected(w.Text == nv.Var)
})
}
})
}
// ViewDefaults are the default 3D view params
func (nv *NetView) ViewDefaults(se *xyz.Scene) {
se.Camera.Pose.Pos.Set(0, 1.5, 2.5) // more "top down" view shows more of layers
// vs.Camera.Pose.Pos.Set(0, 1, 2.75) // more "head on" for larger / deeper networks
se.Camera.Near = 0.1
se.Camera.LookAt(math32.Vec3(0, 0, 0), math32.Vec3(0, 1, 0))
nv.Styler(func(s *styles.Style) {
se.Background = colors.Scheme.Surface
})
xyz.NewAmbient(se, "ambient", 0.1, xyz.DirectSun)
xyz.NewDirectional(se, "directional", 0.5, xyz.DirectSun).Pos.Set(0, 2, 5)
xyz.NewPoint(se, "point", .2, xyz.DirectSun).Pos.Set(0, 2, -5)
}
// ReadLock locks data for reading -- call ReadUnlock when done.
// Call this surrounding calls to UnitVal.
func (nv *NetView) ReadLock() {
nv.DataMu.RLock()
}
// ReadUnlock unlocks data for reading.
func (nv *NetView) ReadUnlock() {
nv.DataMu.RUnlock()
}
// UnitValue returns the raw value, scaled value, and color representation
// for given unit of given layer. scaled is in range -1..1
func (nv *NetView) UnitValue(lay emer.Layer, idx []int) (raw, scaled float32, clr color.RGBA, hasval bool) {
lb := lay.AsEmer()
idx1d := lb.Shape.IndexTo1D(idx...)
if idx1d >= lb.Shape.Len() {
raw, hasval = 0, false
} else {
raw, hasval = nv.Data.UnitValue(lb.Name, nv.Var, idx1d, nv.RecNo, nv.Di)
}
scaled, clr = nv.UnitValColor(lay, idx1d, raw, hasval)
return
}
// UnitValRaster returns the raw value, scaled value, and color representation
// for given unit of given layer, and given raster counter index value (0..RasterMax)
// scaled is in range -1..1
func (nv *NetView) UnitValRaster(lay emer.Layer, idx []int, rCtr int) (raw, scaled float32, clr color.RGBA, hasval bool) {
lb := lay.AsEmer()
idx1d := lb.GetSampleShape().IndexTo1D(idx...)
ridx := lb.SampleIndexes
if len(ridx) == 0 { // no rep
if idx1d >= lb.Shape.Len() {
raw, hasval = 0, false
} else {
raw, hasval = nv.Data.UnitValRaster(lb.Name, nv.Var, idx1d, rCtr, nv.Di)
}
} else {
if idx1d >= len(ridx) {
raw, hasval = 0, false
} else {
idx1d = ridx[idx1d]
raw, hasval = nv.Data.UnitValRaster(lb.Name, nv.Var, idx1d, rCtr, nv.Di)
}
}
scaled, clr = nv.UnitValColor(lay, idx1d, raw, hasval)
return
}
var NilColor = color.RGBA{0x20, 0x20, 0x20, 0x40}
// UnitValColor returns the raw value, scaled value, and color representation
// for given unit of given layer. scaled is in range -1..1
func (nv *NetView) UnitValColor(lay emer.Layer, idx1d int, raw float32, hasval bool) (scaled float32, clr color.RGBA) {
if nv.CurVarOptions == nil || nv.CurVarOptions.Var != nv.Var {
ok := false
nv.CurVarOptions, ok = nv.VarOptions[nv.Var]
if !ok {
return
}
}
if !hasval {
scaled = 0
if lay.Label() == nv.Data.PathLay && idx1d == nv.Data.PathUnIndex {
clr = color.RGBA{0x20, 0x80, 0x20, 0x80}
} else {
clr = NilColor
}
} else {
clp := nv.CurVarOptions.Range.ClampValue(raw)
norm := nv.CurVarOptions.Range.NormValue(clp)
var op float32
if nv.CurVarOptions.ZeroCtr {
scaled = float32(2*norm - 1)
op = (nv.Options.ZeroAlpha + (1-nv.Options.ZeroAlpha)*math32.Abs(scaled))
} else {
scaled = float32(norm)
op = (nv.Options.ZeroAlpha + (1-nv.Options.ZeroAlpha)*0.8) // no meaningful alpha -- just set at 80\%
}
clr = colors.WithAF32(nv.ColorMap.Map(norm), op)
}
return
}
func (nv *NetView) Labels() *xyz.Group {
se := nv.SceneXYZ()
lgpi := se.ChildByName("Labels", 1)
if lgpi == nil {
return nil
}
return lgpi.(*xyz.Group)
}
func (nv *NetView) Layers() *xyz.Group {
se := nv.SceneXYZ()
lgpi := se.ChildByName("Layers", 0)
if lgpi == nil {
return nil
}
return lgpi.(*xyz.Group)
}
// ConfigLabels ensures that given label xyz.Text2D objects are created and initialized
// in a top-level group called Labels. Use LabelByName() to get a given label, and
// LayerByName() to get a Layer group, whose Pose can be copied to put a label in
// position relative to a layer. Default alignment is Left, Top.
// Returns true set of labels was changed (mods).
func (nv *NetView) ConfigLabels(labs []string) bool {
se := nv.SceneXYZ()
lgp := nv.Labels()
if lgp == nil {
lgp = xyz.NewGroup(se)
lgp.Name = "Labels"
}
lbConfig := tree.TypePlan{}
for _, ls := range labs {
lbConfig.Add(types.For[xyz.Text2D](), ls)
}
if tree.Update(lgp, lbConfig) {
for i, ls := range labs {
lb := lgp.ChildByName(ls, i).(*xyz.Text2D)
lb.SetText(ls)
// todo:
// lb.SetProperty("text-align", styles.Start)
// lb.SetProperty("vertical-align", styles.Start)
// lb.SetProperty("white-space", styles.WhiteSpacePre)
}
return true
}
return false
}
// LabelByName returns given Text2D label (see ConfigLabels).
// nil if not found.
func (nv *NetView) LabelByName(lab string) *xyz.Text2D {
lgp := nv.Labels()
txt := lgp.ChildByName(lab, 0)
if txt == nil {
return nil
}
return txt.(*xyz.Text2D)
}
// LayerByName returns the xyz.Group that represents layer of given name.
// nil if not found.
func (nv *NetView) LayerByName(lay string) *xyz.Group {
lgp := nv.Layers()
ly := lgp.ChildByName(lay, 0)
if ly == nil {
return nil
}
return ly.(*xyz.Group)
}
// SaveWeights saves the network weights.
func (nv *NetView) SaveWeights(filename core.Filename) { //types:add
nv.Net.AsEmer().SaveWeightsJSON(filename)
}
// OpenWeights opens the network weights.
func (nv *NetView) OpenWeights(filename core.Filename) { //types:add
nv.Net.AsEmer().OpenWeightsJSON(filename)
}
// ShowNonDefaultParams shows a dialog of all the parameters that
// are not at their default values in the network. Useful for setting params.
func (nv *NetView) ShowNonDefaultParams() string { //types:add
nds := nv.Net.AsEmer().ParamsString(emer.NonDefault)
textcore.TextDialog(nv, "Non Default Params: "+nv.Name, nds)
return nds
}
// ShowAllParams shows a dialog of all the parameters in the network.
func (nv *NetView) ShowAllParams() string { //types:add
nds := nv.Net.AsEmer().ParamsString(emer.AllParams)
textcore.TextDialog(nv, "All Params: "+nv.Name, nds)
return nds
}
// ShowKeyLayerParams shows a dialog with a listing for all layers in the network,
// of the most important layer-level params (specific to each algorithm)
func (nv *NetView) ShowKeyLayerParams() string { //types:add
nds := nv.Net.KeyLayerParams()
textcore.TextDialog(nv, "Key Layer Params: "+nv.Name, nds)
return nds
}
// ShowKeyPathParams shows a dialog with a listing for all Recv pathways in the network,
// of the most important pathway-level params (specific to each algorithm)
func (nv *NetView) ShowKeyPathParams() string { //types:add
nds := nv.Net.KeyPathParams()
textcore.TextDialog(nv, "Key Path Params: "+nv.Name, nds)
return nds
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package netview
import (
"log"
"reflect"
"strconv"
"cogentcore.org/core/core"
"cogentcore.org/core/math32/minmax"
)
// NVarCols is the default number of variable columns in the NetView
var NVarCols = 2
// RasterOptions holds parameters controlling the raster plot view
type RasterOptions struct { //types:add
// if true, show a raster plot over time, otherwise units
On bool
// if true, the raster counter (time) is plotted across the X axis -- otherwise the Z depth axis
XAxis bool
// maximum count for the counter defining the raster plot
Max int
// size of a single unit, where 1 = full width and no space.. 1 default
UnitSize float32 `min:"0.1" max:"1" step:"0.1" default:"1"`
// height multiplier for units, where 1 = full height.. 0.2 default
UnitHeight float32 `min:"0.1" max:"1" step:"0.1" default:"0.2"`
}
func (nv *RasterOptions) Defaults() {
if nv.Max == 0 {
nv.Max = 200
}
if nv.UnitSize == 0 {
nv.UnitSize = 1
}
if nv.UnitHeight == 0 {
nv.UnitHeight = .2
}
}
// Options holds parameters controlling how the view is rendered
type Options struct { //types:add
// whether to display the pathways between layers as arrows
Paths bool
// PathType has name(s) to display (space separated), for path arrows,
// and when there are multiple pathways from the same layer.
// Uses the parameter Class names in addition to type,
// and case insensitive "contains" logic for each name.
PathType string
// width of the path arrows, in normalized units
PathWidth float32 `min:"0.0001" max:".05" step:"0.001" default:"0.002"`
// raster plot parameters
Raster RasterOptions `display:"inline"`
// do not record synapse level data -- turn this on for very large networks where recording the entire synaptic state would be prohibitive
NoSynData bool
// maximum number of records to store to enable rewinding through prior states
MaxRecs int `min:"1"`
// number of variable columns
NVarCols int
// size of a single unit, where 1 = full width and no space.. .9 default
UnitSize float32 `min:"0.1" max:"1" step:"0.1" default:"0.9"`
// size of the layer name labels -- entire network view is unit sized
LayerNameSize float32 `min:"0.01" max:".1" step:"0.01" default:"0.05"`
// name of color map to use
ColorMap core.ColorMapName
// opacity (0-1) of zero values -- greater magnitude values become increasingly opaque on either side of this minimum
ZeroAlpha float32 `min:"0" max:"1" step:"0.1" default:"0.5"`
// the number of records to jump for fast forward/backward
NFastSteps int
}
func (nv *Options) Defaults() {
nv.Raster.Defaults()
if nv.NVarCols == 0 {
nv.NVarCols = NVarCols
nv.Paths = true
nv.PathWidth = 0.002
}
if nv.MaxRecs == 0 {
nv.MaxRecs = 210 // 200 cycles + 8 phase updates max + 2 extra..
}
if nv.UnitSize == 0 {
nv.UnitSize = .9
}
if nv.LayerNameSize == 0 {
nv.LayerNameSize = .05
}
if nv.ZeroAlpha == 0 {
nv.ZeroAlpha = 0.5
}
if nv.ColorMap == "" {
nv.ColorMap = core.ColorMapName("ColdHot")
}
if nv.NFastSteps == 0 {
nv.NFastSteps = 10
}
}
// VarOptions holds parameters for display of each variable
type VarOptions struct { //types:add
// name of the variable
Var string
// keep Min - Max centered around 0, and use negative heights for units -- else use full min-max range for height (no negative heights)
ZeroCtr bool
// range to display
Range minmax.Range32 `display:"inline"`
// if not using fixed range, this is the actual range of data
MinMax minmax.F32 `display:"inline"`
}
// Defaults sets default values if otherwise not set
func (vp *VarOptions) Defaults() {
if vp.Range.Max == 0 && vp.Range.Min == 0 {
vp.ZeroCtr = true
vp.Range.SetMin(-1)
vp.Range.SetMax(1)
}
}
// SetProps parses Go struct-tag style properties for variable and sets values accordingly
// for customized defaults
func (vp *VarOptions) SetProps(pstr string) {
rstr := reflect.StructTag(pstr)
if tv, ok := rstr.Lookup("range"); ok {
rg, err := strconv.ParseFloat(tv, 32)
if err != nil {
log.Printf("NetView.VarOptions.SetProps for Var: %v 'range:' err: %v on val: %v\n", vp.Var, err, tv)
} else {
vp.Range.Max = float32(rg)
vp.Range.Min = -float32(rg)
vp.ZeroCtr = true
}
}
if tv, ok := rstr.Lookup("min"); ok {
rg, err := strconv.ParseFloat(tv, 32)
if err != nil {
log.Printf("NetView.VarOptions.SetProps for Var: %v 'min:' err: %v on val: %v\n", vp.Var, err, tv)
} else {
vp.Range.Min = float32(rg)
vp.ZeroCtr = false
}
}
if tv, ok := rstr.Lookup("max"); ok {
rg, err := strconv.ParseFloat(tv, 32)
if err != nil {
log.Printf("NetView.VarOptions.SetProps for Var: %v 'max:' err: %v on val: %v\n", vp.Var, err, tv)
} else {
vp.Range.Max = float32(rg)
vp.ZeroCtr = false
}
}
if tv, ok := rstr.Lookup("auto-scale"); ok {
if tv == "+" {
vp.Range.FixMin = false
vp.Range.FixMax = false
} else {
vp.Range.FixMin = true
vp.Range.FixMax = true
}
}
if tv, ok := rstr.Lookup("zeroctr"); ok {
if tv == "+" {
vp.ZeroCtr = true
} else {
vp.ZeroCtr = false
}
}
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package netview
import (
"cmp"
"fmt"
"math"
"slices"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/colors"
"cogentcore.org/core/math32"
"cogentcore.org/core/text/text"
"cogentcore.org/core/tree"
"cogentcore.org/core/types"
"cogentcore.org/core/xyz"
"github.com/emer/emergent/v2/emer"
)
// UpdateLayers updates the layer display with any structural or
// current data changes. Very fast if no structural changes.
func (nv *NetView) UpdateLayers() {
sw := nv.SceneWidget()
se := sw.SceneXYZ()
if nv.Net == nil || nv.Net.NumLayers() == 0 {
se.DeleteChildren()
se.Meshes.Reset()
return
}
nb := nv.Net.AsEmer()
if nv.NeedsRebuild() {
se.Background = colors.Scheme.Surface
}
nlay := nv.Net.NumLayers()
laysGp := se.ChildByName("Layers", 0).(*xyz.Group)
layConfig := tree.TypePlan{}
for li := range nlay {
ly := nv.Net.EmerLayer(li)
layConfig.Add(types.For[xyz.Group](), ly.Label())
}
if !tree.Update(laysGp, layConfig) && nv.layerNameSizeShown == nv.Options.LayerNameSize {
for li := range laysGp.Children {
ly := nv.Net.EmerLayer(li)
lmesh := errors.Log1(se.MeshByName(ly.Label()))
se.SetMesh(lmesh) // does update
}
if nv.hasPaths != nv.Options.Paths || nv.pathTypeShown != nv.Options.PathType ||
nv.pathWidthShown != nv.Options.PathWidth {
nv.UpdatePaths()
}
return
}
nv.layerNameSizeShown = nv.Options.LayerNameSize
gpConfig := tree.TypePlan{}
gpConfig.Add(types.For[LayObj](), "layer")
gpConfig.Add(types.For[LayName](), "name")
nmin, nmax := nb.MinPos, nb.MaxPos
nsz := nmax.Sub(nmin).Sub(math32.Vec3(1, 1, 0)).Max(math32.Vec3(1, 1, 1))
nsc := math32.Vec3(1.0/nsz.X, 1.0/nsz.Y, 1.0/nsz.Z)
szc := math32.Max(nsc.X, nsc.Y)
poff := math32.Vector3Scalar(0.5)
poff.Y = -0.5
for li, lgi := range laysGp.Children {
ly := nv.Net.EmerLayer(li)
lb := ly.AsEmer()
lmesh, _ := se.MeshByName(ly.Label())
if lmesh == nil {
NewLayMesh(se, nv, ly)
} else {
lmesh.(*LayMesh).Lay = ly // make sure
}
lg := lgi.(*xyz.Group)
gpConfig[1].Name = ly.Label() // text2d textures use obj name, so must be unique
tree.Update(lg, gpConfig)
lp := lb.Pos.Pos
lp.Y = -lp.Y // reverse direction
lp = lp.Sub(nmin).Mul(nsc).Sub(poff)
lg.Pose.Pos.Set(lp.X, lp.Z, lp.Y)
lg.Pose.Scale.Set(nsc.X*lb.Pos.Scale, szc, nsc.Y*lb.Pos.Scale)
lo := lg.Child(0).(*LayObj)
lo.Defaults()
lo.LayName = ly.Label()
lo.NetView = nv
lo.SetMeshName(ly.Label())
lo.Material.Color = colors.FromRGB(255, 100, 255)
lo.Material.Reflective = 8
lo.Material.Bright = 8
lo.Material.Shiny = 30
// note: would actually be better to NOT cull back so you can view underneath
// but then the front and back fight against each other, causing flickering
txt := lg.Child(1).(*LayName)
txt.Defaults()
txt.NetView = nv
txt.SetText(ly.Label())
txt.Pose.Scale = math32.Vector3Scalar(nv.Options.LayerNameSize).Div(lg.Pose.Scale)
txt.Styles.Background = colors.Uniform(colors.Transparent)
txt.Styles.Text.Align = text.Start
txt.Styles.Text.AlignV = text.Start
}
nv.UpdatePaths()
sw.NeedsRender()
}
// UpdatePaths updates the path display.
// Only called when layers have structural changes.
func (nv *NetView) UpdatePaths() {
sw := nv.SceneWidget()
se := sw.SceneXYZ()
nb := nv.Net.AsEmer()
nlay := nv.Net.NumLayers()
pathsGp := se.ChildByName("Paths", 0).(*xyz.Group)
pathsGp.DeleteChildren()
if !nv.Options.Paths {
nv.hasPaths = false
return
}
nv.hasPaths = true
nmin, nmax := nb.MinPos, nb.MaxPos
nsz := nmax.Sub(nmin).Sub(math32.Vec3(1, 1, 0)).Max(math32.Vec3(1, 1, 1))
nsc := math32.Vec3(1.0/nsz.X, 1.0/nsz.Y, 1.0/nsz.Z)
poff := math32.Vector3Scalar(0.5)
poff.Y = -0.5
lineWidth := nv.Options.PathWidth
// weight factors applied to distance for the different sides,
// to encourage / discourage choice of sides.
// In general the sides are preferred, and back is discouraged.
sideWeights := [4]float32{1.1, 1, 1, 1.1}
type pathData struct {
path emer.Path
sSide, rSide, cat int
sIdx, sN, rIdx, rN int // indexes and numbers for each side
sPos, rPos math32.Vector3
}
pdIdx := func(side, cat int) int {
return side*3 + cat
}
type layerData struct {
paths [12][]*pathData // by side * category
selfPaths []*pathData
}
layPaths := make([]layerData, nlay)
// 0 = forward, "left" side; 1 = lateral, "middle"; 2 = back, "right"
sideCat := func(rLayY, sLayY float32) int {
if rLayY < sLayY {
return 2
} else if rLayY == sLayY {
return 1
}
return 0
}
// returns layer position and size in normalized display coordinates (NDC)
// using the correct rendering coordinate system: X = X, Y <-> Z
layPosSize := func(lb *emer.LayerBase) (math32.Vector3, math32.Vector3) {
lp := lb.Pos.Pos
lp.Y = -lp.Y
lp = lp.Sub(nmin).Mul(nsc).Sub(poff)
lp.Y, lp.Z = lp.Z, lp.Y
dsz := lb.DisplaySize()
lsz := math32.Vector3{dsz.X * nsc.X, 0, dsz.Y * nsc.Y}
return lp, lsz
}
// F, L, R, B -- center of each side, z is negative; order favors front in a tie
sideMids := []math32.Vector3{{0.5, 0, 0}, {0, 0, -0.5}, {1, 0, -0.5}, {0.5, 0, -1}}
sideDims := []math32.Dims{math32.X, math32.Z, math32.Z, math32.X}
// returns the matrix
sideMtx := func(side int, prop float32) math32.Vector3 {
dim := sideDims[side]
smat := sideMids[side]
smat.SetDim(dim, prop)
if dim == math32.Z {
smat.Z *= -1
}
return smat
}
laySidePos := func(lb *emer.LayerBase, side, cat, idx, n int, off float32) math32.Vector3 {
prop := (float32(cat) / 3) + (float32(idx)+off)/float32(3*n)
pos, sz := layPosSize(lb)
mat := sideMtx(side, prop)
return pos.Add(sz.Mul(mat))
}
// returns the sending, recv positions of the path,
// for point at given index along side, cat
setPathPos := func(pd *pathData) {
pt := pd.path
sb := pt.SendLayer().AsEmer()
rb := pt.RecvLayer().AsEmer()
off := float32(0.4)
if rb.Index < sb.Index {
off = 0.6
}
pd.sPos = laySidePos(sb, pd.sSide, pd.cat, pd.sIdx, pd.sN, off)
pd.rPos = laySidePos(rb, pd.rSide, pd.cat, pd.rIdx, pd.rN, off)
return
}
// first pass: find the side to make connections on, based on shortest weighted length
for si := range nlay {
sl := nv.Net.EmerLayer(si)
sb := sl.AsEmer()
slayData := &layPaths[sb.Index]
sLayPos, _ := layPosSize(sb)
npt := sl.NumSendPaths()
for pi := range npt {
pt := sl.SendPath(pi)
if !nv.pathTypeNameMatch(pt) {
continue
}
rb := pt.RecvLayer().AsEmer()
if sb.Index == rb.Index { // self
slayData.selfPaths = append(slayData.selfPaths, &pathData{path: pt, cat: 1})
continue
}
minDist := float32(math.MaxFloat32)
var minData *pathData
for sSide := range 4 {
swt := sideWeights[sSide]
for rSide := range 4 {
rwt := sideWeights[rSide]
rLayPos, _ := layPosSize(rb)
cat := sideCat(rLayPos.Y, sLayPos.Y)
pd := &pathData{path: pt, sSide: sSide, rSide: rSide, cat: cat, sN: 1, rN: 1}
setPathPos(pd)
dist := pd.rPos.Sub(pd.sPos).Length() * swt * rwt
if dist < minDist {
minDist = dist
minData = pd
}
}
}
i := pdIdx(minData.sSide, minData.cat)
slayData.paths[i] = append(slayData.paths[i], minData)
rlayData := &layPaths[rb.Index]
i = pdIdx(minData.rSide, minData.cat)
rlayData.paths[i] = append(rlayData.paths[i], minData)
}
}
for li := range nlay {
ly := nv.Net.EmerLayer(li)
lb := ly.AsEmer()
layData := &layPaths[lb.Index]
for side := range 4 {
for cat := range 3 {
pidx := pdIdx(side, cat)
pths := layData.paths[pidx]
npt := len(pths)
if npt == 0 {
continue
}
for pi, pd := range pths {
if pd.path.RecvLayer() == ly {
pd.rIdx = pi
pd.rN = npt
} else {
pd.sIdx = pi
pd.sN = npt
}
}
}
}
}
// now we have the full set of data, sort positions
for range 10 { // 10 seems to get as much as 100 on complex networks
for li := range nlay {
ly := nv.Net.EmerLayer(li)
lb := ly.AsEmer()
layData := &layPaths[lb.Index]
for side := range 4 {
for cat := range 3 {
pidx := pdIdx(side, cat)
pths := layData.paths[pidx]
npt := len(pths)
if npt == 0 {
continue
}
for _, pd := range pths {
if pd.path.RecvLayer() == ly {
setPathPos(pd)
}
}
}
}
}
orderChanged := false
for li := range nlay {
ly := nv.Net.EmerLayer(li)
lb := ly.AsEmer()
layData := &layPaths[lb.Index]
for side := range 4 {
for cat := range 3 {
pidx := pdIdx(side, cat)
pths := layData.paths[pidx]
npt := len(pths)
if npt == 0 {
continue
}
slices.SortStableFunc(pths, func(a, b *pathData) int {
if a.path.RecvLayer() == ly {
return -cmp.Compare(a.sPos.Dim(sideDims[a.rSide]), b.sPos.Dim(sideDims[b.rSide]))
} else {
return -cmp.Compare(a.rPos.Dim(sideDims[a.sSide]), b.rPos.Dim(sideDims[b.sSide]))
}
})
for pi, pd := range pths {
if pd.path.RecvLayer() == ly {
if pi != pd.rIdx {
orderChanged = true
pd.rIdx = pi
}
} else {
if pi != pd.sIdx {
orderChanged = true
pd.sIdx = pi
}
}
}
}
}
}
if !orderChanged {
break
}
}
// final render
for li := range nlay {
ly := nv.Net.EmerLayer(li)
lb := ly.AsEmer()
layData := &layPaths[lb.Index]
for side := range 4 {
for cat := range 3 {
pidx := pdIdx(side, cat)
pths := layData.paths[pidx]
for _, pd := range pths {
if pd.path.RecvLayer() != ly {
continue
}
pt := pd.path
pb := pt.AsEmer()
clr := colors.Spaced(pt.TypeNumber())
arw := xyz.NewGroup(pathsGp)
arw.SetName(pb.Name)
xyz.InitArrow(arw, pd.sPos, pd.rPos, lineWidth, clr, xyz.NoStartArrow, xyz.EndArrow, 4, .5, 4)
}
}
}
npt := len(layData.selfPaths)
if npt == 0 {
continue
}
// determine which side to put the self connections on.
// they will show up in the front by default.
var totLeft, totRight int
for side := 1; side <= 2; side++ { // left, right
for cat := range 3 {
pidx := pdIdx(side, cat)
if side == 1 {
totLeft += len(layData.paths[pidx])
} else {
totRight += len(layData.paths[pidx])
}
}
}
selfSide := 1 // left
if totRight < totLeft {
selfSide = 2 // right
}
for pi, pd := range layData.selfPaths {
pt := pd.path
pb := pt.AsEmer()
pd.sSide, pd.rSide = selfSide, selfSide
clr := colors.Spaced(pt.TypeNumber())
spm := nv.selfPrjn(se, pd.sSide)
sfgp := xyz.NewGroup(pathsGp)
sfgp.SetName(pb.Name)
sfp := xyz.NewSolid(sfgp).SetMesh(spm).SetColor(clr)
sfp.SetName(pb.Name)
sfp.Pose.Pos = laySidePos(lb, selfSide, 1, pi, npt, 0)
}
}
nv.pathTypeShown = nv.Options.PathType
nv.pathWidthShown = nv.Options.PathWidth
}
func (nv *NetView) pathTypeNameMatch(pt emer.Path) bool {
if len(nv.Options.PathType) == 0 {
return true
}
return pt.AsEmer().IsTypeOrClass(nv.Options.PathType)
}
// returns the self projection mesh, either left = 1 or right = 2
func (nv *NetView) selfPrjn(se *xyz.Scene, side int) xyz.Mesh {
selfnm := fmt.Sprintf("selfPathSide%d", side)
sm, err := se.MeshByName(selfnm)
if err == nil && nv.pathWidthShown == nv.Options.PathWidth {
return sm
}
szm := max(nv.Options.PathWidth/0.002, 1)
lineWidth := 1.5 * nv.Options.PathWidth
size := float32(0.015) * szm
sideFact := float32(1.5)
if side == 1 {
sideFact = -1.5
}
sm = xyz.NewLines(se, selfnm, []math32.Vector3{{0, 0, -size}, {sideFact * size, 0, -size}, {sideFact * size, 0, size}, {0, 0, size}}, math32.Vec2(lineWidth, lineWidth), xyz.OpenLines)
return sm
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package netview
import (
"fmt"
"cogentcore.org/core/colors/colormap"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/events/key"
"cogentcore.org/core/icons"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/abilities"
"cogentcore.org/core/tree"
)
func (nv *NetView) MakeToolbar(p *tree.Plan) {
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(nv.Update).SetText("Init").SetIcon(icons.Update)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(nv.Current).SetIcon(icons.Update)
})
tree.Add(p, func(w *core.Button) {
w.SetText("Options").SetIcon(icons.Settings).
SetTooltip("set parameters that control display (font size etc)").
OnClick(func(e events.Event) {
d := core.NewBody(nv.Name + " Options")
core.NewForm(d).SetStruct(&nv.Options).
OnChange(func(e events.Event) {
nv.GoUpdateView()
})
d.RunWindowDialog(nv)
})
})
tree.Add(p, func(w *core.Separator) {})
tree.Add(p, func(w *core.Button) {
w.SetText("Weights").SetType(core.ButtonAction).SetMenu(func(m *core.Scene) {
fb := core.NewFuncButton(m).SetFunc(nv.SaveWeights)
fb.SetIcon(icons.Save)
fb.Args[0].SetTag(`extension:".wts,.wts.gz"`)
fb = core.NewFuncButton(m).SetFunc(nv.OpenWeights)
fb.SetIcon(icons.Open)
fb.Args[0].SetTag(`extension:".wts,.wts.gz"`)
})
})
tree.Add(p, func(w *core.Button) {
w.SetText("Params").SetIcon(icons.Info).SetMenu(func(m *core.Scene) {
core.NewFuncButton(m).SetFunc(nv.ShowNonDefaultParams).SetIcon(icons.Info)
core.NewFuncButton(m).SetFunc(nv.ShowAllParams).SetIcon(icons.Info)
core.NewFuncButton(m).SetFunc(nv.ShowKeyLayerParams).SetIcon(icons.Info)
core.NewFuncButton(m).SetFunc(nv.ShowKeyPathParams).SetIcon(icons.Info)
})
})
tree.Add(p, func(w *core.Button) {
w.SetText("Net Data").SetIcon(icons.Save).SetMenu(func(m *core.Scene) {
core.NewFuncButton(m).SetFunc(nv.Data.SaveJSON).SetText("Save Net Data").SetIcon(icons.Save)
core.NewFuncButton(m).SetFunc(nv.Data.OpenJSON).SetText("Open Net Data").SetIcon(icons.Open)
core.NewSeparator(m)
core.NewFuncButton(m).SetFunc(nv.PlotSelectedUnit).SetIcon(icons.Open)
})
})
tree.Add(p, func(w *core.Separator) {})
tree.Add(p, func(w *core.Switch) {
w.SetText("Paths").SetChecked(nv.Options.Paths).
SetTooltip("Toggles whether pathways between layers are shown or not").
OnChange(func(e events.Event) {
nv.Options.Paths = w.IsChecked()
nv.UpdateView()
})
})
ditp := "data parallel index -- for models running multiple input patterns in parallel, this selects which one is viewed"
tree.Add(p, func(w *core.Text) {
w.SetText("Di:").SetTooltip(ditp)
})
tree.Add(p, func(w *core.Spinner) {
w.SetMin(0).SetStep(1).SetValue(float32(nv.Di)).SetTooltip(ditp)
w.Styler(func(s *styles.Style) {
s.Max.X.Ch(9)
s.Min.X.Ch(9)
})
w.OnChange(func(e events.Event) {
maxData := nv.Net.MaxParallelData()
md := int(w.Value)
if md < maxData && md >= 0 {
nv.Di = md
}
w.SetValue(float32(nv.Di))
nv.UpdateView()
})
})
tree.Add(p, func(w *core.Separator) {})
tree.Add(p, func(w *core.Switch) {
w.SetText("Raster").SetChecked(nv.Options.Raster.On).
SetTooltip("Toggles raster plot mode -- displays values on one axis (Z by default) and raster counter (time) along the other (X by default)").
OnChange(func(e events.Event) {
nv.Options.Raster.On = w.IsChecked()
// nv.ReconfigMeshes()
nv.UpdateView()
})
})
tree.Add(p, func(w *core.Switch) {
w.SetText("X").SetType(core.SwitchCheckbox).SetChecked(nv.Options.Raster.XAxis).
SetTooltip("If checked, the raster (time) dimension is plotted along the X (horizontal) axis of the layers, otherwise it goes in the depth (Z) dimension").
OnChange(func(e events.Event) {
nv.Options.Raster.XAxis = w.IsChecked()
nv.UpdateView()
})
})
vp, ok := nv.VarOptions[nv.Var]
if !ok {
vp = &VarOptions{}
vp.Defaults()
}
var minSpin, maxSpin *core.Spinner
var minSwitch, maxSwitch *core.Switch
tree.Add(p, func(w *core.Separator) {})
tree.AddAt(p, "minSwitch", func(w *core.Switch) {
minSwitch = w
w.SetText("Min").SetType(core.SwitchCheckbox).SetChecked(vp.Range.FixMin).
SetTooltip("Fix the minimum end of the displayed value range to value shown in next box. Having both min and max fixed is recommended where possible for speed and consistent interpretability of the colors.").
OnChange(func(e events.Event) {
vp := nv.VarOptions[nv.Var]
vp.Range.FixMin = w.IsChecked()
minSpin.UpdateWidget().NeedsRender()
nv.UpdateView()
})
w.Updater(func() {
vp := nv.VarOptions[nv.Var]
if vp != nil {
w.SetChecked(vp.Range.FixMin)
}
})
})
tree.AddAt(p, "minSpin", func(w *core.Spinner) {
minSpin = w
w.Styler(func(s *styles.Style) {
s.Min.X.Ch(15)
s.Max.X.Ch(15)
})
w.SetValue(vp.Range.Min).
OnChange(func(e events.Event) {
vp := nv.VarOptions[nv.Var]
vp.Range.SetMin(w.Value)
vp.Range.FixMin = true
minSwitch.UpdateWidget().NeedsRender()
if vp.ZeroCtr && vp.Range.Min < 0 && vp.Range.FixMax {
vp.Range.SetMax(-vp.Range.Min)
}
if vp.ZeroCtr {
maxSpin.UpdateWidget().NeedsRender()
}
nv.UpdateView()
})
w.Updater(func() {
vp := nv.VarOptions[nv.Var]
if vp != nil {
w.SetValue(vp.Range.Min)
}
})
})
tree.AddAt(p, "cmap", func(w *core.ColorMapButton) {
nv.ColorMapButton = w
w.MapName = string(nv.Options.ColorMap)
w.SetTooltip("Color map for translating values into colors -- click to select alternative.")
w.Styler(func(s *styles.Style) {
s.Min.X.Em(10)
s.Min.Y.Em(1.2)
s.Grow.Set(0, 1)
})
w.OnChange(func(e events.Event) {
cmap, ok := colormap.AvailableMaps[string(nv.ColorMapButton.MapName)]
if ok {
nv.ColorMap = cmap
}
nv.UpdateView()
})
})
tree.AddAt(p, "maxSwitch", func(w *core.Switch) {
maxSwitch = w
w.SetText("Max").SetType(core.SwitchCheckbox).SetChecked(vp.Range.FixMax).
SetTooltip("Fix the maximum end of the displayed value range to value shown in next box. Having both min and max fixed is recommended where possible for speed and consistent interpretability of the colors.").
OnChange(func(e events.Event) {
vp := nv.VarOptions[nv.Var]
vp.Range.FixMax = w.IsChecked()
maxSpin.UpdateWidget().NeedsRender()
nv.UpdateView()
})
w.Updater(func() {
vp := nv.VarOptions[nv.Var]
if vp != nil {
w.SetChecked(vp.Range.FixMax)
}
})
})
tree.AddAt(p, "maxSpin", func(w *core.Spinner) {
maxSpin = w
w.Styler(func(s *styles.Style) {
s.Min.X.Ch(15)
s.Max.X.Ch(15)
})
w.SetValue(vp.Range.Max).OnChange(func(e events.Event) {
vp := nv.VarOptions[nv.Var]
vp.Range.SetMax(w.Value)
vp.Range.FixMax = true
maxSwitch.UpdateWidget().NeedsRender()
if vp.ZeroCtr && vp.Range.Max > 0 && vp.Range.FixMin {
vp.Range.SetMin(-vp.Range.Max)
}
if vp.ZeroCtr {
minSpin.UpdateWidget().NeedsRender()
}
nv.UpdateView()
})
w.Updater(func() {
vp := nv.VarOptions[nv.Var]
if vp != nil {
w.SetValue(vp.Range.Max)
}
})
})
tree.AddAt(p, "zeroCtrSwitch", func(w *core.Switch) {
w.SetText("ZeroCtr").SetChecked(vp.ZeroCtr).
SetTooltip("keep Min - Max centered around 0, and use negative heights for units -- else use full min-max range for height (no negative heights)").
OnChange(func(e events.Event) {
vp := nv.VarOptions[nv.Var]
vp.ZeroCtr = w.IsChecked()
nv.UpdateView()
})
w.Updater(func() {
vp := nv.VarOptions[nv.Var]
if vp != nil {
w.SetChecked(vp.ZeroCtr)
}
})
})
}
func (nv *NetView) MakeViewbar(p *tree.Plan) {
tree.Add(p, func(w *core.Button) {
w.SetIcon(icons.Update).SetTooltip("reset to default initial display").
OnClick(func(e events.Event) {
nv.SceneXYZ().SetCamera("default")
nv.UpdateView()
})
})
tree.Add(p, func(w *core.Button) {
w.SetIcon(icons.ZoomIn).SetTooltip("zoom in")
w.Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.RepeatClickable)
})
w.OnClick(func(e events.Event) {
nv.SceneXYZ().Camera.Zoom(-.05)
nv.UpdateView()
})
})
tree.Add(p, func(w *core.Button) {
w.SetIcon(icons.ZoomOut).SetTooltip("zoom out")
w.Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.RepeatClickable)
})
w.OnClick(func(e events.Event) {
nv.SceneXYZ().Camera.Zoom(.05)
nv.UpdateView()
})
})
tree.Add(p, func(w *core.Separator) {})
tree.Add(p, func(w *core.Text) {
w.SetText("Rot:").SetTooltip("rotate display")
})
tree.Add(p, func(w *core.Button) {
w.SetIcon(icons.KeyboardArrowLeft)
w.Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.RepeatClickable)
})
w.OnClick(func(e events.Event) {
nv.SceneXYZ().Camera.Orbit(5, 0)
nv.UpdateView()
})
})
tree.Add(p, func(w *core.Button) {
w.SetIcon(icons.KeyboardArrowUp)
w.Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.RepeatClickable)
})
w.OnClick(func(e events.Event) {
nv.SceneXYZ().Camera.Orbit(0, 5)
nv.UpdateView()
})
})
tree.Add(p, func(w *core.Button) {
w.SetIcon(icons.KeyboardArrowDown)
w.Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.RepeatClickable)
})
w.OnClick(func(e events.Event) {
nv.SceneXYZ().Camera.Orbit(0, -5)
nv.UpdateView()
})
})
tree.Add(p, func(w *core.Button) {
w.SetIcon(icons.KeyboardArrowRight)
w.Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.RepeatClickable)
})
w.OnClick(func(e events.Event) {
nv.SceneXYZ().Camera.Orbit(-5, 0)
nv.UpdateView()
})
})
tree.Add(p, func(w *core.Separator) {})
tree.Add(p, func(w *core.Text) {
w.SetText("Pan:").SetTooltip("pan display")
})
tree.Add(p, func(w *core.Button) {
w.SetIcon(icons.KeyboardArrowLeft)
w.Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.RepeatClickable)
})
w.OnClick(func(e events.Event) {
nv.SceneXYZ().Camera.Pan(-.2, 0)
nv.UpdateView()
})
})
tree.Add(p, func(w *core.Button) {
w.SetIcon(icons.KeyboardArrowUp)
w.Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.RepeatClickable)
})
w.OnClick(func(e events.Event) {
nv.SceneXYZ().Camera.Pan(0, .2)
nv.UpdateView()
})
})
tree.Add(p, func(w *core.Button) {
w.SetIcon(icons.KeyboardArrowDown)
w.Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.RepeatClickable)
})
w.OnClick(func(e events.Event) {
nv.SceneXYZ().Camera.Pan(0, -.2)
nv.UpdateView()
})
})
tree.Add(p, func(w *core.Button) {
w.SetIcon(icons.KeyboardArrowRight)
w.Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.RepeatClickable)
})
w.OnClick(func(e events.Event) {
nv.SceneXYZ().Camera.Pan(.2, 0)
nv.UpdateView()
})
})
tree.Add(p, func(w *core.Separator) {})
tree.Add(p, func(w *core.Text) { w.SetText("Save:") })
for i := 1; i <= 4; i++ {
nm := fmt.Sprintf("%d", i)
tree.AddAt(p, "saved-"+nm, func(w *core.Button) {
w.SetText(nm).
SetTooltip("first click (or + Shift) saves current view, second click restores to saved state")
w.OnClick(func(e events.Event) {
sc := nv.SceneXYZ()
cam := nm
if e.HasAllModifiers(e.Modifiers(), key.Shift) {
sc.SaveCamera(cam)
} else {
err := sc.SetCamera(cam)
if err != nil {
sc.SaveCamera(cam)
}
}
fmt.Printf("Camera %s: %v\n", cam, sc.Camera.GenGoSet(""))
nv.UpdateView()
})
})
}
tree.Add(p, func(w *core.Separator) {})
tree.Add(p, func(w *core.Text) {
w.SetText("Time:").
SetTooltip("states are recorded over time -- last N can be reviewed using these buttons")
})
tree.AddAt(p, "rec", func(w *core.Text) {
w.SetText(fmt.Sprintf(" %4d ", nv.RecNo)).
SetTooltip("current view record: -1 means latest, 0 = earliest")
w.Styler(func(s *styles.Style) {
s.Min.X.Ch(5)
})
w.Updater(func() {
w.SetText(fmt.Sprintf(" %4d ", nv.RecNo))
})
})
tree.Add(p, func(w *core.Button) {
w.SetIcon(icons.FirstPage).SetTooltip("move to first record (start of history)")
w.OnClick(func(e events.Event) {
if nv.RecFullBkwd() {
nv.UpdateView()
}
})
})
tree.Add(p, func(w *core.Button) {
w.SetIcon(icons.FastRewind).SetTooltip("move earlier by N records (default 10)")
w.Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.RepeatClickable)
})
w.OnClick(func(e events.Event) {
if nv.RecFastBkwd() {
nv.UpdateView()
}
})
})
tree.Add(p, func(w *core.Button) {
w.SetIcon(icons.SkipPrevious).SetTooltip("move earlier by 1")
w.Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.RepeatClickable)
})
w.OnClick(func(e events.Event) {
if nv.RecBkwd() {
nv.UpdateView()
}
})
})
tree.Add(p, func(w *core.Button) {
w.SetIcon(icons.PlayArrow).SetTooltip("move to latest and always display latest (-1)")
w.OnClick(func(e events.Event) {
if nv.RecTrackLatest() {
nv.UpdateView()
}
})
})
tree.Add(p, func(w *core.Button) {
w.SetIcon(icons.SkipNext).SetTooltip("move later by 1")
w.Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.RepeatClickable)
})
w.OnClick(func(e events.Event) {
if nv.RecFwd() {
nv.UpdateView()
}
})
})
tree.Add(p, func(w *core.Button) {
w.SetIcon(icons.FastForward).SetTooltip("move later by N (default 10)")
w.Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.RepeatClickable)
})
w.OnClick(func(e events.Event) {
if nv.RecFastFwd() {
nv.UpdateView()
}
})
})
tree.Add(p, func(w *core.Button) {
w.SetIcon(icons.LastPage).SetTooltip("move to end (current time, tracking latest updates)")
w.OnClick(func(e events.Event) {
if nv.RecTrackLatest() {
nv.UpdateView()
}
})
})
}
// Code generated by "core generate -add-types"; DO NOT EDIT.
package netview
import (
"sync"
"cogentcore.org/core/colors/colormap"
"cogentcore.org/core/core"
"cogentcore.org/core/tree"
"cogentcore.org/core/types"
)
var _ = types.AddType(&types.Type{Name: "github.com/emer/emergent/v2/netview.LayData", IDName: "lay-data", Doc: "LayData maintains a record of all the data for a given layer", Fields: []types.Field{{Name: "LayName", Doc: "the layer name"}, {Name: "NUnits", Doc: "cached number of units"}, {Name: "Data", Doc: "the full data, in that order"}, {Name: "RecvPaths", Doc: "receiving pathway data -- shared with SendPaths"}, {Name: "SendPaths", Doc: "sending pathway data -- shared with RecvPaths"}}})
var _ = types.AddType(&types.Type{Name: "github.com/emer/emergent/v2/netview.PathData", IDName: "path-data", Doc: "PathData holds display state for a pathway", Fields: []types.Field{{Name: "Send", Doc: "name of sending layer"}, {Name: "Recv", Doc: "name of recv layer"}, {Name: "Path", Doc: "source pathway"}, {Name: "SynData", Doc: "synaptic data, by variable in SynVars and number of data points"}}})
var _ = types.AddType(&types.Type{Name: "github.com/emer/emergent/v2/netview.Scene", IDName: "scene", Doc: "Scene is a Widget for managing the 3D Scene of the NetView", Embeds: []types.Field{{Name: "Scene"}}, Fields: []types.Field{{Name: "NetView"}}})
// NewScene returns a new [Scene] with the given optional parent:
// Scene is a Widget for managing the 3D Scene of the NetView
func NewScene(parent ...tree.Node) *Scene { return tree.New[Scene](parent...) }
// SetNetView sets the [Scene.NetView]
func (t *Scene) SetNetView(v *NetView) *Scene { t.NetView = v; return t }
var _ = types.AddType(&types.Type{Name: "github.com/emer/emergent/v2/netview.LayMesh", IDName: "lay-mesh", Doc: "LayMesh is a xyz.Mesh that represents a layer -- it is dynamically updated using the\nUpdate method which only resets the essential Vertex elements.\nThe geometry is literal in the layer size: 0,0,0 lower-left corner and increasing X,Z\nfor the width and height of the layer, in unit (1) increments per unit..\nNetView applies an overall scaling to make it fit within the larger view.", Embeds: []types.Field{{Name: "MeshBase"}}, Fields: []types.Field{{Name: "Lay", Doc: "layer that we render"}, {Name: "Shape", Doc: "current shape that has been constructed -- if same, just update"}, {Name: "View", Doc: "netview that we're in"}}})
var _ = types.AddType(&types.Type{Name: "github.com/emer/emergent/v2/netview.LayObj", IDName: "lay-obj", Doc: "LayObj is the Layer 3D object within the NetView", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Embeds: []types.Field{{Name: "Solid"}}, Fields: []types.Field{{Name: "LayName", Doc: "name of the layer we represent"}, {Name: "NetView", Doc: "our netview"}}})
// NewLayObj returns a new [LayObj] with the given optional parent:
// LayObj is the Layer 3D object within the NetView
func NewLayObj(parent ...tree.Node) *LayObj { return tree.New[LayObj](parent...) }
// SetLayName sets the [LayObj.LayName]:
// name of the layer we represent
func (t *LayObj) SetLayName(v string) *LayObj { t.LayName = v; return t }
// SetNetView sets the [LayObj.NetView]:
// our netview
func (t *LayObj) SetNetView(v *NetView) *LayObj { t.NetView = v; return t }
var _ = types.AddType(&types.Type{Name: "github.com/emer/emergent/v2/netview.LayName", IDName: "lay-name", Doc: "LayName is the Layer name as a Text2D within the NetView", Embeds: []types.Field{{Name: "Text2D"}}, Fields: []types.Field{{Name: "NetView", Doc: "our netview"}}})
// NewLayName returns a new [LayName] with the given optional parent:
// LayName is the Layer name as a Text2D within the NetView
func NewLayName(parent ...tree.Node) *LayName { return tree.New[LayName](parent...) }
// SetNetView sets the [LayName.NetView]:
// our netview
func (t *LayName) SetNetView(v *NetView) *LayName { t.NetView = v; return t }
var _ = types.AddType(&types.Type{Name: "github.com/emer/emergent/v2/netview.NetData", IDName: "net-data", Doc: "NetData maintains a record of all the network data that has been displayed\nup to a given maximum number of records (updates), using efficient ring index logic\nwith no copying to store in fixed-sized buffers.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Methods: []types.Method{{Name: "OpenJSON", Doc: "OpenJSON opens colors from a JSON-formatted file.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename"}, Returns: []string{"error"}}, {Name: "SaveJSON", Doc: "SaveJSON saves colors to a JSON-formatted file.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename"}, Returns: []string{"error"}}}, Fields: []types.Field{{Name: "Net", Doc: "the network that we're viewing"}, {Name: "NoSynData", Doc: "copied from Params -- do not record synapse level data -- turn this on for very large networks where recording the entire synaptic state would be prohibitive"}, {Name: "PathLay", Doc: "name of the layer with unit for viewing pathways (connection / synapse-level values)"}, {Name: "PathUnIndex", Doc: "1D index of unit within PathLay for for viewing pathways"}, {Name: "PathType", Doc: "copied from NetView Params: if non-empty, this is the type pathway to show when there are multiple pathways from the same layer -- e.g., Inhib, Lateral, Forward, etc"}, {Name: "UnVars", Doc: "the list of unit variables saved"}, {Name: "UnVarIndexes", Doc: "index of each variable in the Vars slice"}, {Name: "SynVars", Doc: "the list of synaptic variables saved"}, {Name: "SynVarIndexes", Doc: "index of synaptic variable in the SynVars slice"}, {Name: "Ring", Doc: "the circular ring index -- Max here is max number of values to store, Len is number stored, and Index(Len-1) is the most recent one, etc"}, {Name: "MaxData", Doc: "max data parallel data per unit"}, {Name: "LayData", Doc: "the layer data -- map keyed by layer name"}, {Name: "UnMinPer", Doc: "unit var min values for each Ring.Max * variable"}, {Name: "UnMaxPer", Doc: "unit var max values for each Ring.Max * variable"}, {Name: "UnMinVar", Doc: "min values for unit variables"}, {Name: "UnMaxVar", Doc: "max values for unit variables"}, {Name: "SynMinVar", Doc: "min values for syn variables"}, {Name: "SynMaxVar", Doc: "max values for syn variables"}, {Name: "Counters", Doc: "counter strings"}, {Name: "RasterCtrs", Doc: "raster counter values"}, {Name: "RasterMap", Doc: "map of raster counter values to record numbers"}, {Name: "RastCtr", Doc: "dummy raster counter when passed a -1 -- increments and wraps around"}}})
var _ = types.AddType(&types.Type{Name: "github.com/emer/emergent/v2/netview.NetView", IDName: "net-view", Doc: "NetView is a Cogent Core Widget that provides a 3D network view using the Cogent Core gi3d\n3D framework.", Methods: []types.Method{{Name: "PlotSelectedUnit", Doc: "PlotSelectedUnit opens a window with a plot of all the data for the\ncurrently selected unit.\nUseful for replaying detailed trace for units of interest.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Returns: []string{"Table", "PlotEditor"}}, {Name: "Current", Doc: "Current records the current state of the network, including synaptic values,\nand updates the display. Use this when switching to NetView tab after network\nhas been running while viewing another tab, because the network state\nis typically not recored then.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "SaveWeights", Doc: "SaveWeights saves the network weights.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename"}}, {Name: "OpenWeights", Doc: "OpenWeights opens the network weights.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename"}}, {Name: "ShowNonDefaultParams", Doc: "ShowNonDefaultParams shows a dialog of all the parameters that\nare not at their default values in the network. Useful for setting params.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Returns: []string{"string"}}, {Name: "ShowAllParams", Doc: "ShowAllParams shows a dialog of all the parameters in the network.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Returns: []string{"string"}}, {Name: "ShowKeyLayerParams", Doc: "ShowKeyLayerParams shows a dialog with a listing for all layers in the network,\nof the most important layer-level params (specific to each algorithm)", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Returns: []string{"string"}}, {Name: "ShowKeyPathParams", Doc: "ShowKeyPathParams shows a dialog with a listing for all Recv pathways in the network,\nof the most important pathway-level params (specific to each algorithm)", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Returns: []string{"string"}}}, Embeds: []types.Field{{Name: "Frame"}}, Fields: []types.Field{{Name: "Net", Doc: "the network that we're viewing"}, {Name: "Var", Doc: "current variable that we're viewing"}, {Name: "Di", Doc: "current data parallel index di, for networks capable of processing input patterns in parallel."}, {Name: "Vars", Doc: "the list of variables to view"}, {Name: "SynVars", Doc: "list of synaptic variables"}, {Name: "SynVarsMap", Doc: "map of synaptic variable names to index"}, {Name: "VarOptions", Doc: "parameters for the list of variables to view"}, {Name: "CurVarOptions", Doc: "current var params -- only valid during Update of display"}, {Name: "Options", Doc: "parameters controlling how the view is rendered"}, {Name: "ColorMap", Doc: "color map for mapping values to colors -- set by name in Options"}, {Name: "ColorMapButton", Doc: "color map value representing ColorMap"}, {Name: "RecNo", Doc: "record number to display -- use -1 to always track latest, otherwise in range"}, {Name: "LastCtrs", Doc: "last non-empty counters string provided -- re-used if no new one"}, {Name: "CurCtrs", Doc: "current counters"}, {Name: "Data", Doc: "contains all the network data with history"}, {Name: "DataMu", Doc: "mutex on data access"}, {Name: "layerNameSizeShown", Doc: "these are used to detect need to update"}, {Name: "hasPaths"}, {Name: "pathTypeShown"}, {Name: "pathWidthShown"}}})
// NewNetView returns a new [NetView] with the given optional parent:
// NetView is a Cogent Core Widget that provides a 3D network view using the Cogent Core gi3d
// 3D framework.
func NewNetView(parent ...tree.Node) *NetView { return tree.New[NetView](parent...) }
// SetDi sets the [NetView.Di]:
// current data parallel index di, for networks capable of processing input patterns in parallel.
func (t *NetView) SetDi(v int) *NetView { t.Di = v; return t }
// SetVars sets the [NetView.Vars]:
// the list of variables to view
func (t *NetView) SetVars(v ...string) *NetView { t.Vars = v; return t }
// SetSynVars sets the [NetView.SynVars]:
// list of synaptic variables
func (t *NetView) SetSynVars(v ...string) *NetView { t.SynVars = v; return t }
// SetSynVarsMap sets the [NetView.SynVarsMap]:
// map of synaptic variable names to index
func (t *NetView) SetSynVarsMap(v map[string]int) *NetView { t.SynVarsMap = v; return t }
// SetVarOptions sets the [NetView.VarOptions]:
// parameters for the list of variables to view
func (t *NetView) SetVarOptions(v map[string]*VarOptions) *NetView { t.VarOptions = v; return t }
// SetCurVarOptions sets the [NetView.CurVarOptions]:
// current var params -- only valid during Update of display
func (t *NetView) SetCurVarOptions(v *VarOptions) *NetView { t.CurVarOptions = v; return t }
// SetOptions sets the [NetView.Options]:
// parameters controlling how the view is rendered
func (t *NetView) SetOptions(v Options) *NetView { t.Options = v; return t }
// SetColorMap sets the [NetView.ColorMap]:
// color map for mapping values to colors -- set by name in Options
func (t *NetView) SetColorMap(v *colormap.Map) *NetView { t.ColorMap = v; return t }
// SetColorMapButton sets the [NetView.ColorMapButton]:
// color map value representing ColorMap
func (t *NetView) SetColorMapButton(v *core.ColorMapButton) *NetView { t.ColorMapButton = v; return t }
// SetRecNo sets the [NetView.RecNo]:
// record number to display -- use -1 to always track latest, otherwise in range
func (t *NetView) SetRecNo(v int) *NetView { t.RecNo = v; return t }
// SetLastCtrs sets the [NetView.LastCtrs]:
// last non-empty counters string provided -- re-used if no new one
func (t *NetView) SetLastCtrs(v string) *NetView { t.LastCtrs = v; return t }
// SetCurCtrs sets the [NetView.CurCtrs]:
// current counters
func (t *NetView) SetCurCtrs(v string) *NetView { t.CurCtrs = v; return t }
// SetData sets the [NetView.Data]:
// contains all the network data with history
func (t *NetView) SetData(v NetData) *NetView { t.Data = v; return t }
// SetDataMu sets the [NetView.DataMu]:
// mutex on data access
func (t *NetView) SetDataMu(v sync.RWMutex) *NetView { t.DataMu = v; return t }
var _ = types.AddType(&types.Type{Name: "github.com/emer/emergent/v2/netview.RasterOptions", IDName: "raster-options", Doc: "RasterOptions holds parameters controlling the raster plot view", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Fields: []types.Field{{Name: "On", Doc: "if true, show a raster plot over time, otherwise units"}, {Name: "XAxis", Doc: "if true, the raster counter (time) is plotted across the X axis -- otherwise the Z depth axis"}, {Name: "Max", Doc: "maximum count for the counter defining the raster plot"}, {Name: "UnitSize", Doc: "size of a single unit, where 1 = full width and no space.. 1 default"}, {Name: "UnitHeight", Doc: "height multiplier for units, where 1 = full height.. 0.2 default"}}})
var _ = types.AddType(&types.Type{Name: "github.com/emer/emergent/v2/netview.Options", IDName: "options", Doc: "Options holds parameters controlling how the view is rendered", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Fields: []types.Field{{Name: "Paths", Doc: "whether to display the pathways between layers as arrows"}, {Name: "PathType", Doc: "path type name(s) to display (space separated), for path arrows,\nand when there are multiple pathways from the same layer.\nFor arrows, uses the style class names to match, which includes type name\nand other factors.\nUses case insensitive contains logic for each name."}, {Name: "PathWidth", Doc: "width of the path arrows, in normalized units"}, {Name: "Raster", Doc: "raster plot parameters"}, {Name: "NoSynData", Doc: "do not record synapse level data -- turn this on for very large networks where recording the entire synaptic state would be prohibitive"}, {Name: "MaxRecs", Doc: "maximum number of records to store to enable rewinding through prior states"}, {Name: "NVarCols", Doc: "number of variable columns"}, {Name: "UnitSize", Doc: "size of a single unit, where 1 = full width and no space.. .9 default"}, {Name: "LayerNameSize", Doc: "size of the layer name labels -- entire network view is unit sized"}, {Name: "ColorMap", Doc: "name of color map to use"}, {Name: "ZeroAlpha", Doc: "opacity (0-1) of zero values -- greater magnitude values become increasingly opaque on either side of this minimum"}, {Name: "NFastSteps", Doc: "the number of records to jump for fast forward/backward"}}})
var _ = types.AddType(&types.Type{Name: "github.com/emer/emergent/v2/netview.VarOptions", IDName: "var-options", Doc: "VarOptions holds parameters for display of each variable", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Fields: []types.Field{{Name: "Var", Doc: "name of the variable"}, {Name: "ZeroCtr", Doc: "keep Min - Max centered around 0, and use negative heights for units -- else use full min-max range for height (no negative heights)"}, {Name: "Range", Doc: "range to display"}, {Name: "MinMax", Doc: "if not using fixed range, this is the actual range of data"}}})
var _ = types.AddType(&types.Type{Name: "github.com/emer/emergent/v2/netview.pathData", IDName: "path-data", Fields: []types.Field{{Name: "path"}, {Name: "sSide"}, {Name: "rSide"}, {Name: "cat"}, {Name: "sIdx"}, {Name: "sN"}, {Name: "rIdx"}, {Name: "rN"}, {Name: "sPos"}, {Name: "rPos"}}})
var _ = types.AddType(&types.Type{Name: "github.com/emer/emergent/v2/netview.layerData", IDName: "layer-data", Fields: []types.Field{{Name: "paths"}, {Name: "selfPaths"}}})
var _ = types.AddType(&types.Type{Name: "github.com/emer/emergent/v2/netview.ViewUpdate", IDName: "view-update", Doc: "ViewUpdate manages time scales for updating the NetView", Fields: []types.Field{{Name: "View", Doc: "the network view"}, {Name: "Testing", Doc: "whether in testing mode -- can be set in advance to drive appropriate updating"}, {Name: "Text", Doc: "text to display at the bottom of the view"}, {Name: "On", Doc: "toggles update of display on"}, {Name: "SkipInvis", Doc: "if true, do not record network data when the NetView is invisible -- this speeds up running when not visible, but the NetView display will not show the current state when switching back to it"}, {Name: "Train", Doc: "at what time scale to update the display during training?"}, {Name: "Test", Doc: "at what time scale to update the display during testing?"}}})
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package netview
import (
"strings"
"github.com/emer/emergent/v2/etime"
)
// ViewUpdate manages time scales for updating the NetView
type ViewUpdate struct {
// View is the network view.
View *NetView `display:"-"`
// whether in testing mode -- can be set in advance to drive appropriate updating
Testing bool `display:"-"`
// text to display at the bottom of the view
Text string `display:"-"`
// toggles update of display on
On bool
// SkipInvis means do not record network data when the NetView is invisible.
// This speeds up running when not visible, but the NetView display will
// not show the current state when switching back to it.
SkipInvis bool
// at what time scale to update the display during training?
Train etime.Times
// at what time scale to update the display during testing?
Test etime.Times
}
// Config configures for given NetView and default train, test times
func (vu *ViewUpdate) Config(nv *NetView, train, test etime.Times) {
vu.View = nv
vu.On = true
vu.Train = train
vu.Test = test
vu.SkipInvis = true // more often running than debugging probably
}
// GetUpdateTime returns the relevant update time based on testing flag
func (vu *ViewUpdate) GetUpdateTime(testing bool) etime.Times {
if testing {
return vu.Test
}
return vu.Train
}
// GoUpdate does an update if view is On, visible and active,
// including recording new data and driving update of display.
// This version is only for calling from a separate goroutine,
// not the main event loop (see also Update).
func (vu *ViewUpdate) GoUpdate() {
if !vu.On || vu.View == nil {
return
}
if !vu.View.IsVisible() && vu.SkipInvis {
vu.View.RecordCounters(vu.Text)
return
}
vu.View.Record(vu.Text, -1) // -1 = use a dummy counter
// note: essential to use Go version of update when called from another goroutine
if vu.View.IsVisible() {
vu.View.GoUpdateView()
}
}
// Update does an update if view is On, visible and active,
// including recording new data and driving update of display.
// This version is only for calling from the main event loop
// (see also GoUpdate).
func (vu *ViewUpdate) Update() {
if !vu.On || vu.View == nil {
return
}
if !vu.View.IsVisible() && vu.SkipInvis {
vu.View.RecordCounters(vu.Text)
return
}
vu.View.Record(vu.Text, -1) // -1 = use a dummy counter
// note: essential to use Go version of update when called from another goroutine
if vu.View.IsVisible() {
vu.View.UpdateView()
}
}
// UpdateWhenStopped does an update when the network updating was stopped
// either via stepping or hitting the stop button.
// This has different logic for the raster view vs. regular.
// This is only for calling from a separate goroutine,
// not the main event loop.
func (vu *ViewUpdate) UpdateWhenStopped() {
if !vu.On || vu.View == nil {
return
}
if !vu.View.IsVisible() && vu.SkipInvis {
vu.View.RecordCounters(vu.Text)
return
}
if !vu.View.Options.Raster.On { // always record when not in raster mode
vu.View.Record(vu.Text, -1) // -1 = use a dummy counter
}
// todo: updating is not available here -- needed?
// if vu.View.Scene.Is(core.ScUpdating) {
// return
// }
vu.View.GoUpdateView()
}
// UpdateTime triggers an update at given timescale.
func (vu *ViewUpdate) UpdateTime(time etime.Times) {
if !vu.On || vu.View == nil {
return
}
viewUpdate := vu.GetUpdateTime(vu.Testing)
if viewUpdate == time {
vu.GoUpdate()
} else {
if viewUpdate < etime.Trial && time == etime.Trial {
if vu.View.Options.Raster.On { // no extra rec here
vu.View.Data.RecordLastCtrs(vu.Text)
if vu.View.IsVisible() {
vu.View.GoUpdateView()
}
} else {
vu.GoUpdate()
}
}
}
}
// IsCycleUpdating returns true if the view is updating at a cycle level,
// either from raster or literal cycle level.
func (vu *ViewUpdate) IsCycleUpdating() bool {
if !vu.On || vu.View == nil || !(vu.View.IsVisible() || !vu.SkipInvis) {
return false
}
viewUpdate := vu.GetUpdateTime(vu.Testing)
if viewUpdate > etime.ThetaCycle {
return false
}
if viewUpdate == etime.Cycle {
return true
}
if vu.View.Options.Raster.On {
return true
}
return false
}
// IsViewingSynapse returns true if netview is actively viewing synapses.
func (vu *ViewUpdate) IsViewingSynapse() bool {
if !vu.On || vu.View == nil || !(vu.View.IsVisible() || !vu.SkipInvis) {
return false
}
vvar := vu.View.Var
if strings.HasPrefix(vvar, "r.") || strings.HasPrefix(vvar, "s.") {
return true
}
return false
}
// UpdateCycle triggers an update at the Cycle (Millisecond) timescale,
// using given text to display at bottom of view
func (vu *ViewUpdate) UpdateCycle(cyc int) {
if !vu.On || vu.View == nil {
return
}
viewUpdate := vu.GetUpdateTime(vu.Testing)
if viewUpdate > etime.ThetaCycle {
return
}
if vu.View.Options.Raster.On {
vu.UpdateCycleRaster(cyc)
return
}
switch viewUpdate {
case etime.Cycle:
vu.GoUpdate()
case etime.FastSpike:
if cyc%10 == 0 {
vu.GoUpdate()
}
case etime.GammaCycle:
if cyc%25 == 0 {
vu.GoUpdate()
}
case etime.BetaCycle:
if cyc%50 == 0 {
vu.GoUpdate()
}
case etime.AlphaCycle:
if cyc%100 == 0 {
vu.GoUpdate()
}
case etime.ThetaCycle:
if cyc%200 == 0 {
vu.GoUpdate()
}
}
}
// UpdateCycleRaster raster version of Cycle update
func (vu *ViewUpdate) UpdateCycleRaster(cyc int) {
if !vu.View.IsVisible() && vu.SkipInvis {
vu.View.RecordCounters(vu.Text)
return
}
viewUpdate := vu.GetUpdateTime(vu.Testing)
vu.View.Record(vu.Text, cyc)
switch viewUpdate {
case etime.Cycle:
vu.View.GoUpdateView()
case etime.FastSpike:
if cyc%10 == 0 {
vu.View.GoUpdateView()
}
case etime.GammaCycle:
if cyc%25 == 0 {
vu.View.GoUpdateView()
}
case etime.BetaCycle:
if cyc%50 == 0 {
vu.View.GoUpdateView()
}
case etime.AlphaCycle:
if cyc%100 == 0 {
vu.View.GoUpdateView()
}
case etime.ThetaCycle:
if cyc%200 == 0 {
vu.View.GoUpdateView()
}
}
}
// RecordSyns records synaptic data -- stored separate from unit data
// and only needs to be called when synaptic values are updated.
// Should be done when the DWt values have been computed, before
// updating Wts and zeroing.
// NetView displays this recorded data when Update is next called.
func (vu *ViewUpdate) RecordSyns() {
if !vu.On || vu.View == nil {
return
}
if !vu.View.IsVisible() {
if vu.SkipInvis || !vu.IsViewingSynapse() {
return
}
}
vu.View.RecordSyns()
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package params
import (
"errors"
"fmt"
"log/slog"
"slices"
"strings"
)
// Apply checks if Sel selector applies to this object according to (.Class, #Name, Type)
// using the Styler interface, and returns false if it does not. If it does apply,
// then the Set function is called on the object.
func (ps *Sel[T]) Apply(obj T) bool {
if !SelMatch(ps.Sel, obj) {
return false
}
ps.Set(obj)
return true
}
// SelMatch returns true if Sel selector matches the target object properties.
func SelMatch[T Styler](sel string, obj T) bool {
if sel == "" {
return true
}
if sel[0] == '.' { // class
return ClassMatch(sel[1:], obj.StyleClass())
}
if sel[0] == '#' { // name
return obj.StyleName() == sel[1:]
}
return true // type always matches
}
// ClassMatch returns true if given class names match.
// Handles space-separated multiple class names.
func ClassMatch(sel, cls string) bool {
return slices.Contains(strings.Fields(cls), sel)
}
//////// Sheet
// Apply applies entire sheet to given object, using Sel's in order.
// returns true if any Sel's applied, and error if any errors.
func (ps *Sheet[T]) Apply(obj T) bool {
applied := false
for _, sl := range *ps {
app := sl.Apply(obj)
if app {
applied = true
sl.NMatch++
}
}
return applied
}
// SelMatchReset resets the Sel.NMatch counter used to find cases where no Sel
// matched any target objects. Call at start of application process, which
// may be at an outer-loop of Apply calls (e.g., for a Network, Apply is called
// for each Layer and Path), so this must be called separately.
// See SelNoMatchWarn for warning call at end.
func (ps *Sheet[T]) SelMatchReset() {
for _, sl := range *ps {
sl.NMatch = 0
}
}
// SelNoMatchWarn issues warning messages for any Sel selectors that had no
// matches during the last Apply process -- see SelMatchReset.
// The sheetName and objName provide info about the Sheet and obj being applied.
// Returns an error message with the non-matching sets if any, else nil.
func (ps *Sheet[T]) SelNoMatchWarn(sheetName, objName string) error {
msg := ""
for _, sl := range *ps {
if sl.NMatch == 0 {
msg += "\tSel: " + sl.Sel + "\n"
}
}
if msg != "" {
msg = fmt.Sprintf("param.Sheet from Sheet: %s for object: %s had the following non-matching Selectors:\n%s", sheetName, objName, msg)
slog.Warn(msg)
return errors.New(msg)
}
return nil
}
// Code generated by "core generate -add-types"; DO NOT EDIT.
package params
import (
"cogentcore.org/core/enums"
)
var _TweakTypesValues = []TweakTypes{0, 1}
// TweakTypesN is the highest valid value for type TweakTypes, plus one.
const TweakTypesN TweakTypes = 2
var _TweakTypesValueMap = map[string]TweakTypes{`Increment`: 0, `Log`: 1}
var _TweakTypesDescMap = map[TweakTypes]string{0: `Increment increments around current value, e.g., if .5, generates .4 and .6`, 1: `Log uses the quasi-log scheme: 1, 2, 5, 10 etc, which only applies if value is one of those numbers.`}
var _TweakTypesMap = map[TweakTypes]string{0: `Increment`, 1: `Log`}
// String returns the string representation of this TweakTypes value.
func (i TweakTypes) String() string { return enums.String(i, _TweakTypesMap) }
// SetString sets the TweakTypes value from its string representation,
// and returns an error if the string is invalid.
func (i *TweakTypes) SetString(s string) error {
return enums.SetString(i, s, _TweakTypesValueMap, "TweakTypes")
}
// Int64 returns the TweakTypes value as an int64.
func (i TweakTypes) Int64() int64 { return int64(i) }
// SetInt64 sets the TweakTypes value from an int64.
func (i *TweakTypes) SetInt64(in int64) { *i = TweakTypes(in) }
// Desc returns the description of the TweakTypes value.
func (i TweakTypes) Desc() string { return enums.Desc(i, _TweakTypesDescMap) }
// TweakTypesValues returns all possible values for the type TweakTypes.
func TweakTypesValues() []TweakTypes { return _TweakTypesValues }
// Values returns all possible values for the type TweakTypes.
func (i TweakTypes) Values() []enums.Enum { return enums.Values(_TweakTypesValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i TweakTypes) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *TweakTypes) UnmarshalText(text []byte) error {
return enums.UnmarshalText(i, text, "TweakTypes")
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package params
//go:generate core generate -add-types
import (
"fmt"
"cogentcore.org/core/base/errors"
)
// Sel specifies a selector for the scope of application of a set of
// parameters, using standard css selector syntax (. prefix = class, # prefix = name,
// and no prefix = type). Type always matches, and generally should come first as an
// initial set of defaults.
type Sel[T Styler] struct {
// Sel is the selector for what to apply the parameters to,
// using standard css selector syntax:
// - .Example applies to anything with a Class tag of 'Example'
// - #Example applies to anything with a Name of 'Example'
// - Example with no prefix or blank selector always applies as the presumed Type.
Sel string `width:"30"`
// Doc is documentation of these parameter values: what effect
// do they have? what range was explored? It is valuable to record
// this information as you explore the params.
Doc string `width:"60"`
// Set function applies parameter values to the given object of the target type.
Set func(v T) `display:"-"`
// NMatch is the number of times this selector matched a target
// during the last Apply process. A warning is issued for any
// that remain at 0: See Sheet SelMatchReset and SelNoMatchWarn methods.
NMatch int `table:"-" toml:"-" json:"-" xml:"-" edit:"-"`
}
////////
// Sheet is a CSS-like style-sheet of params.Sel values, each of which represents
// a different set of specific parameter values applied according to the Sel selector:
// .Class #Name or Type.
//
// The order of elements in the Sheet list is critical, as they are applied
// in the order given by the list (slice), and thus later Sel's can override
// those applied earlier. Generally put more general Type-level parameters first,
// and then subsequently more specific ones (.Class and #Name).
type Sheet[T Styler] []*Sel[T]
// NewSheet returns a new Sheet for given type.
func NewSheet[T Styler]() *Sheet[T] {
sh := make(Sheet[T], 0)
return &sh
}
// ElemLabel satisfies the core.SliceLabeler interface to provide labels for slice elements.
func (sh *Sheet[T]) ElemLabel(idx int) string {
return (*sh)[idx].Sel
}
// SelByName returns given selector within the Sheet, by Name.
// Returns and logs error if not found.
func (sh *Sheet[T]) SelByName(sel string) (*Sel[T], error) {
for _, sl := range *sh {
if sl.Sel == sel {
return sl, nil
}
}
return nil, errors.Log(fmt.Errorf("params.Sheet: Sel named %v not found", sel))
}
////////
// Sheets are named collections of Sheet elements that can be chosen among
// depending on different desired configurations.
// Conventionally, there is always a Base configuration with basic-level
// defaults, and then any number of more specific sets to apply after that.
type Sheets[T Styler] map[string]*Sheet[T]
// SheetByName tries to find given set by name.
// Returns and logs error if not found.
func (ps *Sheets[T]) SheetByName(name string) (*Sheet[T], error) {
st, ok := (*ps)[name]
if ok {
return st, nil
}
return nil, errors.Log(fmt.Errorf("params.Sheets: Param Sheet named %q not found", name))
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package params
import (
"reflect"
"strings"
"cogentcore.org/core/base/indent"
"cogentcore.org/core/base/reflectx"
)
// PrintStruct returns a string representation of a struct
// for printing out parameter values. It uses standard Cogent Core
// display tags to produce results that resemble the GUI interface,
// and only includes exported fields.
// The optional filter function determines whether a field is included, based
// on the full path to the field (using . separators) and the field value.
// Indent provides the starting indent level (2 spaces).
// The optional format function returns a string representation of the value,
// if you want to override the default, which just uses
// reflectx.ToString (returning an empty string means use the default).
func PrintStruct(v any, indent int,
filter func(path string, ft reflect.StructField, fv any) bool,
format func(path string, ft reflect.StructField, fv any) string) string {
return printStruct("", indent, v, filter, format)
}
func addPath(par, field string) string {
if par == "" {
return field
}
return par + "." + field
}
func printStruct(parPath string, ident int, v any,
filter func(path string, ft reflect.StructField, fv any) bool,
format func(path string, ft reflect.StructField, fv any) string) string {
rv := reflectx.Underlying(reflect.ValueOf(v))
if reflectx.IsNil(rv) {
return ""
}
var b strings.Builder
rt := rv.Type()
nf := rt.NumField()
var fis []int
maxFieldW := 0
for i := range nf {
ft := rt.Field(i)
if !ft.IsExported() {
continue
}
fv := rv.Field(i)
fvi := fv.Interface()
pp := addPath(parPath, ft.Name)
if filter != nil && !filter(pp, ft, fvi) {
continue
}
fis = append(fis, i)
maxFieldW = max(maxFieldW, len(ft.Name))
}
for _, i := range fis {
fv := rv.Field(i)
ft := rt.Field(i)
fvi := fv.Interface()
pp := addPath(parPath, ft.Name)
is := indent.Spaces(ident, 2)
printName := func() {
b.WriteString(is)
b.WriteString(ft.Name)
b.WriteString(strings.Repeat(" ", 1+maxFieldW-len(ft.Name)))
}
ps := ""
if reflectx.NonPointerType(ft.Type).Kind() == reflect.Struct {
if ft.Tag.Get("display") == "inline" {
ps = printStructInline(pp, ident+1, fvi, filter, format)
if ps != "{ }" {
printName()
b.WriteString(ps)
b.WriteString("\n")
}
} else {
ps := printStruct(pp, ident+1, fvi, filter, format)
if ps != "" {
printName()
b.WriteString("{\n")
b.WriteString(ps)
b.WriteString(is)
b.WriteString("}\n")
}
}
continue
}
if ps == "" && format != nil {
ps = format(pp, ft, fvi)
if ps != "" {
printName()
b.WriteString(ps + "\n")
continue
}
}
printName()
ps = reflectx.ToString(fvi)
b.WriteString(ps + "\n")
}
return b.String()
}
func printStructInline(parPath string, ident int, v any,
filter func(path string, ft reflect.StructField, fv any) bool,
format func(path string, ft reflect.StructField, fv any) string) string {
rv := reflectx.Underlying(reflect.ValueOf(v))
if reflectx.IsNil(rv) {
return ""
}
var b strings.Builder
b.WriteString("{ ")
rt := rv.Type()
nf := rt.NumField()
didPrint := false
for i := range nf {
ft := rt.Field(i)
if !ft.IsExported() {
continue
}
fv := rv.Field(i)
fvi := fv.Interface()
pp := addPath(parPath, ft.Name)
if filter != nil && !filter(pp, ft, fvi) {
continue
}
printName := func() {
if didPrint {
b.WriteString(", ")
}
b.WriteString(ft.Name)
b.WriteString(": ")
didPrint = true
}
ps := ""
if reflectx.NonPointerType(ft.Type).Kind() == reflect.Struct {
ps = printStructInline(pp, ident+1, fvi, filter, format)
if ps != "{ }" {
printName()
b.WriteString(ps)
}
continue
}
if ps == "" && format != nil {
ps = format(pp, ft, fvi)
if ps != "" {
printName()
b.WriteString(ps)
continue
}
}
printName()
ps = reflectx.ToString(fvi)
b.WriteString(ps)
}
b.WriteString(" }")
return b.String()
}
// Copyright (c) 2024, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package params
import (
"fmt"
)
// Search represents one parameter search element, applying given selector
// to objects of given type, using float32 values generated by given function.
type Search[T Styler] struct {
// Sel is the selector for what to apply the parameters to,
// using standard css selector syntax:
// - .Example applies to anything with a Class tag of 'Example'
// - #Example applies to anything with a Name of 'Example'
// - Example with no prefix or blank selector always applies as the presumed Type.
Sel string
// Set function applies given parameter value to the given object
// of the target type.
Set func(v T, val float32)
// Vals is a function that generates values to search over. This is
// typically called once at the start of the entire search process to get
// a list of values that are cached, to determine the total number of search
// jobs required, and then it is called again when this parameter is at the
// top of the list to be searched, so that new values can potentially be
// generated.
Vals func() []float32
cached []float32
}
// Apply checks if Sel selector applies to this object according to
// (.Class, #Name, Type) using the Styler interface, and returns
// false if it does not. If it does apply, then the given value
// is passed to the Set function on the object, and true is returned.
func (ps *Search[T]) Apply(obj T, val float32) bool {
if !SelMatch(ps.Sel, obj) {
return false
}
ps.Set(obj, val)
return true
}
// Value returns the parameter value at given value index.
// Returns false if the value index is invalid.
func (ps *Search[T]) Value(valIndex int) (bool, float32) {
vals := ps.Values()
if valIndex >= len(vals) {
return false, 0
}
val := vals[valIndex]
return true, val
}
// Values returns the search values, using cached values if already set,
// and returning [Search.CacheValues] otherwise.
func (ps *Search[T]) Values() []float32 {
if ps.cached != nil {
return ps.cached
}
return ps.CacheValues()
}
// CacheValues calls the Vals function and caches the resulting values
// for later use. This can be called for dynamic searches to update the
// current values relative to any initial cached values, if the search is
// dynamic and depends on other state. Returns the updated cached values.
func (ps *Search[T]) CacheValues() []float32 {
ps.cached = ps.Vals()
return ps.cached
}
// JobString returns a string that identifies a param search job
// for given parameter index and value, for this item.
func (ps *Search[T]) JobString(paramIndex int, val float32) string {
return fmt.Sprintf("Search %d: %s=%g", paramIndex, ps.Sel, val)
}
// Searches is a list of [Search] elements, representing an entire
// parameter search process, where multiple parameters with multiple
// values per parameter are searched, typically in parallel across
// independent sim run jobs. Thus, you first get the total number of
// params via the [Searches.NumParams] method, and then launch jobs
// for each of these param indexes, which applies that param.
type Searches[T Styler] []*Search[T]
// NumParams returns the total number of parameter values to search.
// This calls the Vals function to generate initial search values.
func (sr Searches[T]) NumParams() int {
n := 0
for _, ps := range sr {
n += len(ps.Values())
}
return n
}
// SearchAtIndex returns the [Search] element at given parameter
// index in range [0..NumParams), and the value index within that search.
func (sr Searches[T]) SearchAtIndex(paramIndex int) (*Search[T], int) {
n := 0
for _, ps := range sr {
sn := len(ps.Values())
if paramIndex >= n && paramIndex < n+sn {
return ps, paramIndex - n
}
n += sn
}
return nil, -1
}
// SearchValue returns [Search] and value for parameter at given
// param index within searches, returning error for invalid index.
// Also returns string descriptor for parameter.
func (sr Searches[T]) SearchValue(paramIndex int) (*Search[T], float32, string, error) {
ps, valIndex := sr.SearchAtIndex(paramIndex)
if ps == nil {
return ps, 0, "", fmt.Errorf("No param search found for param index: %d", paramIndex)
}
ok, val := ps.Value(valIndex)
if !ok {
return ps, 0, "", fmt.Errorf("Param value index out of range: %d, for Sel: %s", valIndex, ps.Sel)
}
lbl := ps.JobString(paramIndex, val)
return ps, val, lbl, nil
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package params
import (
"slices"
"strings"
)
// Styler must be implemented by any object that parameters are
// applied to, to provide the .Class and #Name selector functionality.
type Styler interface {
// StyleClass returns the space-separated list of class selectors (tags).
// Parameters with a . prefix target class tags.
// Do NOT include the . in the Class tags on Styler objects;
// The . is only used in the Sel selector on the [Sel].
StyleClass() string
// StyleName returns the name of this object.
// Parameters with a # prefix target object names, which are typically
// unique. Do NOT include the # prefix in the actual object name,
// which is only present in the Sel selector on [Sel].
StyleName() string
}
// AddClass is a helper function that adds given class(es) to current
// class string, ensuring it is not a duplicate of existing, and properly
// adding spaces.
func AddClass(cur string, class ...string) string {
nc := strings.Fields(cur)
for _, cl := range class {
if !slices.Contains(nc, cl) {
nc = append(nc, cl)
}
}
return strings.Join(nc, " ")
}
// Copyright (c) 2024, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package params
import (
"cogentcore.org/core/math32"
)
// TweakTypes are the types of param tweak logic supported.
type TweakTypes int32 //enums:enum
const (
// Increment increments around current value, e.g., if .5, generates .4 and .6
Increment TweakTypes = iota
// Log uses the quasi-log scheme: 1, 2, 5, 10 etc, which only applies if value
// is one of those numbers.
Log
)
func tweakValue(msd, fact, exp10 float32, isRmdr bool) float32 {
if isRmdr {
return math32.Truncate(msd+fact*math32.Pow(10, exp10), 3)
}
return math32.Truncate(fact*math32.Pow(10, exp10), 3)
}
// Tweak returns parameter [Search] values to try,
// below and above the given value.
// Log: use the quasi-log scheme: 1, 2, 5, 10 etc. Only if val is one of these vals.
// Increment: use increments around current value: e.g., if .5, returns .4 and .6.
// These apply to the 2nd significant digit (remainder after most significant digit)
// if that is present in the original value.
func Tweak(v float32, typ TweakTypes) []float32 {
ex := math32.Floor(math32.Log10(v))
base := math32.Pow(10, ex)
basem1 := math32.Pow(10, ex-1)
fact := math32.Round(v / base)
msd := tweakValue(0, fact, ex, false)
rmdr := math32.Round((v - msd) / basem1)
var vals []float32
sv := fact
isRmdr := false
if rmdr != 0 {
if rmdr < 0 {
msd = tweakValue(0, fact-1, ex, false)
rmdr = math32.Round((v - msd) / basem1)
}
sv = rmdr
ex--
isRmdr = true
}
switch sv {
case 1:
if typ == Log {
vals = append(vals, tweakValue(msd, 5, ex-1, isRmdr), tweakValue(msd, 2, ex, isRmdr))
} else {
vals = append(vals, tweakValue(msd, 9, ex-1, isRmdr), tweakValue(msd, 1.1, ex, isRmdr))
}
case 2:
if typ == Log {
vals = append(vals, tweakValue(msd, 1, ex, isRmdr), tweakValue(msd, 5, ex, isRmdr))
} else {
vals = append(vals, tweakValue(msd, 1, ex, isRmdr), tweakValue(msd, 3, ex, isRmdr))
}
case 5:
if typ == Log {
vals = append(vals, tweakValue(msd, 2, ex, isRmdr), tweakValue(msd, 1, ex+1, isRmdr))
} else {
vals = append(vals, tweakValue(msd, 4, ex, isRmdr), tweakValue(msd, 6, ex, isRmdr))
}
case 9:
vals = append(vals, tweakValue(msd, 8, ex, isRmdr), tweakValue(msd, 1, ex+1, isRmdr))
default:
vals = append(vals, tweakValue(msd, sv-1, ex, isRmdr), tweakValue(msd, sv+1, ex, isRmdr))
}
return vals
}
// TweakPct returns parameter [Search] values to try, as given given percent
// below and above the given value.
func TweakPct(v, pct float32) []float32 {
trunc := 6
return []float32{math32.Truncate(v*(1-pct), trunc), math32.Truncate(v*(1+pct), trunc)}
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package paths
import (
"cogentcore.org/core/math32"
"cogentcore.org/core/math32/vecint"
"cogentcore.org/lab/tensor"
"github.com/emer/emergent/v2/edge"
"github.com/emer/emergent/v2/efuns"
)
// Circle implements a circular pattern of connectivity between two layers
// where the center moves in proportion to receiver position with offset
// and multiplier factors, and a given radius is used (with wrap-around
// optionally). A corresponding Gaussian bump of TopoWeights is available as well.
// Makes for a good center-surround connectivity pattern.
// 4D layers are automatically flattened to 2D for this connection.
type Circle struct {
// radius of the circle, in units from center in sending layer
Radius int
// starting offset in sending layer, for computing the corresponding sending center relative to given recv unit position
Start vecint.Vector2i
// scaling to apply to receiving unit position to compute sending center as function of recv unit position
Scale math32.Vector2
// auto-scale sending center positions as function of relative sizes of send and recv layers -- if Start is positive then assumes it is a border, subtracted from sending size
AutoScale bool
// if true, connectivity wraps around edges
Wrap bool
// if true, this path should set gaussian topographic weights, according to following parameters
TopoWeights bool
// gaussian sigma (width) as a proportion of the radius of the circle
Sigma float32
// maximum weight value for GaussWts function -- multiplies values
MaxWt float32
// if true, and connecting layer to itself (self pathway), then make a self-connection from unit to itself
SelfCon bool
}
func NewCircle() *Circle {
cr := &Circle{}
cr.Defaults()
return cr
}
func (cr *Circle) Defaults() {
cr.Wrap = true
cr.Radius = 8
cr.Scale.SetScalar(1)
cr.Sigma = 0.5
cr.MaxWt = 1
}
func (cr *Circle) Name() string {
return "Circle"
}
func (cr *Circle) Connect(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
sendn, recvn, cons = NewTensors(send, recv)
sNy, sNx, _, _ := tensor.Projection2DShape(send, false)
rNy, rNx, _, _ := tensor.Projection2DShape(recv, false)
rnv := recvn.Values
snv := sendn.Values
sNtot := send.Len()
sc := cr.Scale
if cr.AutoScale {
ssz := math32.Vec2(float32(sNx), float32(sNy))
if cr.Start.X >= 0 && cr.Start.Y >= 0 {
ssz.X -= float32(2 * cr.Start.X)
ssz.Y -= float32(2 * cr.Start.Y)
}
rsz := math32.Vec2(float32(rNx), float32(rNy))
sc = ssz.Div(rsz)
}
for ry := 0; ry < rNy; ry++ {
for rx := 0; rx < rNx; rx++ {
sctr := math32.Vec2(float32(rx)*sc.X+float32(cr.Start.X), float32(ry)*sc.Y+float32(cr.Start.Y))
for sy := 0; sy < sNy; sy++ {
for sx := 0; sx < sNx; sx++ {
sp := math32.Vec2(float32(sx), float32(sy))
if cr.Wrap {
sp.X = edge.WrapMinDist(sp.X, float32(sNx), sctr.X)
sp.Y = edge.WrapMinDist(sp.Y, float32(sNy), sctr.Y)
}
d := int(math32.Round(sp.DistanceTo(sctr)))
if d <= cr.Radius {
ri := tensor.Projection2DIndex(recv, false, ry, rx)
si := tensor.Projection2DIndex(send, false, sy, sx)
off := ri*sNtot + si
if !cr.SelfCon && same && ri == si {
continue
}
cons.Values.Set(true, off)
rnv[ri]++
snv[si]++
}
}
}
}
}
return
}
// GaussWts returns gaussian weight value for given unit indexes in
// given send and recv layers according to Gaussian Sigma and MaxWt.
// Can be used for a Path.SetScalesFunc or SetWtsFunc
func (cr *Circle) GaussWts(si, ri int, send, recv *tensor.Shape) float32 {
sNy, sNx, _, _ := tensor.Projection2DShape(send, false)
rNy, rNx, _, _ := tensor.Projection2DShape(recv, false)
ry := ri / rNx // todo: this is not right for 4d!
rx := ri % rNx
sy := si / sNx
sx := si % sNx
fsig := cr.Sigma * float32(cr.Radius)
sc := cr.Scale
if cr.AutoScale {
ssz := math32.Vec2(float32(sNx), float32(sNy))
if cr.Start.X >= 0 && cr.Start.Y >= 0 {
ssz.X -= float32(2 * cr.Start.X)
ssz.Y -= float32(2 * cr.Start.Y)
}
rsz := math32.Vec2(float32(rNx), float32(rNy))
sc = ssz.Div(rsz)
}
sctr := math32.Vec2(float32(rx)*sc.X+float32(cr.Start.X), float32(ry)*sc.Y+float32(cr.Start.Y))
sp := math32.Vec2(float32(sx), float32(sy))
if cr.Wrap {
sp.X = edge.WrapMinDist(sp.X, float32(sNx), sctr.X)
sp.Y = edge.WrapMinDist(sp.Y, float32(sNy), sctr.Y)
}
wt := cr.MaxWt * efuns.GaussVecDistNoNorm(sp, sctr, fsig)
return wt
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package paths
import "cogentcore.org/lab/tensor"
// Full implements full all-to-all pattern of connectivity between two layers
type Full struct {
// if true, and connecting layer to itself (self pathway), then make a self-connection from unit to itself
SelfCon bool
}
func NewFull() *Full {
return &Full{}
}
func (fp *Full) Name() string {
return "Full"
}
func (fp *Full) Connect(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
sendn, recvn, cons = NewTensors(send, recv)
cons.Values.SetAll(true)
nsend := send.Len()
nrecv := recv.Len()
if same && !fp.SelfCon {
for i := 0; i < nsend; i++ { // nsend = nrecv
off := i*nsend + i
cons.Values.Set(false, off)
}
nsend--
nrecv--
}
rnv := recvn.Values
for i := range rnv {
rnv[i] = int32(nsend)
}
snv := sendn.Values
for i := range snv {
snv[i] = int32(nrecv)
}
return
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package paths
import "cogentcore.org/lab/tensor"
// OneToOne implements point-to-point one-to-one pattern of connectivity between two layers
type OneToOne struct {
// number of recv connections to make (0 for entire size of recv layer)
NCons int
// starting unit index for sending connections
SendStart int
// starting unit index for recv connections
RecvStart int
}
func NewOneToOne() *OneToOne {
return &OneToOne{}
}
func (ot *OneToOne) Name() string {
return "OneToOne"
}
func (ot *OneToOne) Connect(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
sendn, recvn, cons = NewTensors(send, recv)
nsend := send.Len()
nrecv := recv.Len()
rnv := recvn.Values
snv := sendn.Values
ncon := nrecv
if ot.NCons > 0 {
ncon = min(ot.NCons, nrecv)
}
for i := 0; i < ncon; i++ {
ri := ot.RecvStart + i
si := ot.SendStart + i
if ri >= nrecv || si >= nsend {
break
}
off := ri*nsend + si
cons.Values.Set(true, off)
rnv[ri] = 1
snv[si] = 1
}
return
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package paths
//go:generate core generate -add-types
import (
"cogentcore.org/lab/tensor"
)
// Pattern defines a pattern of connectivity between two layers.
// The pattern is stored efficiently using a bitslice tensor of binary values indicating
// presence or absence of connection between two items.
// A receiver-based organization is generally assumed but connectivity can go either way.
type Pattern interface {
// Name returns the name of the pattern -- i.e., the "type" name of the actual pattern generatop
Name() string
// Connect connects layers with the given shapes, returning the pattern of connectivity
// as a bits tensor with shape = recv + send shapes, using row-major ordering with outer-most
// indexes first (i.e., for each recv unit, there is a full inner-level of sender bits).
// The number of connections for each recv and each send unit are also returned in
// recvn and send tensors, each the shape of send and recv respectively.
// The same flag should be set to true if the send and recv layers are the same (i.e., a self-connection)
// often there are some different options for such connections.
Connect(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool)
}
// NewTensors returns the tensors used for Connect method, based on layer sizes
func NewTensors(send, recv *tensor.Shape) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
sendn = tensor.NewInt32(send.Sizes...)
recvn = tensor.NewInt32(recv.Sizes...)
csh := tensor.AddShapes(recv, send)
cons = tensor.NewBoolShape(csh)
return
}
// ConsStringFull returns a []byte string showing the pattern of connectivity.
// if perRecv is true then it displays the sending connections
// per each recv unit -- otherwise it shows the entire matrix
// as a 2D matrix
func ConsStringFull(send, recv *tensor.Shape, cons *tensor.Bool) []byte {
nsend := send.Len()
nrecv := recv.Len()
one := []byte("1 ")
zero := []byte("0 ")
sz := nrecv * (nsend*2 + 1)
b := make([]byte, 0, sz)
for ri := 0; ri < nrecv; ri++ {
for si := 0; si < nsend; si++ {
off := ri*nsend + si
cn := cons.Value1D(off)
if cn {
b = append(b, one...)
} else {
b = append(b, zero...)
}
}
b = append(b, byte('\n'))
}
return b
}
// ConsStringPerRecv returns a []byte string showing the pattern of connectivity
// organized by receiving unit, showing the sending connections per each
func ConsStringPerRecv(send, recv *tensor.Shape, cons *tensor.Bool) []byte {
return nil
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package paths
import "cogentcore.org/lab/tensor"
// PoolOneToOne implements one-to-one connectivity between pools within layers.
// Pools are the outer-most two dimensions of a 4D layer shape.
// If either layer does not have pools, then if the number of individual
// units matches the number of pools in the other layer, those are connected one-to-one
// otherwise each pool connects to the entire set of other units.
// If neither is 4D, then it is equivalent to OneToOne.
type PoolOneToOne struct {
// number of recv pools to connect (0 for entire number of pools in recv layer)
NPools int
// starting pool index for sending connections
SendStart int
// starting pool index for recv connections
RecvStart int
}
func NewPoolOneToOne() *PoolOneToOne {
return &PoolOneToOne{}
}
func (ot *PoolOneToOne) Name() string {
return "PoolOneToOne"
}
func (ot *PoolOneToOne) Connect(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
switch {
case send.NumDims() == 4 && recv.NumDims() == 4:
return ot.ConnectPools(send, recv, same)
case send.NumDims() == 2 && recv.NumDims() == 4:
return ot.ConnectRecvPool(send, recv, same)
case send.NumDims() == 4 && recv.NumDims() == 2:
return ot.ConnectSendPool(send, recv, same)
case send.NumDims() == 2 && recv.NumDims() == 2:
return ot.ConnectOneToOne(send, recv, same)
}
return
}
// ConnectPools is when both recv and send have pools
func (ot *PoolOneToOne) ConnectPools(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
sendn, recvn, cons = NewTensors(send, recv)
sNtot := send.Len()
// rNtot := recv.Len()
sNp := send.DimSize(0) * send.DimSize(1)
rNp := recv.DimSize(0) * recv.DimSize(1)
sNu := send.DimSize(2) * send.DimSize(3)
rNu := recv.DimSize(2) * recv.DimSize(3)
rnv := recvn.Values
snv := sendn.Values
npl := rNp
if ot.NPools > 0 {
npl = min(ot.NPools, rNp)
}
for i := 0; i < npl; i++ {
rpi := ot.RecvStart + i
spi := ot.SendStart + i
if rpi >= rNp || spi >= sNp {
break
}
for rui := 0; rui < rNu; rui++ {
ri := rpi*rNu + rui
for sui := 0; sui < sNu; sui++ {
si := spi*sNu + sui
off := ri*sNtot + si
cons.Values.Set(true, off)
rnv[ri] = int32(sNu)
snv[si] = int32(rNu)
}
}
}
return
}
// ConnectRecvPool is when recv has pools but send doesn't
func (ot *PoolOneToOne) ConnectRecvPool(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
sendn, recvn, cons = NewTensors(send, recv)
sNtot := send.Len()
rNp := recv.DimSize(0) * recv.DimSize(1)
rNu := recv.DimSize(2) * recv.DimSize(3)
rnv := recvn.Values
snv := sendn.Values
npl := rNp
if ot.NPools > 0 {
npl = min(ot.NPools, rNp)
}
if sNtot == rNp { // one-to-one
for i := 0; i < npl; i++ {
rpi := ot.RecvStart + i
si := ot.SendStart + i
if rpi >= rNp || si >= sNtot {
break
}
for rui := 0; rui < rNu; rui++ {
ri := rpi*rNu + rui
off := ri*sNtot + si
cons.Values.Set(true, off)
rnv[ri] = int32(1)
snv[si] = int32(rNu)
}
}
} else { // full
for i := 0; i < npl; i++ {
rpi := ot.RecvStart + i
if rpi >= rNp {
break
}
for rui := 0; rui < rNu; rui++ {
ri := rpi*rNu + rui
for si := 0; si < sNtot; si++ {
off := ri*sNtot + si
cons.Values.Set(true, off)
rnv[ri] = int32(sNtot)
snv[si] = int32(npl * rNu)
}
}
}
}
return
}
// ConnectSendPool is when send has pools but recv doesn't
func (ot *PoolOneToOne) ConnectSendPool(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
sendn, recvn, cons = NewTensors(send, recv)
sNtot := send.Len()
rNtot := recv.Len()
sNp := send.DimSize(0) * send.DimSize(1)
sNu := send.DimSize(2) * send.DimSize(3)
rnv := recvn.Values
snv := sendn.Values
npl := sNp
if ot.NPools > 0 {
npl = min(ot.NPools, sNp)
}
if rNtot == sNp { // one-to-one
for i := 0; i < npl; i++ {
spi := ot.SendStart + i
ri := ot.RecvStart + i
if ri >= rNtot || spi >= sNp {
break
}
for sui := 0; sui < sNu; sui++ {
si := spi*sNu + sui
off := ri*sNtot + si
cons.Values.Set(true, off)
rnv[ri] = int32(sNu)
snv[si] = int32(1)
}
}
} else { // full
for i := 0; i < npl; i++ {
spi := ot.SendStart + i
if spi >= sNp {
break
}
for ri := 0; ri < rNtot; ri++ {
for sui := 0; sui < sNu; sui++ {
si := spi*sNu + sui
off := ri*sNtot + si
cons.Values.Set(true, off)
rnv[ri] = int32(npl * sNu)
snv[si] = int32(rNtot)
}
}
}
}
return
}
// copy of OneToOne.Connect
func (ot *PoolOneToOne) ConnectOneToOne(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
sendn, recvn, cons = NewTensors(send, recv)
sNtot := send.Len()
rNtot := recv.Len()
rnv := recvn.Values
snv := sendn.Values
npl := rNtot
if ot.NPools > 0 {
npl = min(ot.NPools, rNtot)
}
for i := 0; i < npl; i++ {
ri := ot.RecvStart + i
si := ot.SendStart + i
if ri >= rNtot || si >= sNtot {
break
}
off := ri*sNtot + si
cons.Values.Set(true, off)
rnv[ri] = 1
snv[si] = 1
}
return
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package paths
import (
"cogentcore.org/core/math32"
"cogentcore.org/core/math32/vecint"
"cogentcore.org/lab/tensor"
"github.com/emer/emergent/v2/edge"
)
// PoolRect implements a rectangular pattern of connectivity between
// two 4D layers, in terms of their pool-level shapes,
// where the lower-left corner moves in proportion to receiver
// pool position with offset and multiplier factors (with wrap-around optionally).
type PoolRect struct {
// size of rectangle (of pools) in sending layer that each receiving unit receives from
Size vecint.Vector2i
// starting pool offset in sending layer, for computing the corresponding sending lower-left corner relative to given recv pool position
Start vecint.Vector2i
// scaling to apply to receiving pool osition to compute corresponding position in sending layer of the lower-left corner of rectangle
Scale math32.Vector2
// auto-set the Scale as function of the relative pool sizes of send and recv layers (e.g., if sending layer is 2x larger than receiving, Scale = 2)
AutoScale bool
// if true, use Round when applying scaling factor -- otherwise uses Floor which makes Scale work like a grouping factor -- e.g., .25 will effectively group 4 recv pools with same send position
RoundScale bool
// if true, connectivity wraps around all edges if it would otherwise go off the edge -- if false, then edges are clipped
Wrap bool
// if true, and connecting layer to itself (self pathway), then make a self-connection from unit to itself
SelfCon bool
// starting pool position in receiving layer -- if > 0 then pools below this starting point remain unconnected
RecvStart vecint.Vector2i
// number of pools in receiving layer to connect -- if 0 then all (remaining after RecvStart) are connected -- otherwise if < remaining then those beyond this point remain unconnected
RecvN vecint.Vector2i
}
func NewPoolRect() *PoolRect {
cr := &PoolRect{}
cr.Defaults()
return cr
}
func (cr *PoolRect) Defaults() {
cr.Wrap = true
cr.Size.Set(1, 1)
cr.Scale.SetScalar(1)
}
func (cr *PoolRect) Name() string {
return "PoolRect"
}
func (cr *PoolRect) Connect(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
sendn, recvn, cons = NewTensors(send, recv)
sNy := send.DimSize(0)
sNx := send.DimSize(1)
rNy := recv.DimSize(0)
rNx := recv.DimSize(1)
sNn := 1
rNn := 1
if send.NumDims() == 4 {
sNn = send.DimSize(2) * send.DimSize(3)
} else { // 2D
sNn = sNy * sNx
sNy = 1
sNx = 1
}
if recv.NumDims() == 4 {
rNn = recv.DimSize(2) * recv.DimSize(3)
} else { // 2D
rNn = rNy * rNx
rNy = 1
rNx = 1
}
rnv := recvn.Values
snv := sendn.Values
sNtot := send.Len()
sc := cr.Scale
if cr.AutoScale {
ssz := math32.Vec2(float32(sNx), float32(sNy))
rsz := math32.Vec2(float32(rNx), float32(rNy))
sc = ssz.Div(rsz)
}
rNyEff := rNy
if cr.RecvN.Y > 0 {
rNyEff = min(rNy, cr.RecvStart.Y+cr.RecvN.Y)
}
rNxEff := rNx
if cr.RecvN.X > 0 {
rNxEff = min(rNx, cr.RecvStart.X+cr.RecvN.X)
}
for ry := cr.RecvStart.Y; ry < rNyEff; ry++ {
for rx := cr.RecvStart.X; rx < rNxEff; rx++ {
rpi := ry*rNx + rx
ris := rpi * rNn
sst := cr.Start
if cr.RoundScale {
sst.X += int(math32.Round(float32(rx-cr.RecvStart.X) * sc.X))
sst.Y += int(math32.Round(float32(ry-cr.RecvStart.Y) * sc.Y))
} else {
sst.X += int(math32.Floor(float32(rx-cr.RecvStart.X) * sc.X))
sst.Y += int(math32.Floor(float32(ry-cr.RecvStart.Y) * sc.Y))
}
for y := 0; y < cr.Size.Y; y++ {
sy, clipy := edge.Edge(sst.Y+y, sNy, cr.Wrap)
if clipy {
continue
}
for x := 0; x < cr.Size.X; x++ {
sx, clipx := edge.Edge(sst.X+x, sNx, cr.Wrap)
if clipx {
continue
}
spi := sy*sNx + sx
sis := spi * sNn
for r := 0; r < rNn; r++ {
ri := ris + r
for s := 0; s < sNn; s++ {
si := sis + s
off := ri*sNtot + si
if !cr.SelfCon && same && ri == si {
continue
}
cons.Values.Set(true, off)
rnv[ri]++
snv[si]++
}
}
}
}
}
}
return
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package paths
import "cogentcore.org/lab/tensor"
// PoolSameUnit connects a given unit to the unit at the same index
// across all the pools in a layer.
// Pools are the outer-most two dimensions of a 4D layer shape.
// This is most sensible when pools have same numbers of units in send and recv.
// This is typically used for lateral topography-inducing connectivity
// and can also serve to reduce a pooled layer down to a single pool.
// The logic works if either layer does not have pools.
// If neither is 4D, then it is equivalent to OneToOne.
type PoolSameUnit struct {
// if true, and connecting layer to itself (self pathway), then make a self-connection from unit to itself
SelfCon bool
}
func NewPoolSameUnit() *PoolSameUnit {
return &PoolSameUnit{}
}
func (ot *PoolSameUnit) Name() string {
return "PoolSameUnit"
}
func (ot *PoolSameUnit) Connect(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
switch {
case send.NumDims() == 4 && recv.NumDims() == 4:
return ot.ConnectPools(send, recv, same)
case send.NumDims() == 2 && recv.NumDims() == 4:
return ot.ConnectRecvPool(send, recv, same)
case send.NumDims() == 4 && recv.NumDims() == 2:
return ot.ConnectSendPool(send, recv, same)
case send.NumDims() == 2 && recv.NumDims() == 2:
return ot.ConnectOneToOne(send, recv, same)
}
return
}
// ConnectPools is when both recv and send have pools
func (ot *PoolSameUnit) ConnectPools(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
sendn, recvn, cons = NewTensors(send, recv)
sNtot := send.Len()
sNp := send.DimSize(0) * send.DimSize(1)
rNp := recv.DimSize(0) * recv.DimSize(1)
sNu := send.DimSize(2) * send.DimSize(3)
rNu := recv.DimSize(2) * recv.DimSize(3)
rnv := recvn.Values
snv := sendn.Values
for rpi := 0; rpi < rNp; rpi++ {
for rui := 0; rui < rNu; rui++ {
if rui >= sNu {
break
}
ri := rpi*rNu + rui
for spi := 0; spi < sNp; spi++ {
if same && !ot.SelfCon && spi == rpi {
continue
}
si := spi*sNu + rui
off := ri*sNtot + si
cons.Values.Set(true, off)
rnv[ri]++
snv[si]++
}
}
}
return
}
// ConnectRecvPool is when recv has pools but send doesn't
func (ot *PoolSameUnit) ConnectRecvPool(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
sendn, recvn, cons = NewTensors(send, recv)
sNtot := send.Len()
rNp := recv.DimSize(0) * recv.DimSize(1)
sNu := send.DimSize(0) * send.DimSize(1)
rNu := recv.DimSize(2) * recv.DimSize(3)
rnv := recvn.Values
snv := sendn.Values
for rpi := 0; rpi < rNp; rpi++ {
for rui := 0; rui < rNu; rui++ {
if rui >= sNu {
break
}
ri := rpi*rNu + rui
si := rui
off := ri*sNtot + si
cons.Values.Set(true, off)
rnv[ri]++
snv[si]++
}
}
return
}
// ConnectSendPool is when send has pools but recv doesn't
func (ot *PoolSameUnit) ConnectSendPool(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
sendn, recvn, cons = NewTensors(send, recv)
sNtot := send.Len()
sNp := send.DimSize(0) * send.DimSize(1)
sNu := send.DimSize(2) * send.DimSize(3)
rNu := recv.DimSize(0) * recv.DimSize(1)
rnv := recvn.Values
snv := sendn.Values
for rui := 0; rui < rNu; rui++ {
if rui >= sNu {
break
}
ri := rui
for spi := 0; spi < sNp; spi++ {
si := spi*sNu + rui
off := ri*sNtot + si
cons.Values.Set(true, off)
rnv[ri]++
snv[si]++
}
}
return
}
// copy of OneToOne.Connect
func (ot *PoolSameUnit) ConnectOneToOne(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
sendn, recvn, cons = NewTensors(send, recv)
sNtot := send.Len()
sNu := send.DimSize(0) * send.DimSize(1)
rNu := recv.DimSize(0) * recv.DimSize(1)
rnv := recvn.Values
snv := sendn.Values
for rui := 0; rui < rNu; rui++ {
if rui >= sNu {
break
}
ri := rui
si := rui
off := ri*sNtot + si
cons.Values.Set(true, off)
rnv[ri] = 1
snv[si] = 1
}
return
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package paths
import (
"fmt"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/math32"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/core/math32/vecint"
"cogentcore.org/lab/tensor"
"github.com/emer/emergent/v2/edge"
"github.com/emer/emergent/v2/efuns"
)
// PoolTile implements tiled 2D connectivity between pools within layers, where
// a 2D rectangular receptive field (defined over pools, not units) is tiled
// across the sending layer pools, with specified level of overlap.
// Pools are the outer-most two dimensions of a 4D layer shape.
// 2D layers are assumed to have 1x1 pool.
// This is a standard form of convolutional connectivity, where pools are
// the filters and the outer dims are locations filtered.
// Various initial weight / scaling patterns are also available -- code
// must specifically apply these to the receptive fields.
type PoolTile struct {
// reciprocal topographic connectivity -- logic runs with recv <-> send -- produces symmetric back-pathway or topo path when sending layer is larger than recv
Recip bool
// size of receptive field tile, in terms of pools on the sending layer
Size vecint.Vector2i
// how many pools to skip in tiling over sending layer -- typically 1/2 of Size
Skip vecint.Vector2i
// starting pool offset for lower-left corner of first receptive field in sending layer
Start vecint.Vector2i
// if true, pool coordinates wrap around sending shape -- otherwise truncated at edges, which can lead to assymmetries in connectivity etc
Wrap bool
// gaussian topographic weights / scaling parameters for full receptive field width. multiplies any other factors present
GaussFull GaussTopo
// gaussian topographic weights / scaling parameters within individual sending pools (i.e., unit positions within their parent pool drive distance for gaussian) -- this helps organize / differentiate units more within pools, not just across entire receptive field. multiplies any other factors present
GaussInPool GaussTopo
// sigmoidal topographic weights / scaling parameters for full receptive field width. left / bottom half have increasing sigmoids, and second half decrease. Multiplies any other factors present (only used if Gauss versions are not On!)
SigFull SigmoidTopo
// sigmoidal topographic weights / scaling parameters within individual sending pools (i.e., unit positions within their parent pool drive distance for sigmoid) -- this helps organize / differentiate units more within pools, not just across entire receptive field. multiplies any other factors present (only used if Gauss versions are not On!). left / bottom half have increasing sigmoids, and second half decrease.
SigInPool SigmoidTopo
// min..max range of topographic weight values to generate
TopoRange minmax.F32
}
func NewPoolTile() *PoolTile {
pt := &PoolTile{}
pt.Defaults()
return pt
}
// NewPoolTileRecip creates a new PoolTile that is a recip version of given ff feedforward one
func NewPoolTileRecip(ff *PoolTile) *PoolTile {
pt := &PoolTile{}
*pt = *ff
pt.Recip = true
return pt
}
func (pt *PoolTile) Defaults() {
pt.Size.Set(4, 4)
pt.Skip.Set(2, 2)
pt.Start.Set(-1, -1)
pt.Wrap = true
pt.TopoRange.Min = 0.8
pt.TopoRange.Max = 1
pt.GaussFull.Defaults()
pt.GaussInPool.Defaults()
pt.SigFull.Defaults()
pt.SigInPool.Defaults()
pt.GaussFull.On = true
pt.GaussInPool.On = true
}
func (pt *PoolTile) Name() string {
return "PoolTile"
}
func (pt *PoolTile) Connect(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
if pt.Recip {
return pt.ConnectRecip(send, recv, same)
}
sendn, recvn, cons = NewTensors(send, recv)
sNtot := send.Len()
sNpY := send.DimSize(0)
sNpX := send.DimSize(1)
rNpY := recv.DimSize(0)
rNpX := recv.DimSize(1)
sNu := 1
rNu := 1
if send.NumDims() == 4 {
sNu = send.DimSize(2) * send.DimSize(3)
} else {
sNpY = 1
sNpX = 1
sNu = send.DimSize(0) * send.DimSize(1)
}
if recv.NumDims() == 4 {
rNu = recv.DimSize(2) * recv.DimSize(3)
} else {
rNpY = 1
rNpX = 1
rNu = recv.DimSize(0) * recv.DimSize(1)
}
rnv := recvn.Values
snv := sendn.Values
var clip bool
for rpy := 0; rpy < rNpY; rpy++ {
for rpx := 0; rpx < rNpX; rpx++ {
rpi := rpy*rNpX + rpx
ris := rpi * rNu
for fy := 0; fy < pt.Size.Y; fy++ {
spy := pt.Start.Y + rpy*pt.Skip.Y + fy
if spy, clip = edge.Edge(spy, sNpY, pt.Wrap); clip {
continue
}
for fx := 0; fx < pt.Size.X; fx++ {
spx := pt.Start.X + rpx*pt.Skip.X + fx
if spx, clip = edge.Edge(spx, sNpX, pt.Wrap); clip {
continue
}
spi := spy*sNpX + spx
sis := spi * sNu
for rui := 0; rui < rNu; rui++ {
ri := ris + rui
for sui := 0; sui < sNu; sui++ {
si := sis + sui
off := ri*sNtot + si
if off < cons.Len() && ri < len(rnv) && si < len(snv) {
// if !pt.SelfCon && same && ri == si {
// continue
// }
cons.Values.Set(true, off)
rnv[ri]++
snv[si]++
}
}
}
}
}
}
}
return
}
func (pt *PoolTile) ConnectRecip(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
sendn, recvn, cons = NewTensors(send, recv)
// all these variables are swapped: s from recv, r from send
rNtot := send.Len()
sNpY := recv.DimSize(0)
sNpX := recv.DimSize(1)
rNpY := send.DimSize(0)
rNpX := send.DimSize(1)
sNu := 1
rNu := 1
if recv.NumDims() == 4 {
sNu = recv.DimSize(2) * recv.DimSize(3)
} else {
sNpY = 1
sNpX = 1
sNu = recv.DimSize(0) * recv.DimSize(1)
}
if send.NumDims() == 4 {
rNu = send.DimSize(2) * send.DimSize(3)
} else {
rNpY = 1
rNpX = 1
rNu = send.DimSize(0) * send.DimSize(1)
}
snv := recvn.Values
rnv := sendn.Values
var clip bool
for rpy := 0; rpy < rNpY; rpy++ {
for rpx := 0; rpx < rNpX; rpx++ {
rpi := rpy*rNpX + rpx
ris := rpi * rNu
for fy := 0; fy < pt.Size.Y; fy++ {
spy := pt.Start.Y + rpy*pt.Skip.Y + fy
if spy, clip = edge.Edge(spy, sNpY, pt.Wrap); clip {
continue
}
for fx := 0; fx < pt.Size.X; fx++ {
spx := pt.Start.X + rpx*pt.Skip.X + fx
if spx, clip = edge.Edge(spx, sNpX, pt.Wrap); clip {
continue
}
spi := spy*sNpX + spx
sis := spi * sNu
for sui := 0; sui < sNu; sui++ {
si := sis + sui
for rui := 0; rui < rNu; rui++ {
ri := ris + rui
off := si*rNtot + ri
if off < cons.Len() && si < len(snv) && ri < len(rnv) {
cons.Values.Set(true, off)
snv[si]++
rnv[ri]++
}
}
}
}
}
}
}
return
}
// HasTopoWeights returns true if some form of topographic weight patterns are set
func (pt *PoolTile) HasTopoWeights() bool {
return pt.GaussFull.On || pt.GaussInPool.On || pt.SigFull.On || pt.SigInPool.On
}
// TopoWeights sets values in given 4D or 6D tensor according to *Topo settings.
// wts is shaped with first 2 outer-most dims as Y, X of units within layer / pool
// of recv layer (these are units over which topography is defined)
// and remaing 2D or 4D is for receptive field Size by units within pool size for
// sending layer.
func (pt *PoolTile) TopoWeights(send, recv *tensor.Shape, wts *tensor.Float32) error {
if pt.GaussFull.On || pt.GaussInPool.On {
if send.NumDims() == 2 {
return pt.TopoWeightsGauss2D(send, recv, wts)
} else {
return pt.TopoWeightsGauss4D(send, recv, wts)
}
}
if pt.SigFull.On || pt.SigInPool.On {
if send.NumDims() == 2 {
return pt.TopoWeightsSigmoid2D(send, recv, wts)
} else {
return pt.TopoWeightsSigmoid4D(send, recv, wts)
}
}
err := fmt.Errorf("PoolTile:TopoWeights no Gauss or Sig params turned on")
return errors.Log(err)
}
/////////////////////////////////////////////////////
// GaussTopo Wts
// GaussTopo has parameters for Gaussian topographic weights or scaling factors
type GaussTopo struct {
// use gaussian topographic weights / scaling values
On bool
// gaussian sigma (width) in normalized units where entire distance across relevant dimension is 1.0 -- typical useful values range from .3 to 1.5, with .6 default
Sigma float32 `default:"0.6"`
// wrap the gaussian around on other sides of the receptive field, with the closest distance being used -- this removes strict topography but ensures a more uniform distribution of weight values so edge units don't have weaker overall weights
Wrap bool
// proportion to move gaussian center relative to the position of the receiving unit within its pool: 1.0 = centers span the entire range of the receptive field. Typically want to use 1.0 for Wrap = true, and 0.8 for false
CtrMove float32 `default:"0.8,1"`
}
func (gt *GaussTopo) Defaults() {
gt.Sigma = 0.6
gt.Wrap = true
gt.CtrMove = 1
}
func (gt *GaussTopo) ShouldDisplay(field string) bool {
switch field {
case "On":
return true
default:
return gt.On
}
}
// DefWrap sets default wrap parameters (which are overall defaults): CtrMove = 1
func (gt *GaussTopo) DefWrap() {
gt.Wrap = true
gt.CtrMove = 1
}
// DefNoWrap sets default no-wrap parameters (CtrMove = .8 instead of 1)
func (gt *GaussTopo) DefNoWrap() {
gt.Wrap = false
gt.CtrMove = 0.8
}
// GaussOff turns off gaussian weights
func (pt *PoolTile) GaussOff() {
pt.GaussFull.On = false
pt.GaussInPool.On = false
}
// TopoWeightsGauss2D sets values in given 4D tensor according to *Topo settings.
// wts is shaped with first 2 outer-most dims as Y, X of units within layer / pool
// of recv layer (these are units over which topography is defined)
// and remaing 2D is for sending layer size (2D = sender)
func (pt *PoolTile) TopoWeightsGauss2D(send, recv *tensor.Shape, wts *tensor.Float32) error {
if pt.GaussFull.Sigma == 0 {
pt.GaussFull.Defaults()
}
if pt.GaussInPool.Sigma == 0 {
pt.GaussInPool.Defaults()
}
sNuY := send.DimSize(0)
sNuX := send.DimSize(1)
rNuY := recv.DimSize(0) // ok if recv is 2D
rNuX := recv.DimSize(1)
if recv.NumDims() == 4 {
rNuY = recv.DimSize(2)
rNuX = recv.DimSize(3)
}
wts.SetShapeSizes(rNuY, rNuX, sNuY, sNuX)
fsz := math32.Vec2(float32(sNuX-1), float32(sNuY-1)) // full rf size
hfsz := fsz.MulScalar(0.5) // half rf
fsig := pt.GaussFull.Sigma * hfsz.X // full sigma
if fsig <= 0 {
fsig = pt.GaussFull.Sigma
}
psz := math32.Vec2(float32(sNuX), float32(sNuY)) // within-pool rf size
if sNuX > 1 {
psz.X -= 1
}
if sNuY > 1 {
psz.Y -= 1
}
hpsz := psz.MulScalar(0.5) // half rf
psig := pt.GaussInPool.Sigma * hpsz.X // pool sigma
if psig <= 0 {
psig = pt.GaussInPool.Sigma
}
rsz := math32.Vec2(float32(rNuX), float32(rNuY)) // recv units-in-pool size
if rNuX > 1 {
rsz.X -= 1
}
if rNuY > 1 {
rsz.Y -= 1
}
hrsz := rsz.MulScalar(0.5)
for ruy := 0; ruy < rNuY; ruy++ {
for rux := 0; rux < rNuX; rux++ {
rpos := math32.Vec2(float32(rux), float32(ruy)).Sub(hrsz).Div(hrsz) // -1..1 normalized r unit pos
rfpos := rpos.MulScalar(pt.GaussFull.CtrMove)
rppos := rpos.MulScalar(pt.GaussInPool.CtrMove)
sfctr := rfpos.Mul(hfsz).Add(hfsz) // sending center for full
spctr := rppos.Mul(hpsz).Add(hpsz) // sending center for within-pool
for suy := 0; suy < sNuY; suy++ {
for sux := 0; sux < sNuX; sux++ {
fwt := float32(1)
if pt.GaussFull.On {
sf := math32.Vec2(float32(sux), float32(suy))
if pt.GaussFull.Wrap {
sf.X = edge.WrapMinDist(sf.X, fsz.X, sfctr.X)
sf.Y = edge.WrapMinDist(sf.Y, fsz.Y, sfctr.Y)
}
fwt = efuns.GaussVecDistNoNorm(sf, sfctr, fsig)
}
pwt := float32(1)
if pt.GaussInPool.On {
sp := math32.Vec2(float32(sux), float32(suy))
if pt.GaussInPool.Wrap {
sp.X = edge.WrapMinDist(sp.X, psz.X, spctr.X)
sp.Y = edge.WrapMinDist(sp.Y, psz.Y, spctr.Y)
}
pwt = efuns.GaussVecDistNoNorm(sp, spctr, psig)
}
wt := fwt * pwt
rwt := pt.TopoRange.ProjValue(wt)
wts.Set(rwt, ruy, rux, suy, sux)
}
}
}
}
return nil
}
// TopoWeightsGauss4D sets values in given 6D tensor according to *Topo settings.
// wts is shaped with first 2 outer-most dims as Y, X of units within layer / pool
// of recv layer (these are units over which topography is defined)
// and remaing 4D is for receptive field Size by units within pool size for
// sending layer.
func (pt *PoolTile) TopoWeightsGauss4D(send, recv *tensor.Shape, wts *tensor.Float32) error {
if pt.GaussFull.Sigma == 0 {
pt.GaussFull.Defaults()
}
if pt.GaussInPool.Sigma == 0 {
pt.GaussInPool.Defaults()
}
sNuY := send.DimSize(2)
sNuX := send.DimSize(3)
rNuY := recv.DimSize(0) // ok if recv is 2D
rNuX := recv.DimSize(1)
if recv.NumDims() == 4 {
rNuY = recv.DimSize(2)
rNuX = recv.DimSize(3)
}
wts.SetShapeSizes(rNuY, rNuX, pt.Size.Y, pt.Size.X, sNuY, sNuX)
fsz := math32.Vec2(float32(pt.Size.X*sNuX-1), float32(pt.Size.Y*sNuY-1)) // full rf size
hfsz := fsz.MulScalar(0.5) // half rf
fsig := pt.GaussFull.Sigma * hfsz.X // full sigma
if fsig <= 0 {
fsig = pt.GaussFull.Sigma
}
psz := math32.Vec2(float32(sNuX), float32(sNuY)) // within-pool rf size
if sNuX > 1 {
psz.X -= 1
}
if sNuY > 1 {
psz.Y -= 1
}
hpsz := psz.MulScalar(0.5) // half rf
psig := pt.GaussInPool.Sigma * hpsz.X // pool sigma
if psig <= 0 {
psig = pt.GaussInPool.Sigma
}
rsz := math32.Vec2(float32(rNuX), float32(rNuY)) // recv units-in-pool size
if rNuX > 1 {
rsz.X -= 1
}
if rNuY > 1 {
rsz.Y -= 1
}
hrsz := rsz.MulScalar(0.5)
for ruy := 0; ruy < rNuY; ruy++ {
for rux := 0; rux < rNuX; rux++ {
rpos := math32.Vec2(float32(rux), float32(ruy)).Sub(hrsz).Div(hrsz) // -1..1 normalized r unit pos
rfpos := rpos.MulScalar(pt.GaussFull.CtrMove)
rppos := rpos.MulScalar(pt.GaussInPool.CtrMove)
sfctr := rfpos.Mul(hfsz).Add(hfsz) // sending center for full
spctr := rppos.Mul(hpsz).Add(hpsz) // sending center for within-pool
for fy := 0; fy < pt.Size.Y; fy++ {
for fx := 0; fx < pt.Size.X; fx++ {
for suy := 0; suy < sNuY; suy++ {
for sux := 0; sux < sNuX; sux++ {
fwt := float32(1)
if pt.GaussFull.On {
sf := math32.Vec2(float32(fx*sNuX+sux), float32(fy*sNuY+suy))
if pt.GaussFull.Wrap {
sf.X = edge.WrapMinDist(sf.X, fsz.X, sfctr.X)
sf.Y = edge.WrapMinDist(sf.Y, fsz.Y, sfctr.Y)
}
fwt = efuns.GaussVecDistNoNorm(sf, sfctr, fsig)
}
pwt := float32(1)
if pt.GaussInPool.On {
sp := math32.Vec2(float32(sux), float32(suy))
if pt.GaussInPool.Wrap {
sp.X = edge.WrapMinDist(sp.X, psz.X, spctr.X)
sp.Y = edge.WrapMinDist(sp.Y, psz.Y, spctr.Y)
}
pwt = efuns.GaussVecDistNoNorm(sp, spctr, psig)
}
wt := fwt * pwt
rwt := pt.TopoRange.ProjValue(wt)
wts.Set(rwt, ruy, rux, fy, fx, suy, sux)
}
}
}
}
}
}
return nil
}
/////////////////////////////////////////////////////
// SigmoidTopo Wts
// SigmoidTopo has parameters for Gaussian topographic weights or scaling factors
type SigmoidTopo struct {
// use gaussian topographic weights / scaling values
On bool
// gain of sigmoid that determines steepness of curve, in normalized units where entire distance across relevant dimension is 1.0 -- typical useful values range from 0.01 to 0.1
Gain float32
// proportion to move gaussian center relative to the position of the receiving unit within its pool: 1.0 = centers span the entire range of the receptive field. Typically want to use 1.0 for Wrap = true, and 0.8 for false
CtrMove float32 `default:"0.5,1"`
}
func (gt *SigmoidTopo) Defaults() {
gt.Gain = 0.05
gt.CtrMove = 0.5
}
func (gt *SigmoidTopo) ShouldDisplay(field string) bool {
switch field {
case "On":
return true
default:
return gt.On
}
}
// TopoWeightsSigmoid2D sets values in given 4D tensor according to Topo settings.
// wts is shaped with first 2 outer-most dims as Y, X of units within pool
// of recv layer (these are units over which topography is defined)
// and remaing 2D is for sending layer (2D = sender).
func (pt *PoolTile) TopoWeightsSigmoid2D(send, recv *tensor.Shape, wts *tensor.Float32) error {
if pt.SigFull.Gain == 0 {
pt.SigFull.Defaults()
}
if pt.SigInPool.Gain == 0 {
pt.SigInPool.Defaults()
}
sNuY := send.DimSize(0)
sNuX := send.DimSize(1)
rNuY := recv.DimSize(0) // ok if recv is 2D
rNuX := recv.DimSize(1)
if recv.NumDims() == 4 {
rNuY = recv.DimSize(2)
rNuX = recv.DimSize(3)
}
wts.SetShapeSizes(rNuY, rNuX, sNuY, sNuX)
fsz := math32.Vec2(float32(sNuX-1), float32(sNuY-1)) // full rf size
hfsz := fsz.MulScalar(0.5) // half rf
fgain := pt.SigFull.Gain * hfsz.X // full gain
psz := math32.Vec2(float32(sNuX), float32(sNuY)) // within-pool rf size
if sNuX > 1 {
psz.X -= 1
}
if sNuY > 1 {
psz.Y -= 1
}
hpsz := psz.MulScalar(0.5) // half rf
pgain := pt.SigInPool.Gain * hpsz.X // pool sigma
rsz := math32.Vec2(float32(rNuX), float32(rNuY)) // recv units-in-pool size
if rNuX > 1 {
rsz.X -= 1
}
if rNuY > 1 {
rsz.Y -= 1
}
hrsz := rsz.MulScalar(0.5)
for ruy := 0; ruy < rNuY; ruy++ {
for rux := 0; rux < rNuX; rux++ {
rpos := math32.Vec2(float32(rux), float32(ruy)).Div(hrsz) // 0..2 normalized r unit pos
sgn := math32.Vec2(1, 1)
rfpos := rpos.SubScalar(0.5).MulScalar(pt.SigFull.CtrMove).AddScalar(0.5)
rppos := rpos.SubScalar(0.5).MulScalar(pt.SigInPool.CtrMove).AddScalar(0.5)
if rpos.X >= 1 { // flip direction half-way through
sgn.X = -1
rpos.X = -rpos.X + 1
rfpos.X = (rpos.X+0.5)*pt.SigFull.CtrMove - 0.5
rppos.X = (rpos.X+0.5)*pt.SigInPool.CtrMove - 0.5
}
if rpos.Y >= 1 {
sgn.Y = -1
rpos.Y = -rpos.Y + 1
rfpos.Y = (rpos.Y+0.5)*pt.SigFull.CtrMove - 0.5
rfpos.Y = (rpos.Y+0.5)*pt.SigInPool.CtrMove - 0.5
}
sfctr := rfpos.Mul(fsz) // sending center for full
spctr := rppos.Mul(psz) // sending center for within-pool
for suy := 0; suy < sNuY; suy++ {
for sux := 0; sux < sNuX; sux++ {
fwt := float32(1)
if pt.SigFull.On {
sf := math32.Vec2(float32(sux), float32(suy))
sigx := efuns.Logistic(sgn.X*sf.X, fgain, sfctr.X)
sigy := efuns.Logistic(sgn.Y*sf.Y, fgain, sfctr.Y)
fwt = sigx * sigy
}
pwt := float32(1)
if pt.SigInPool.On {
sp := math32.Vec2(float32(sux), float32(suy))
sigx := efuns.Logistic(sgn.X*sp.X, pgain, spctr.X)
sigy := efuns.Logistic(sgn.Y*sp.Y, pgain, spctr.Y)
pwt = sigx * sigy
}
wt := fwt * pwt
rwt := pt.TopoRange.ProjValue(wt)
wts.Set(rwt, ruy, rux, suy, sux)
}
}
}
}
return nil
}
// TopoWeightsSigmoid4D sets values in given 6D tensor according to Topo settings.
// wts is shaped with first 2 outer-most dims as Y, X of units within pool
// of recv layer (these are units over which topography is defined)
// and remaing 2D is for receptive field Size by units within pool size for
// sending layer.
func (pt *PoolTile) TopoWeightsSigmoid4D(send, recv *tensor.Shape, wts *tensor.Float32) error {
if pt.SigFull.Gain == 0 {
pt.SigFull.Defaults()
}
if pt.SigInPool.Gain == 0 {
pt.SigInPool.Defaults()
}
sNuY := send.DimSize(2)
sNuX := send.DimSize(3)
rNuY := recv.DimSize(0) // ok if recv is 2D
rNuX := recv.DimSize(1)
if recv.NumDims() == 4 {
rNuY = recv.DimSize(2)
rNuX = recv.DimSize(3)
}
wts.SetShapeSizes(rNuY, rNuX, pt.Size.Y, pt.Size.X, sNuY, sNuX)
fsz := math32.Vec2(float32(pt.Size.X*sNuX-1), float32(pt.Size.Y*sNuY-1)) // full rf size
hfsz := fsz.MulScalar(0.5) // half rf
fgain := pt.SigFull.Gain * hfsz.X // full gain
psz := math32.Vec2(float32(sNuX), float32(sNuY)) // within-pool rf size
if sNuX > 1 {
psz.X -= 1
}
if sNuY > 1 {
psz.Y -= 1
}
hpsz := psz.MulScalar(0.5) // half rf
pgain := pt.SigInPool.Gain * hpsz.X // pool sigma
rsz := math32.Vec2(float32(rNuX), float32(rNuY)) // recv units-in-pool size
if rNuX > 1 {
rsz.X -= 1
}
if rNuY > 1 {
rsz.Y -= 1
}
hrsz := rsz.MulScalar(0.5)
for ruy := 0; ruy < rNuY; ruy++ {
for rux := 0; rux < rNuX; rux++ {
rpos := math32.Vec2(float32(rux), float32(ruy)).Div(hrsz) // 0..2 normalized r unit pos
sgn := math32.Vec2(1, 1)
rfpos := rpos.SubScalar(0.5).MulScalar(pt.SigFull.CtrMove).AddScalar(0.5)
rppos := rpos.SubScalar(0.5).MulScalar(pt.SigInPool.CtrMove).AddScalar(0.5)
if rpos.X >= 1 { // flip direction half-way through
sgn.X = -1
rpos.X = -rpos.X + 1
rfpos.X = (rpos.X+0.5)*pt.SigFull.CtrMove - 0.5
rppos.X = (rpos.X+0.5)*pt.SigInPool.CtrMove - 0.5
}
if rpos.Y >= 1 {
sgn.Y = -1
rpos.Y = -rpos.Y + 1
rfpos.Y = (rpos.Y+0.5)*pt.SigFull.CtrMove - 0.5
rfpos.Y = (rpos.Y+0.5)*pt.SigInPool.CtrMove - 0.5
}
sfctr := rfpos.Mul(fsz) // sending center for full
spctr := rppos.Mul(psz) // sending center for within-pool
for fy := 0; fy < pt.Size.Y; fy++ {
for fx := 0; fx < pt.Size.X; fx++ {
for suy := 0; suy < sNuY; suy++ {
for sux := 0; sux < sNuX; sux++ {
fwt := float32(1)
if pt.SigFull.On {
sf := math32.Vec2(float32(fx*sNuX+sux), float32(fy*sNuY+suy))
sigx := efuns.Logistic(sgn.X*sf.X, fgain, sfctr.X)
sigy := efuns.Logistic(sgn.Y*sf.Y, fgain, sfctr.Y)
fwt = sigx * sigy
}
pwt := float32(1)
if pt.SigInPool.On {
sp := math32.Vec2(float32(sux), float32(suy))
sigx := efuns.Logistic(sgn.X*sp.X, pgain, spctr.X)
sigy := efuns.Logistic(sgn.Y*sp.Y, pgain, spctr.Y)
pwt = sigx * sigy
}
wt := fwt * pwt
rwt := pt.TopoRange.ProjValue(wt)
wts.Set(rwt, ruy, rux, fy, fx, suy, sux)
}
}
}
}
}
}
return nil
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package paths
import (
"fmt"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/math32"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/core/math32/vecint"
"cogentcore.org/lab/tensor"
"github.com/emer/emergent/v2/edge"
"github.com/emer/emergent/v2/efuns"
)
// PoolTileSub implements tiled 2D connectivity between pools within layers, where
// a 2D rectangular receptive field (defined over pools, not units) is tiled
// across the sending layer pools, with specified level of overlap.
// Pools are the outer-most two dimensions of a 4D layer shape.
// Sub version has sub-pools within each pool to encourage more independent
// representations.
// 2D layers are assumed to have 1x1 pool.
// This is a standard form of convolutional connectivity, where pools are
// the filters and the outer dims are locations filtered.
// Various initial weight / scaling patterns are also available -- code
// must specifically apply these to the receptive fields.
type PoolTileSub struct {
// reciprocal topographic connectivity -- logic runs with recv <-> send -- produces symmetric back-pathway or topo path when sending layer is larger than recv
Recip bool
// size of receptive field tile, in terms of pools on the sending layer
Size vecint.Vector2i
// how many pools to skip in tiling over sending layer -- typically 1/2 of Size
Skip vecint.Vector2i
// starting pool offset for lower-left corner of first receptive field in sending layer
Start vecint.Vector2i
// number of sub-pools within each pool
Subs vecint.Vector2i
// sending layer has sub-pools
SendSubs bool
// if true, pool coordinates wrap around sending shape -- otherwise truncated at edges, which can lead to assymmetries in connectivity etc
Wrap bool
// gaussian topographic weights / scaling parameters for full receptive field width. multiplies any other factors present
GaussFull GaussTopo
// gaussian topographic weights / scaling parameters within individual sending pools (i.e., unit positions within their parent pool drive distance for gaussian) -- this helps organize / differentiate units more within pools, not just across entire receptive field. multiplies any other factors present
GaussInPool GaussTopo
// sigmoidal topographic weights / scaling parameters for full receptive field width. left / bottom half have increasing sigmoids, and second half decrease. Multiplies any other factors present (only used if Gauss versions are not On!)
SigFull SigmoidTopo
// sigmoidal topographic weights / scaling parameters within individual sending pools (i.e., unit positions within their parent pool drive distance for sigmoid) -- this helps organize / differentiate units more within pools, not just across entire receptive field. multiplies any other factors present (only used if Gauss versions are not On!). left / bottom half have increasing sigmoids, and second half decrease.
SigInPool SigmoidTopo
// min..max range of topographic weight values to generate
TopoRange minmax.F32
}
func NewPoolTileSub() *PoolTileSub {
pt := &PoolTileSub{}
pt.Defaults()
return pt
}
// NewPoolTileSubRecip creates a new PoolTileSub that is a recip version of given ff feedforward one
func NewPoolTileSubRecip(ff *PoolTileSub) *PoolTileSub {
pt := &PoolTileSub{}
*pt = *ff
pt.Recip = true
return pt
}
func (pt *PoolTileSub) Defaults() {
pt.Size.Set(4, 4)
pt.Skip.Set(2, 2)
pt.Start.Set(-1, -1)
pt.Subs.Set(2, 2)
pt.Wrap = true
pt.TopoRange.Min = 0.8
pt.TopoRange.Max = 1
pt.GaussFull.Defaults()
pt.GaussInPool.Defaults()
pt.SigFull.Defaults()
pt.SigInPool.Defaults()
pt.GaussFull.On = true
pt.GaussInPool.On = true
}
func (pt *PoolTileSub) Name() string {
return "PoolTileSub"
}
func (pt *PoolTileSub) Connect(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
if pt.Recip {
return pt.ConnectRecip(send, recv, same)
}
sendn, recvn, cons = NewTensors(send, recv)
sNtot := send.Len()
sNpY := send.DimSize(0)
sNpX := send.DimSize(1)
rNpY := recv.DimSize(0)
rNpX := recv.DimSize(1)
sNu := 1
rNu := 1
if send.NumDims() == 4 {
sNu = send.DimSize(2) * send.DimSize(3)
} else {
sNpY = 1
sNpX = 1
sNu = send.DimSize(0) * send.DimSize(1)
}
if recv.NumDims() == 4 {
rNu = recv.DimSize(2) * recv.DimSize(3)
} else {
rNpY = 1
rNpX = 1
rNu = recv.DimSize(0) * recv.DimSize(1)
}
rnv := recvn.Values
snv := sendn.Values
var clip bool
for rpy := 0; rpy < rNpY; rpy++ {
rpys := rpy / pt.Subs.Y // sub group
rpyi := rpy % pt.Subs.Y // index within subgroup
for rpx := 0; rpx < rNpX; rpx++ {
rpxs := rpx / pt.Subs.X
rpxi := rpx % pt.Subs.X
rpi := rpy*rNpX + rpx
ris := rpi * rNu
for fy := 0; fy < pt.Size.Y; fy++ {
spy := pt.Start.Y + rpys*pt.Skip.Y + fy
if pt.SendSubs {
spy = spy*pt.Subs.Y + rpyi
}
if spy, clip = edge.Edge(spy, sNpY, pt.Wrap); clip {
continue
}
for fx := 0; fx < pt.Size.X; fx++ {
spx := pt.Start.X + rpxs*pt.Skip.X + fx
if pt.SendSubs {
spx = spx*pt.Subs.X + rpxi
}
if spx, clip = edge.Edge(spx, sNpX, pt.Wrap); clip {
continue
}
spi := spy*sNpX + spx
sis := spi * sNu
for rui := 0; rui < rNu; rui++ {
ri := ris + rui
for sui := 0; sui < sNu; sui++ {
si := sis + sui
off := ri*sNtot + si
if off < cons.Len() && ri < len(rnv) && si < len(snv) {
// if !pt.SelfCon && same && ri == si {
// continue
// }
cons.Values.Set(true, off)
rnv[ri]++
snv[si]++
}
}
}
}
}
}
}
return
}
func (pt *PoolTileSub) ConnectRecip(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
sendn, recvn, cons = NewTensors(send, recv)
// all these variables are swapped: s from recv, r from send
rNtot := send.Len()
sNpY := recv.DimSize(0)
sNpX := recv.DimSize(1)
rNpY := send.DimSize(0)
rNpX := send.DimSize(1)
sNu := 1
rNu := 1
if recv.NumDims() == 4 {
sNu = recv.DimSize(2) * recv.DimSize(3)
} else {
sNpY = 1
sNpX = 1
sNu = recv.DimSize(0) * recv.DimSize(1)
}
if send.NumDims() == 4 {
rNu = send.DimSize(2) * send.DimSize(3)
} else {
rNpY = 1
rNpX = 1
rNu = send.DimSize(0) * send.DimSize(1)
}
snv := recvn.Values
rnv := sendn.Values
var clip bool
for rpy := 0; rpy < rNpY; rpy++ {
rpys := rpy / pt.Subs.Y // sub group
rpyi := rpy % pt.Subs.Y // index within subgroup
for rpx := 0; rpx < rNpX; rpx++ {
rpxs := rpx / pt.Subs.X
rpxi := rpx % pt.Subs.X
rpi := rpy*rNpX + rpx
ris := rpi * rNu
for fy := 0; fy < pt.Size.Y; fy++ {
spy := pt.Start.Y + rpys*pt.Skip.Y + fy
if pt.SendSubs {
spy = spy*pt.Subs.Y + rpyi
}
if spy, clip = edge.Edge(spy, sNpY, pt.Wrap); clip {
continue
}
for fx := 0; fx < pt.Size.X; fx++ {
spx := pt.Start.X + rpxs*pt.Skip.X + fx
if pt.SendSubs {
spx = spx*pt.Subs.X + rpxi
}
if spx, clip = edge.Edge(spx, sNpX, pt.Wrap); clip {
continue
}
spi := spy*sNpX + spx
sis := spi * sNu
for sui := 0; sui < sNu; sui++ {
si := sis + sui
for rui := 0; rui < rNu; rui++ {
ri := ris + rui
off := si*rNtot + ri
if off < cons.Len() && si < len(snv) && ri < len(rnv) {
cons.Values.Set(true, off)
snv[si]++
rnv[ri]++
}
}
}
}
}
}
}
return
}
// HasTopoWeights returns true if some form of topographic weight patterns are set
func (pt *PoolTileSub) HasTopoWeights() bool {
return pt.GaussFull.On || pt.GaussInPool.On || pt.SigFull.On || pt.SigInPool.On
}
// TopoWeights sets values in given 4D or 6D tensor according to *Topo settings.
// wts is shaped with first 2 outer-most dims as Y, X of units within layer / pool
// of recv layer (these are units over which topography is defined)
// and remaing 2D or 4D is for receptive field Size by units within pool size for
// sending layer.
func (pt *PoolTileSub) TopoWeights(send, recv *tensor.Shape, wts *tensor.Float32) error {
if pt.GaussFull.On || pt.GaussInPool.On {
if send.NumDims() == 2 {
return pt.TopoWeightsGauss2D(send, recv, wts)
} else {
return pt.TopoWeightsGauss4D(send, recv, wts)
}
}
if pt.SigFull.On || pt.SigInPool.On {
if send.NumDims() == 2 {
return pt.TopoWeightsSigmoid2D(send, recv, wts)
} else {
return pt.TopoWeightsSigmoid4D(send, recv, wts)
}
}
err := fmt.Errorf("PoolTileSub:TopoWeights no Gauss or Sig params turned on")
return errors.Log(err)
}
// GaussOff turns off gaussian weights
func (pt *PoolTileSub) GaussOff() {
pt.GaussFull.On = false
pt.GaussInPool.On = false
}
// TopoWeightsGauss2D sets values in given 4D tensor according to *Topo settings.
// wts is shaped with first 2 outer-most dims as Y, X of units within layer / pool
// of recv layer (these are units over which topography is defined)
// and remaing 2D is for sending layer size (2D = sender)
func (pt *PoolTileSub) TopoWeightsGauss2D(send, recv *tensor.Shape, wts *tensor.Float32) error {
if pt.GaussFull.Sigma == 0 {
pt.GaussFull.Defaults()
}
if pt.GaussInPool.Sigma == 0 {
pt.GaussInPool.Defaults()
}
sNuY := send.DimSize(0)
sNuX := send.DimSize(1)
rNuY := recv.DimSize(0) // ok if recv is 2D
rNuX := recv.DimSize(1)
if recv.NumDims() == 4 {
rNuY = recv.DimSize(2)
rNuX = recv.DimSize(3)
}
wts.SetShapeSizes(rNuY, rNuX, sNuY, sNuX)
fsz := math32.Vec2(float32(sNuX-1), float32(sNuY-1)) // full rf size
hfsz := fsz.MulScalar(0.5) // half rf
fsig := pt.GaussFull.Sigma * hfsz.X // full sigma
if fsig <= 0 {
fsig = pt.GaussFull.Sigma
}
psz := math32.Vec2(float32(sNuX), float32(sNuY)) // within-pool rf size
if sNuX > 1 {
psz.X -= 1
}
if sNuY > 1 {
psz.Y -= 1
}
hpsz := psz.MulScalar(0.5) // half rf
psig := pt.GaussInPool.Sigma * hpsz.X // pool sigma
if psig <= 0 {
psig = pt.GaussInPool.Sigma
}
rsz := math32.Vec2(float32(rNuX), float32(rNuY)) // recv units-in-pool size
if rNuX > 1 {
rsz.X -= 1
}
if rNuY > 1 {
rsz.Y -= 1
}
hrsz := rsz.MulScalar(0.5)
for ruy := 0; ruy < rNuY; ruy++ {
for rux := 0; rux < rNuX; rux++ {
rpos := math32.Vec2(float32(rux), float32(ruy)).Sub(hrsz).Div(hrsz) // -1..1 normalized r unit pos
rfpos := rpos.MulScalar(pt.GaussFull.CtrMove)
rppos := rpos.MulScalar(pt.GaussInPool.CtrMove)
sfctr := rfpos.Mul(hfsz).Add(hfsz) // sending center for full
spctr := rppos.Mul(hpsz).Add(hpsz) // sending center for within-pool
for suy := 0; suy < sNuY; suy++ {
for sux := 0; sux < sNuX; sux++ {
fwt := float32(1)
if pt.GaussFull.On {
sf := math32.Vec2(float32(sux), float32(suy))
if pt.GaussFull.Wrap {
sf.X = edge.WrapMinDist(sf.X, fsz.X, sfctr.X)
sf.Y = edge.WrapMinDist(sf.Y, fsz.Y, sfctr.Y)
}
fwt = efuns.GaussVecDistNoNorm(sf, sfctr, fsig)
}
pwt := float32(1)
if pt.GaussInPool.On {
sp := math32.Vec2(float32(sux), float32(suy))
if pt.GaussInPool.Wrap {
sp.X = edge.WrapMinDist(sp.X, psz.X, spctr.X)
sp.Y = edge.WrapMinDist(sp.Y, psz.Y, spctr.Y)
}
pwt = efuns.GaussVecDistNoNorm(sp, spctr, psig)
}
wt := fwt * pwt
rwt := pt.TopoRange.ProjValue(wt)
wts.Set(rwt, ruy, rux, suy, sux)
}
}
}
}
return nil
}
// TopoWeightsGauss4D sets values in given 6D tensor according to *Topo settings.
// wts is shaped with first 2 outer-most dims as Y, X of units within layer / pool
// of recv layer (these are units over which topography is defined)
// and remaing 4D is for receptive field Size by units within pool size for
// sending layer.
func (pt *PoolTileSub) TopoWeightsGauss4D(send, recv *tensor.Shape, wts *tensor.Float32) error {
if pt.GaussFull.Sigma == 0 {
pt.GaussFull.Defaults()
}
if pt.GaussInPool.Sigma == 0 {
pt.GaussInPool.Defaults()
}
sNuY := send.DimSize(2)
sNuX := send.DimSize(3)
rNuY := recv.DimSize(0) // ok if recv is 2D
rNuX := recv.DimSize(1)
if recv.NumDims() == 4 {
rNuY = recv.DimSize(2)
rNuX = recv.DimSize(3)
}
wts.SetShapeSizes(rNuY, rNuX, pt.Size.Y, pt.Size.X, sNuY, sNuX)
fsz := math32.Vec2(float32(pt.Size.X*sNuX-1), float32(pt.Size.Y*sNuY-1)) // full rf size
hfsz := fsz.MulScalar(0.5) // half rf
fsig := pt.GaussFull.Sigma * hfsz.X // full sigma
if fsig <= 0 {
fsig = pt.GaussFull.Sigma
}
psz := math32.Vec2(float32(sNuX), float32(sNuY)) // within-pool rf size
if sNuX > 1 {
psz.X -= 1
}
if sNuY > 1 {
psz.Y -= 1
}
hpsz := psz.MulScalar(0.5) // half rf
psig := pt.GaussInPool.Sigma * hpsz.X // pool sigma
if psig <= 0 {
psig = pt.GaussInPool.Sigma
}
rsz := math32.Vec2(float32(rNuX), float32(rNuY)) // recv units-in-pool size
if rNuX > 1 {
rsz.X -= 1
}
if rNuY > 1 {
rsz.Y -= 1
}
hrsz := rsz.MulScalar(0.5)
for ruy := 0; ruy < rNuY; ruy++ {
for rux := 0; rux < rNuX; rux++ {
rpos := math32.Vec2(float32(rux), float32(ruy)).Sub(hrsz).Div(hrsz) // -1..1 normalized r unit pos
rfpos := rpos.MulScalar(pt.GaussFull.CtrMove)
rppos := rpos.MulScalar(pt.GaussInPool.CtrMove)
sfctr := rfpos.Mul(hfsz).Add(hfsz) // sending center for full
spctr := rppos.Mul(hpsz).Add(hpsz) // sending center for within-pool
for fy := 0; fy < pt.Size.Y; fy++ {
for fx := 0; fx < pt.Size.X; fx++ {
for suy := 0; suy < sNuY; suy++ {
for sux := 0; sux < sNuX; sux++ {
fwt := float32(1)
if pt.GaussFull.On {
sf := math32.Vec2(float32(fx*sNuX+sux), float32(fy*sNuY+suy))
if pt.GaussFull.Wrap {
sf.X = edge.WrapMinDist(sf.X, fsz.X, sfctr.X)
sf.Y = edge.WrapMinDist(sf.Y, fsz.Y, sfctr.Y)
}
fwt = efuns.GaussVecDistNoNorm(sf, sfctr, fsig)
}
pwt := float32(1)
if pt.GaussInPool.On {
sp := math32.Vec2(float32(sux), float32(suy))
if pt.GaussInPool.Wrap {
sp.X = edge.WrapMinDist(sp.X, psz.X, spctr.X)
sp.Y = edge.WrapMinDist(sp.Y, psz.Y, spctr.Y)
}
pwt = efuns.GaussVecDistNoNorm(sp, spctr, psig)
}
wt := fwt * pwt
rwt := pt.TopoRange.ProjValue(wt)
wts.Set(rwt, ruy, rux, fy, fx, suy, sux)
}
}
}
}
}
}
return nil
}
/////////////////////////////////////////////////////
// SigmoidTopo Wts
// TopoWeightsSigmoid2D sets values in given 4D tensor according to Topo settings.
// wts is shaped with first 2 outer-most dims as Y, X of units within pool
// of recv layer (these are units over which topography is defined)
// and remaing 2D is for sending layer (2D = sender).
func (pt *PoolTileSub) TopoWeightsSigmoid2D(send, recv *tensor.Shape, wts *tensor.Float32) error {
if pt.SigFull.Gain == 0 {
pt.SigFull.Defaults()
}
if pt.SigInPool.Gain == 0 {
pt.SigInPool.Defaults()
}
sNuY := send.DimSize(0)
sNuX := send.DimSize(1)
rNuY := recv.DimSize(0) // ok if recv is 2D
rNuX := recv.DimSize(1)
if recv.NumDims() == 4 {
rNuY = recv.DimSize(2)
rNuX = recv.DimSize(3)
}
wts.SetShapeSizes(rNuY, rNuX, sNuY, sNuX)
fsz := math32.Vec2(float32(sNuX-1), float32(sNuY-1)) // full rf size
hfsz := fsz.MulScalar(0.5) // half rf
fgain := pt.SigFull.Gain * hfsz.X // full gain
psz := math32.Vec2(float32(sNuX), float32(sNuY)) // within-pool rf size
if sNuX > 1 {
psz.X -= 1
}
if sNuY > 1 {
psz.Y -= 1
}
hpsz := psz.MulScalar(0.5) // half rf
pgain := pt.SigInPool.Gain * hpsz.X // pool sigma
rsz := math32.Vec2(float32(rNuX), float32(rNuY)) // recv units-in-pool size
if rNuX > 1 {
rsz.X -= 1
}
if rNuY > 1 {
rsz.Y -= 1
}
hrsz := rsz.MulScalar(0.5)
for ruy := 0; ruy < rNuY; ruy++ {
for rux := 0; rux < rNuX; rux++ {
rpos := math32.Vec2(float32(rux), float32(ruy)).Div(hrsz) // 0..2 normalized r unit pos
sgn := math32.Vec2(1, 1)
rfpos := rpos.SubScalar(0.5).MulScalar(pt.SigFull.CtrMove).AddScalar(0.5)
rppos := rpos.SubScalar(0.5).MulScalar(pt.SigInPool.CtrMove).AddScalar(0.5)
if rpos.X >= 1 { // flip direction half-way through
sgn.X = -1
rpos.X = -rpos.X + 1
rfpos.X = (rpos.X+0.5)*pt.SigFull.CtrMove - 0.5
rppos.X = (rpos.X+0.5)*pt.SigInPool.CtrMove - 0.5
}
if rpos.Y >= 1 {
sgn.Y = -1
rpos.Y = -rpos.Y + 1
rfpos.Y = (rpos.Y+0.5)*pt.SigFull.CtrMove - 0.5
rfpos.Y = (rpos.Y+0.5)*pt.SigInPool.CtrMove - 0.5
}
sfctr := rfpos.Mul(fsz) // sending center for full
spctr := rppos.Mul(psz) // sending center for within-pool
for suy := 0; suy < sNuY; suy++ {
for sux := 0; sux < sNuX; sux++ {
fwt := float32(1)
if pt.SigFull.On {
sf := math32.Vec2(float32(sux), float32(suy))
sigx := efuns.Logistic(sgn.X*sf.X, fgain, sfctr.X)
sigy := efuns.Logistic(sgn.Y*sf.Y, fgain, sfctr.Y)
fwt = sigx * sigy
}
pwt := float32(1)
if pt.SigInPool.On {
sp := math32.Vec2(float32(sux), float32(suy))
sigx := efuns.Logistic(sgn.X*sp.X, pgain, spctr.X)
sigy := efuns.Logistic(sgn.Y*sp.Y, pgain, spctr.Y)
pwt = sigx * sigy
}
wt := fwt * pwt
rwt := pt.TopoRange.ProjValue(wt)
wts.Set(rwt, ruy, rux, suy, sux)
}
}
}
}
return nil
}
// TopoWeightsSigmoid4D sets values in given 6D tensor according to Topo settings.
// wts is shaped with first 2 outer-most dims as Y, X of units within pool
// of recv layer (these are units over which topography is defined)
// and remaing 2D is for receptive field Size by units within pool size for
// sending layer.
func (pt *PoolTileSub) TopoWeightsSigmoid4D(send, recv *tensor.Shape, wts *tensor.Float32) error {
if pt.SigFull.Gain == 0 {
pt.SigFull.Defaults()
}
if pt.SigInPool.Gain == 0 {
pt.SigInPool.Defaults()
}
sNuY := send.DimSize(2)
sNuX := send.DimSize(3)
rNuY := recv.DimSize(0) // ok if recv is 2D
rNuX := recv.DimSize(1)
if recv.NumDims() == 4 {
rNuY = recv.DimSize(2)
rNuX = recv.DimSize(3)
}
wts.SetShapeSizes(rNuY, rNuX, pt.Size.Y, pt.Size.X, sNuY, sNuX)
fsz := math32.Vec2(float32(pt.Size.X*sNuX-1), float32(pt.Size.Y*sNuY-1)) // full rf size
hfsz := fsz.MulScalar(0.5) // half rf
fgain := pt.SigFull.Gain * hfsz.X // full gain
psz := math32.Vec2(float32(sNuX), float32(sNuY)) // within-pool rf size
if sNuX > 1 {
psz.X -= 1
}
if sNuY > 1 {
psz.Y -= 1
}
hpsz := psz.MulScalar(0.5) // half rf
pgain := pt.SigInPool.Gain * hpsz.X // pool sigma
rsz := math32.Vec2(float32(rNuX), float32(rNuY)) // recv units-in-pool size
if rNuX > 1 {
rsz.X -= 1
}
if rNuY > 1 {
rsz.Y -= 1
}
hrsz := rsz.MulScalar(0.5)
for ruy := 0; ruy < rNuY; ruy++ {
for rux := 0; rux < rNuX; rux++ {
rpos := math32.Vec2(float32(rux), float32(ruy)).Div(hrsz) // 0..2 normalized r unit pos
sgn := math32.Vec2(1, 1)
rfpos := rpos.SubScalar(0.5).MulScalar(pt.SigFull.CtrMove).AddScalar(0.5)
rppos := rpos.SubScalar(0.5).MulScalar(pt.SigInPool.CtrMove).AddScalar(0.5)
if rpos.X >= 1 { // flip direction half-way through
sgn.X = -1
rpos.X = -rpos.X + 1
rfpos.X = (rpos.X+0.5)*pt.SigFull.CtrMove - 0.5
rppos.X = (rpos.X+0.5)*pt.SigInPool.CtrMove - 0.5
}
if rpos.Y >= 1 {
sgn.Y = -1
rpos.Y = -rpos.Y + 1
rfpos.Y = (rpos.Y+0.5)*pt.SigFull.CtrMove - 0.5
rfpos.Y = (rpos.Y+0.5)*pt.SigInPool.CtrMove - 0.5
}
sfctr := rfpos.Mul(fsz) // sending center for full
spctr := rppos.Mul(psz) // sending center for within-pool
for fy := 0; fy < pt.Size.Y; fy++ {
for fx := 0; fx < pt.Size.X; fx++ {
for suy := 0; suy < sNuY; suy++ {
for sux := 0; sux < sNuX; sux++ {
fwt := float32(1)
if pt.SigFull.On {
sf := math32.Vec2(float32(fx*sNuX+sux), float32(fy*sNuY+suy))
sigx := efuns.Logistic(sgn.X*sf.X, fgain, sfctr.X)
sigy := efuns.Logistic(sgn.Y*sf.Y, fgain, sfctr.Y)
fwt = sigx * sigy
}
pwt := float32(1)
if pt.SigInPool.On {
sp := math32.Vec2(float32(sux), float32(suy))
sigx := efuns.Logistic(sgn.X*sp.X, pgain, spctr.X)
sigy := efuns.Logistic(sgn.Y*sp.Y, pgain, spctr.Y)
pwt = sigx * sigy
}
wt := fwt * pwt
rwt := pt.TopoRange.ProjValue(wt)
wts.Set(rwt, ruy, rux, fy, fx, suy, sux)
}
}
}
}
}
}
return nil
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package paths
import (
"math"
"sort"
"cogentcore.org/lab/base/randx"
"cogentcore.org/lab/tensor"
)
// PoolUniformRand implements random pattern of connectivity between pools within layers.
// Pools are the outer-most two dimensions of a 4D layer shape.
// If either layer does not have pools, PoolUniformRand works as UniformRand does.
// If probability of connection (PCon) is 1, PoolUniformRand works as PoolOnetoOne does.
type PoolUniformRand struct {
PoolOneToOne
UniformRand
}
func NewPoolUniformRand() *PoolUniformRand {
newur := &PoolUniformRand{}
newur.PCon = 0.5
return newur
}
func (ur *PoolUniformRand) Name() string {
return "PoolUniformRand"
}
func (ur *PoolUniformRand) Connect(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
if send.NumDims() == 4 && recv.NumDims() == 4 {
return ur.ConnectPoolsRand(send, recv, same)
}
return ur.ConnectRand(send, recv, same)
}
// ConnectPoolsRand is when both recv and send have pools
func (ur *PoolUniformRand) ConnectPoolsRand(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
if ur.PCon >= 1 {
return ur.ConnectPools(send, recv, same)
}
sendn, recvn, cons = NewTensors(send, recv)
sNtot := send.Len()
// rNtot := recv.Len()
sNp := send.DimSize(0) * send.DimSize(1)
rNp := recv.DimSize(0) * recv.DimSize(1)
sNu := send.DimSize(2) * send.DimSize(3)
rNu := recv.DimSize(2) * recv.DimSize(3)
rnv := recvn.Values
snv := sendn.Values
npl := rNp
noself := same && !ur.SelfCon
var nsend int
if noself {
nsend = int(math.Round(float64(ur.PCon) * float64(sNu-1)))
} else {
nsend = int(math.Round(float64(ur.PCon) * float64(sNu)))
}
ur.InitRand()
sordlen := sNu
if noself {
sordlen--
}
sorder := ur.Rand.Perm(sordlen)
slist := make([]int, nsend)
if ur.NPools > 0 {
npl = min(ur.NPools, rNp)
}
for i := 0; i < npl; i++ {
rpi := ur.RecvStart + i
spi := ur.SendStart + i
if rpi >= rNp || spi >= sNp {
break
}
for rui := 0; rui < rNu; rui++ {
ri := rpi*rNu + rui
rnv[ri] = int32(nsend)
if noself { // need to exclude ri
ix := 0
for j := 0; j < sNu; j++ {
ji := spi*sNu + j
if ji != ri {
sorder[ix] = j
ix++
}
}
randx.PermuteInts(sorder, ur.Rand)
}
copy(slist, sorder)
sort.Ints(slist)
for sui := 0; sui < nsend; sui++ {
si := spi*sNu + slist[sui]
off := ri*sNtot + si
cons.Values.Set(true, off)
}
randx.PermuteInts(sorder, ur.Rand)
}
for sui := 0; sui < sNu; sui++ {
nr := 0
si := spi*sNu + sui
for rui := 0; rui < rNu; rui++ {
ri := rpi*rNu + rui
off := ri*sNtot + si
if cons.Values.Index(off) {
nr++
}
}
snv[si] = int32(nr)
}
}
return
}
// ConnectRand is a copy of UniformRand.Connect with initial if statement modified
func (ur *PoolUniformRand) ConnectRand(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
if ur.PCon >= 1 {
switch {
case send.NumDims() == 2 && recv.NumDims() == 4:
return ur.ConnectRecvPool(send, recv, same)
case send.NumDims() == 4 && recv.NumDims() == 2:
return ur.ConnectSendPool(send, recv, same)
case send.NumDims() == 2 && recv.NumDims() == 2:
return ur.ConnectOneToOne(send, recv, same)
}
}
if ur.Recip {
return ur.ConnectRecip(send, recv, same)
}
sendn, recvn, cons = NewTensors(send, recv)
slen := send.Len()
rlen := recv.Len()
noself := same && !ur.SelfCon
var nsend int
if noself {
nsend = int(math.Round(float64(ur.PCon) * float64(slen-1)))
} else {
nsend = int(math.Round(float64(ur.PCon) * float64(slen)))
}
// NOTE: this is reasonably accurate: mean + 3 * SEM, but we can just use
// empirical values more easily and safely.
// recv number is even distribution across recvs plus some imbalance factor
// nrMean := float32(rlen*nsend) / float32(slen)
// // add 3 * SEM as corrective factor
// nrSEM := nrMean / math32.Sqrt(nrMean)
// nrecv := int(nrMean + 3*nrSEM)
// if nrecv > rlen {
// nrecv = rlen
// }
rnv := recvn.Values
for i := range rnv {
rnv[i] = int32(nsend)
}
ur.InitRand()
sordlen := slen
if noself {
sordlen--
}
sorder := ur.Rand.Perm(sordlen)
slist := make([]int, nsend)
for ri := 0; ri < rlen; ri++ {
if noself { // need to exclude ri
ix := 0
for j := 0; j < slen; j++ {
if j != ri {
sorder[ix] = j
ix++
}
}
randx.PermuteInts(sorder, ur.Rand)
}
copy(slist, sorder)
sort.Ints(slist) // keep list sorted for more efficient memory traversal etc
for si := 0; si < nsend; si++ {
off := ri*slen + slist[si]
cons.Values.Set(true, off)
}
randx.PermuteInts(sorder, ur.Rand)
}
// set send n's empirically
snv := sendn.Values
for si := range snv {
nr := 0
for ri := 0; ri < rlen; ri++ {
off := ri*slen + si
if cons.Values.Index(off) {
nr++
}
}
snv[si] = int32(nr)
}
return
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package paths
import (
"cogentcore.org/core/math32"
"cogentcore.org/core/math32/vecint"
"cogentcore.org/lab/tensor"
"github.com/emer/emergent/v2/edge"
)
// Rect implements a rectangular pattern of connectivity between two layers
// where the lower-left corner moves in proportion to receiver position with offset
// and multiplier factors (with wrap-around optionally).
// 4D layers are automatically flattened to 2D for this pathway.
type Rect struct {
// size of rectangle in sending layer that each receiving unit receives from
Size vecint.Vector2i
// starting offset in sending layer, for computing the corresponding sending lower-left corner relative to given recv unit position
Start vecint.Vector2i
// scaling to apply to receiving unit position to compute corresponding position in sending layer of the lower-left corner of rectangle
Scale math32.Vector2
// auto-set the Scale as function of the relative sizes of send and recv layers (e.g., if sending layer is 2x larger than receiving, Scale = 2)
AutoScale bool
// if true, use Round when applying scaling factor -- otherwise uses Floor which makes Scale work like a grouping factor -- e.g., .25 will effectively group 4 recv units with same send position
RoundScale bool
// if true, connectivity wraps around all edges if it would otherwise go off the edge -- if false, then edges are clipped
Wrap bool
// if true, and connecting layer to itself (self pathway), then make a self-connection from unit to itself
SelfCon bool
// make the reciprocal of the specified connections -- i.e., symmetric for swapping recv and send
Recip bool
// starting position in receiving layer -- if > 0 then units below this starting point remain unconnected
RecvStart vecint.Vector2i
// number of units in receiving layer to connect -- if 0 then all (remaining after RecvStart) are connected -- otherwise if < remaining then those beyond this point remain unconnected
RecvN vecint.Vector2i
}
func NewRect() *Rect {
cr := &Rect{}
cr.Defaults()
return cr
}
// NewRectRecip creates a new Rect that is a Recip version of given ff one
func NewRectRecip(ff *Rect) *Rect {
cr := &Rect{}
*cr = *ff
cr.Recip = true
return cr
}
func (cr *Rect) Defaults() {
cr.Wrap = true
cr.Size.Set(2, 2)
cr.Scale.SetScalar(1)
}
func (cr *Rect) Name() string {
return "Rect"
}
func (cr *Rect) Connect(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
if cr.Recip {
return cr.ConnectRecip(send, recv, same)
}
sendn, recvn, cons = NewTensors(send, recv)
sNy, sNx, _, _ := tensor.Projection2DShape(send, false)
rNy, rNx, _, _ := tensor.Projection2DShape(recv, false)
rnv := recvn.Values
snv := sendn.Values
sNtot := send.Len()
rNyEff := rNy
if cr.RecvN.Y > 0 {
rNyEff = min(rNy, cr.RecvN.Y)
}
if cr.RecvStart.Y > 0 {
rNyEff = min(rNyEff, rNy-cr.RecvStart.Y)
}
rNxEff := rNx
if cr.RecvN.X > 0 {
rNxEff = min(rNx, cr.RecvN.X)
}
if cr.RecvStart.X > 0 {
rNxEff = min(rNxEff, rNx-cr.RecvStart.X)
}
sc := cr.Scale
if cr.AutoScale {
ssz := math32.Vec2(float32(sNx), float32(sNy))
rsz := math32.Vec2(float32(rNxEff), float32(rNyEff))
sc = ssz.Div(rsz)
}
for ry := cr.RecvStart.Y; ry < rNyEff+cr.RecvStart.Y; ry++ {
for rx := cr.RecvStart.X; rx < rNxEff+cr.RecvStart.X; rx++ {
ri := tensor.Projection2DIndex(recv, false, ry, rx)
sst := cr.Start
if cr.RoundScale {
sst.X += int(math32.Round(float32(rx-cr.RecvStart.X) * sc.X))
sst.Y += int(math32.Round(float32(ry-cr.RecvStart.Y) * sc.Y))
} else {
sst.X += int(math32.Floor(float32(rx-cr.RecvStart.X) * sc.X))
sst.Y += int(math32.Floor(float32(ry-cr.RecvStart.Y) * sc.Y))
}
for y := 0; y < cr.Size.Y; y++ {
sy, clipy := edge.Edge(sst.Y+y, sNy, cr.Wrap)
if clipy {
continue
}
for x := 0; x < cr.Size.X; x++ {
sx, clipx := edge.Edge(sst.X+x, sNx, cr.Wrap)
if clipx {
continue
}
si := tensor.Projection2DIndex(send, false, sy, sx)
off := ri*sNtot + si
if !cr.SelfCon && same && ri == si {
continue
}
cons.Values.Set(true, off)
rnv[ri]++
snv[si]++
}
}
}
}
return
}
func (cr *Rect) ConnectRecip(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
sendn, recvn, cons = NewTensors(send, recv)
sNy, sNx, _, _ := tensor.Projection2DShape(recv, false) // swapped!
rNy, rNx, _, _ := tensor.Projection2DShape(send, false)
rnv := recvn.Values
snv := sendn.Values
sNtot := send.Len()
rNyEff := rNy
if cr.RecvN.Y > 0 {
rNyEff = min(rNy, cr.RecvN.Y)
}
if cr.RecvStart.Y > 0 {
rNyEff = min(rNyEff, rNy-cr.RecvStart.Y)
}
rNxEff := rNx
if cr.RecvN.X > 0 {
rNxEff = min(rNx, cr.RecvN.X)
}
if cr.RecvStart.X > 0 {
rNxEff = min(rNxEff, rNx-cr.RecvStart.X)
}
sc := cr.Scale
if cr.AutoScale {
ssz := math32.Vec2(float32(sNx), float32(sNy))
rsz := math32.Vec2(float32(rNxEff), float32(rNyEff))
sc = ssz.Div(rsz)
}
for ry := cr.RecvStart.Y; ry < rNyEff+cr.RecvStart.Y; ry++ {
for rx := cr.RecvStart.X; rx < rNxEff+cr.RecvStart.X; rx++ {
ri := tensor.Projection2DIndex(send, false, ry, rx)
sst := cr.Start
if cr.RoundScale {
sst.X += int(math32.Round(float32(rx-cr.RecvStart.X) * sc.X))
sst.Y += int(math32.Round(float32(ry-cr.RecvStart.Y) * sc.Y))
} else {
sst.X += int(math32.Floor(float32(rx-cr.RecvStart.X) * sc.X))
sst.Y += int(math32.Floor(float32(ry-cr.RecvStart.Y) * sc.Y))
}
for y := 0; y < cr.Size.Y; y++ {
sy, clipy := edge.Edge(sst.Y+y, sNy, cr.Wrap)
if clipy {
continue
}
for x := 0; x < cr.Size.X; x++ {
sx, clipx := edge.Edge(sst.X+x, sNx, cr.Wrap)
if clipx {
continue
}
si := tensor.Projection2DIndex(recv, false, sy, sx)
off := si*sNtot + ri
if !cr.SelfCon && same && ri == si {
continue
}
cons.Values.Set(true, off)
rnv[si]++
snv[ri]++
}
}
}
}
return
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package paths
import (
"math"
"math/rand"
"sort"
"cogentcore.org/lab/base/randx"
"cogentcore.org/lab/tensor"
)
// UniformRand implements uniform random pattern of connectivity between two layers
// using a permuted (shuffled) list for without-replacement randomness,
// and maintains its own local random number source and seed
// which are initialized if Rand == nil -- usually best to keep this
// specific to each instance of a pathway so it is fully reproducible
// and doesn't interfere with other random number streams.
type UniformRand struct {
// probability of connection (0-1)
PCon float32 `min:"0" max:"1"`
// if true, and connecting layer to itself (self pathway), then make a self-connection from unit to itself
SelfCon bool
// reciprocal connectivity: if true, switch the sending and receiving layers to create a symmetric top-down pathway -- ESSENTIAL to use same RandSeed between two paths to ensure symmetry
Recip bool
// random number source -- is created with its own separate source if nil
Rand randx.Rand `display:"-"`
// the current random seed -- will be initialized to a new random number from the global random stream when Rand is created.
RandSeed int64 `display:"-"`
}
func NewUniformRand() *UniformRand {
return &UniformRand{PCon: 0.5}
}
func (ur *UniformRand) Name() string {
return "UniformRand"
}
func (ur *UniformRand) InitRand() {
if ur.Rand != nil {
ur.Rand.Seed(ur.RandSeed)
return
}
if ur.RandSeed == 0 {
ur.RandSeed = int64(rand.Uint64())
}
ur.Rand = randx.NewSysRand(ur.RandSeed)
}
func (ur *UniformRand) Connect(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
if ur.PCon >= 1 {
return ur.ConnectFull(send, recv, same)
}
if ur.Recip {
return ur.ConnectRecip(send, recv, same)
}
sendn, recvn, cons = NewTensors(send, recv)
slen := send.Len()
rlen := recv.Len()
noself := same && !ur.SelfCon
var nsend int
if noself {
nsend = int(math.Round(float64(ur.PCon) * float64(slen-1)))
} else {
nsend = int(math.Round(float64(ur.PCon) * float64(slen)))
}
// NOTE: this is reasonably accurate: mean + 3 * SEM, but we can just use
// empirical values more easily and safely.
// recv number is even distribution across recvs plus some imbalance factor
// nrMean := float32(rlen*nsend) / float32(slen)
// // add 3 * SEM as corrective factor
// nrSEM := nrMean / math32.Sqrt(nrMean)
// nrecv := int(nrMean + 3*nrSEM)
// if nrecv > rlen {
// nrecv = rlen
// }
rnv := recvn.Values
for i := range rnv {
rnv[i] = int32(nsend)
}
ur.InitRand()
sordlen := slen
if noself {
sordlen--
}
sorder := ur.Rand.Perm(sordlen)
slist := make([]int, nsend)
for ri := 0; ri < rlen; ri++ {
if noself { // need to exclude ri
ix := 0
for j := 0; j < slen; j++ {
if j != ri {
sorder[ix] = j
ix++
}
}
randx.PermuteInts(sorder, ur.Rand)
}
copy(slist, sorder)
sort.Ints(slist) // keep list sorted for more efficient memory traversal etc
for si := 0; si < nsend; si++ {
off := ri*slen + slist[si]
cons.Values.Set(true, off)
}
randx.PermuteInts(sorder, ur.Rand)
}
// set send n's empirically
snv := sendn.Values
for si := range snv {
nr := 0
for ri := 0; ri < rlen; ri++ {
off := ri*slen + si
if cons.Values.Index(off) {
nr++
}
}
snv[si] = int32(nr)
}
return
}
// ConnectRecip does reciprocal connectvity
func (ur *UniformRand) ConnectRecip(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
sendn, recvn, cons = NewTensors(send, recv)
slen := recv.Len() // swapped
rlen := send.Len()
slenR := send.Len() // NOT swapped
noself := same && !ur.SelfCon
var nsend int
if noself {
nsend = int(math.Round(float64(ur.PCon) * float64(slen-1)))
} else {
nsend = int(math.Round(float64(ur.PCon) * float64(slen)))
}
rnv := sendn.Values // swapped
for i := range rnv {
rnv[i] = int32(nsend)
}
ur.InitRand()
sordlen := slen
if noself {
sordlen--
}
sorder := ur.Rand.Perm(sordlen)
slist := make([]int, nsend)
for ri := 0; ri < rlen; ri++ {
if noself { // need to exclude ri
ix := 0
for j := 0; j < slen; j++ {
if j != ri {
sorder[ix] = j
ix++
}
}
randx.PermuteInts(sorder, ur.Rand)
}
copy(slist, sorder)
sort.Ints(slist) // keep list sorted for more efficient memory traversal etc
for si := 0; si < nsend; si++ {
off := slist[si]*slenR + ri
cons.Values.Set(true, off)
}
randx.PermuteInts(sorder, ur.Rand)
}
// set send n's empirically
snv := recvn.Values // swapped
for si := range snv {
nr := 0
for ri := 0; ri < rlen; ri++ { // actually si
off := si*slenR + ri
if cons.Values.Index(off) {
nr++
}
}
snv[si] = int32(nr)
}
return
}
func (ur *UniformRand) ConnectFull(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
sendn, recvn, cons = NewTensors(send, recv)
cons.Values.SetAll(true)
nsend := send.Len()
nrecv := recv.Len()
if same && !ur.SelfCon {
for i := 0; i < nsend; i++ { // nsend = nrecv
off := i*nsend + i
cons.Values.Set(false, off)
}
nsend--
nrecv--
}
rnv := recvn.Values
for i := range rnv {
rnv[i] = int32(nsend)
}
snv := sendn.Values
for i := range snv {
snv[i] = int32(nrecv)
}
return
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package popcode
//go:generate core generate -add-types
import (
"sort"
"cogentcore.org/core/math32"
)
type PopCodes int
const (
// GaussBump = gaussian bump, with value = weighted average of tuned unit values
GaussBump PopCodes = iota
// Localist = each unit represents a distinct value; intermediate values represented by graded activity of neighbors; overall activity is weighted-average across all units
Localist
)
// popcode.OneD provides encoding and decoding of population
// codes, used to represent a single continuous (scalar) value
// across a population of units / neurons (1 dimensional)
type OneD struct {
// how to encode the value
Code PopCodes
// minimum value representable -- for GaussBump, typically include extra to allow mean with activity on either side to represent the lowest value you want to encode
Min float32
// maximum value representable -- for GaussBump, typically include extra to allow mean with activity on either side to represent the lowest value you want to encode
Max float32
// sigma parameter of a gaussian specifying the tuning width of the coarse-coded units, in normalized 0-1 range
Sigma float32 `default:"0.2"`
// ensure that encoded and decoded value remains within specified range
Clip bool
// for decoding, threshold to cut off small activation contributions to overall average value (i.e., if unit's activation is below this threshold, it doesn't contribute to weighted average computation)
Thr float32 `default:"0.1"`
// minimum total activity of all the units representing a value: when computing weighted average value, this is used as a minimum for the sum that you divide by
MinSum float32 `default:"0.2"`
}
func (pc *OneD) Defaults() {
pc.Code = GaussBump
pc.Min = -0.5
pc.Max = 1.5
pc.Sigma = 0.2
pc.Clip = true
pc.Thr = 0.1
pc.MinSum = 0.2
}
func (pc *OneD) ShouldDisplay(field string) bool {
switch field {
case "Sigma":
return pc.Code == GaussBump
default:
return true
}
}
// SetRange sets the min, max and sigma values
func (pc *OneD) SetRange(min, max, sigma float32) {
pc.Min = min
pc.Max = max
pc.Sigma = sigma
}
const (
// Add is used for popcode Encode methods, add arg -- indicates to add values
// to any existing values in the target vector / tensor:
// used for encoding additional values (see DecodeN for decoding).
Add = true
// Set is used for popcode Encode methods, add arg -- indicates to set values
// in any existing values in the target vector / tensor:
// used for encoding first / only values.
Set = false
)
// Encode generates a pattern of activation of given size to encode given value.
// n must be 2 or more. pat slice will be constructed if len != n.
// If add == false (use Set const for clarity), values are set to pattern
// else if add == true (Add), then values are added to any existing,
// for encoding additional values in same pattern.
func (pc *OneD) Encode(pat *[]float32, val float32, n int, add bool) {
if len(*pat) != n {
*pat = make([]float32, n)
}
if pc.Clip {
val = math32.Clamp(val, pc.Min, pc.Max)
}
rng := pc.Max - pc.Min
gnrm := 1 / (rng * pc.Sigma)
incr := rng / float32(n-1)
for i := 0; i < n; i++ {
trg := pc.Min + incr*float32(i)
act := float32(0)
switch pc.Code {
case GaussBump:
dist := gnrm * (trg - val)
act = math32.Exp(-(dist * dist))
case Localist:
dist := math32.Abs(trg - val)
if dist > incr {
act = 0
} else {
act = 1.0 - (dist / incr)
}
}
if add {
(*pat)[i] += act
} else {
(*pat)[i] = act
}
}
}
// Decode decodes value from a pattern of activation
// as the activation-weighted-average of the unit's preferred
// tuning values.
// must have 2 or more values in pattern pat.
func (pc *OneD) Decode(pat []float32) float32 {
n := len(pat)
if n < 2 {
return 0
}
rng := pc.Max - pc.Min
incr := rng / float32(n-1)
avg := float32(0)
sum := float32(0)
for i, act := range pat {
if act < pc.Thr {
act = 0
}
trg := pc.Min + incr*float32(i)
avg += trg * act
sum += act
}
sum = math32.Max(sum, pc.MinSum)
avg /= sum
return avg
}
// Values sets the vals slice to the target preferred tuning values
// for each unit, for a distribution of given size n.
// n must be 2 or more.
// vals slice will be constructed if len != n
func (pc *OneD) Values(vals *[]float32, n int) {
if len(*vals) != n {
*vals = make([]float32, n)
}
rng := pc.Max - pc.Min
incr := rng / float32(n-1)
for i := 0; i < n; i++ {
trg := pc.Min + incr*float32(i)
(*vals)[i] = trg
}
}
// DecodeNPeaks decodes N values from a pattern of activation
// using a neighborhood of specified width around local maxima,
// which is the amount on either side of the central point to
// accumulate (0 = localist, single points, 1 = +/- 1 point on
// either side, etc).
// Allocates a temporary slice of size pat, and sorts that: relatively expensive
func (pc *OneD) DecodeNPeaks(pat []float32, nvals, width int) []float32 {
n := len(pat)
if n < 2 {
return nil
}
rng := pc.Max - pc.Min
incr := rng / float32(n-1)
type navg struct {
avg float32
idx int
}
avgs := make([]navg, n)
for i := range pat {
sum := float32(0)
ns := 0
for d := -width; d <= width; d++ {
di := i + d
if di < 0 || di >= n {
continue
}
act := pat[di]
if act < pc.Thr {
continue
}
sum += pat[di]
ns++
}
avgs[i].avg = sum / float32(ns)
avgs[i].idx = i
}
// sort highest to lowest
sort.Slice(avgs, func(i, j int) bool {
return avgs[i].avg > avgs[j].avg
})
vals := make([]float32, nvals)
for i := range vals {
avg := float32(0)
sum := float32(0)
mxi := avgs[i].idx
for d := -width; d <= width; d++ {
di := mxi + d
if di < 0 || di >= n {
continue
}
act := pat[di]
if act < pc.Thr {
act = 0
}
trg := pc.Min + incr*float32(di)
avg += trg * act
sum += act
}
sum = math32.Max(sum, pc.MinSum)
vals[i] = avg / sum
}
return vals
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package popcode
import (
"fmt"
"sort"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/math32"
"cogentcore.org/lab/tensor"
)
// popcode.TwoD provides encoding and decoding of population
// codes, used to represent two continuous (scalar) values
// across a 2D tensor, using row-major XY encoding:
// Y = outer, first dim, X = inner, second dim
type TwoD struct {
// how to encode the value
Code PopCodes
// minimum value representable on each dim -- for GaussBump, typically include extra to allow mean with activity on either side to represent the lowest value you want to encode
Min math32.Vector2
// maximum value representable on each dim -- for GaussBump, typically include extra to allow mean with activity on either side to represent the lowest value you want to encode
Max math32.Vector2
// sigma parameters of a gaussian specifying the tuning width of the coarse-coded units, in normalized 0-1 range
Sigma math32.Vector2 `default:"0.2"`
// ensure that encoded and decoded value remains within specified range -- generally not useful with wrap
Clip bool
// x axis wraps around (e.g., for periodic values such as angle) -- encodes and decodes relative to both the min and max values
WrapX bool
// y axis wraps around (e.g., for periodic values such as angle) -- encodes and decodes relative to both the min and max values
WrapY bool
// threshold to cut off small activation contributions to overall average value (i.e., if unit's activation is below this threshold, it doesn't contribute to weighted average computation)
Thr float32 `default:"0.1"`
// minimum total activity of all the units representing a value: when computing weighted average value, this is used as a minimum for the sum that you divide by
MinSum float32 `default:"0.2"`
}
func (pc *TwoD) Defaults() {
pc.Code = GaussBump
pc.Min.Set(-0.5, -0.5)
pc.Max.Set(1.5, 1.5)
pc.Sigma.Set(0.2, 0.2)
pc.Clip = true
pc.Thr = 0.1
pc.MinSum = 0.2
}
func (pc *TwoD) ShouldDisplay(field string) bool {
switch field {
case "Sigma":
return pc.Code == GaussBump
default:
return true
}
}
// SetRange sets the min, max and sigma values to the same scalar values
func (pc *TwoD) SetRange(min, max, sigma float32) {
pc.Min.Set(min, min)
pc.Max.Set(max, max)
pc.Sigma.Set(sigma, sigma)
}
// Encode generates a pattern of activation on given tensor, which must already have
// appropriate 2D shape which is used for encoding sizes (error if not).
// If add == false (use Set const for clarity), values are set to pattern
// else if add == true (Add), then values are added to any existing,
// for encoding additional values in same pattern.
func (pc *TwoD) Encode(pat tensor.Tensor, val math32.Vector2, add bool) error {
if pat.NumDims() != 2 {
err := fmt.Errorf("popcode.TwoD Encode: pattern must have 2 dimensions")
return errors.Log(err)
}
if pc.Clip {
val.Clamp(pc.Min, pc.Max)
}
rng := pc.Max.Sub(pc.Min)
sr := pc.Sigma.Mul(rng)
if pc.WrapX || pc.WrapY {
err := pc.EncodeImpl(pat, val, add) // always render first
ev := val
// relative to min
if pc.WrapX && math32.Abs(pc.Max.X-val.X) < sr.X { // has significant vals near max
ev.X = pc.Min.X - (pc.Max.X - val.X) // wrap extra above max around to min
}
if pc.WrapY && math32.Abs(pc.Max.Y-val.Y) < sr.Y {
ev.Y = pc.Min.Y - (pc.Max.Y - val.Y)
}
if ev != val {
err = pc.EncodeImpl(pat, ev, Add) // always add
}
// pev := ev
ev = val
if pc.WrapX && math32.Abs(val.X-pc.Min.X) < sr.X { // has significant vals near min
ev.X = pc.Max.X - (val.X - pc.Min.X) // wrap extra below min around to max
}
if pc.WrapY && math32.Abs(val.Y-pc.Min.Y) < sr.Y {
ev.Y = pc.Max.Y - (val.Y - pc.Min.Y)
}
if ev != val {
err = pc.EncodeImpl(pat, ev, Add) // always add
}
return err
}
return pc.EncodeImpl(pat, val, add)
}
// EncodeImpl is the implementation of encoding -- e.g., used twice for Wrap
func (pc *TwoD) EncodeImpl(pat tensor.Tensor, val math32.Vector2, add bool) error {
rng := pc.Max.Sub(pc.Min)
gnrm := math32.Vector2Scalar(1).Div(rng.Mul(pc.Sigma))
ny := pat.DimSize(0)
nx := pat.DimSize(1)
nf := math32.Vec2(float32(nx-1), float32(ny-1))
incr := rng.Div(nf)
for yi := 0; yi < ny; yi++ {
for xi := 0; xi < nx; xi++ {
fi := math32.Vec2(float32(xi), float32(yi))
trg := pc.Min.Add(incr.Mul(fi))
act := float32(0)
switch pc.Code {
case GaussBump:
dist := trg.Sub(val).Mul(gnrm)
act = math32.Exp(-dist.LengthSquared())
case Localist:
dist := trg.Sub(val)
dist.X = math32.Abs(dist.X)
dist.Y = math32.Abs(dist.Y)
if dist.X > incr.X || dist.Y > incr.Y {
act = 0
} else {
nd := dist.Div(incr)
act = 1.0 - 0.5*(nd.X+nd.Y)
}
}
idx := []int{yi, xi}
if add {
val := float64(act) + pat.Float(idx...)
pat.SetFloat(val, idx...)
} else {
pat.SetFloat(float64(act), idx...)
}
}
}
return nil
}
// Decode decodes 2D value from a pattern of activation
// as the activation-weighted-average of the unit's preferred
// tuning values.
func (pc *TwoD) Decode(pat tensor.Tensor) (math32.Vector2, error) {
if pat.NumDims() != 2 {
err := fmt.Errorf("popcode.TwoD Decode: pattern must have 2 dimensions")
return math32.Vector2{}, errors.Log(err)
}
if pc.WrapX || pc.WrapY {
ny := pat.DimSize(0)
nx := pat.DimSize(1)
ys := make([]float32, ny)
xs := make([]float32, nx)
ydiv := 1.0 / (float32(nx) * pc.Sigma.X)
xdiv := 1.0 / (float32(ny) * pc.Sigma.Y)
for yi := 0; yi < ny; yi++ {
for xi := 0; xi < nx; xi++ {
idx := []int{yi, xi}
act := float32(pat.Float(idx...))
if act < pc.Thr {
act = 0
}
ys[yi] += act * ydiv
xs[xi] += act * xdiv
}
}
var val math32.Vector2
switch {
case pc.WrapX && pc.WrapY:
dx := Ring{}
dx.Defaults()
dx.Min = pc.Min.X
dx.Max = pc.Max.X
dx.Sigma = pc.Sigma.X
dx.Thr = pc.Thr
dx.MinSum = pc.MinSum
dx.Code = pc.Code
dy := Ring{}
dy.Defaults()
dy.Min = pc.Min.Y
dy.Max = pc.Max.Y
dy.Sigma = pc.Sigma.Y
dy.Thr = pc.Thr
dy.MinSum = pc.MinSum
dy.Code = pc.Code
val.X = dx.Decode(xs)
val.Y = dy.Decode(ys)
case pc.WrapX:
dx := Ring{}
dx.Defaults()
dx.Min = pc.Min.X
dx.Max = pc.Max.X
dx.Sigma = pc.Sigma.X
dx.Thr = pc.Thr
dx.MinSum = pc.MinSum
dx.Code = pc.Code
dy := OneD{}
dy.Defaults()
dy.Min = pc.Min.Y
dy.Max = pc.Max.Y
dy.Sigma = pc.Sigma.Y
dy.Thr = pc.Thr
dy.MinSum = pc.MinSum
dy.Code = pc.Code
val.X = dx.Decode(xs)
val.Y = dy.Decode(ys)
case pc.WrapY:
dx := OneD{}
dx.Defaults()
dx.Min = pc.Min.X
dx.Max = pc.Max.X
dx.Sigma = pc.Sigma.X
dx.Thr = pc.Thr
dx.MinSum = pc.MinSum
dx.Code = pc.Code
dy := Ring{}
dy.Defaults()
dy.Min = pc.Min.Y
dy.Max = pc.Max.Y
dy.Sigma = pc.Sigma.Y
dy.Thr = pc.Thr
dy.MinSum = pc.MinSum
dy.Code = pc.Code
val.X = dx.Decode(xs)
val.Y = dy.Decode(ys)
}
return val, nil
} else {
return pc.DecodeImpl(pat)
}
}
// DecodeImpl does direct decoding of x, y simultaneously -- for non-wrap
func (pc *TwoD) DecodeImpl(pat tensor.Tensor) (math32.Vector2, error) {
avg := math32.Vector2{}
rng := pc.Max.Sub(pc.Min)
ny := pat.DimSize(0)
nx := pat.DimSize(1)
nf := math32.Vec2(float32(nx-1), float32(ny-1))
incr := rng.Div(nf)
sum := float32(0)
for yi := 0; yi < ny; yi++ {
for xi := 0; xi < nx; xi++ {
idx := []int{yi, xi}
act := float32(pat.Float(idx...))
if act < pc.Thr {
act = 0
}
fi := math32.Vec2(float32(xi), float32(yi))
trg := pc.Min.Add(incr.Mul(fi))
avg = avg.Add(trg.MulScalar(act))
sum += act
}
}
sum = math32.Max(sum, pc.MinSum)
return avg.DivScalar(sum), nil
}
// Values sets the vals slices to the target preferred tuning values
// for each unit, for a distribution of given dimensions.
// n's must be 2 or more in each dim.
// vals slice will be constructed if len != n
func (pc *TwoD) Values(valsX, valsY *[]float32, nx, ny int) {
rng := pc.Max.Sub(pc.Min)
nf := math32.Vec2(float32(nx-1), float32(ny-1))
incr := rng.Div(nf)
// X
if len(*valsX) != nx {
*valsX = make([]float32, nx)
}
for i := 0; i < nx; i++ {
trg := pc.Min.X + incr.X*float32(i)
(*valsX)[i] = trg
}
// Y
if len(*valsY) != ny {
*valsY = make([]float32, ny)
}
for i := 0; i < ny; i++ {
trg := pc.Min.Y + incr.Y*float32(i)
(*valsY)[i] = trg
}
}
// DecodeNPeaks decodes N values from a pattern of activation
// using a neighborhood of specified width around local maxima,
// which is the amount on either side of the central point to
// accumulate (0 = localist, single points, 1 = +/- 1 points on
// either side in a square around central point, etc)
// Allocates a temporary slice of size pat, and sorts that: relatively expensive
func (pc *TwoD) DecodeNPeaks(pat tensor.Tensor, nvals, width int) ([]math32.Vector2, error) {
if pat.NumDims() != 2 {
err := fmt.Errorf("popcode.TwoD DecodeNPeaks: pattern must have 2 dimensions")
return nil, errors.Log(err)
}
rng := pc.Max.Sub(pc.Min)
ny := pat.DimSize(0)
nx := pat.DimSize(1)
nf := math32.Vec2(float32(nx-1), float32(ny-1))
incr := rng.Div(nf)
type navg struct {
avg float32
x, y int
}
avgs := make([]navg, nx*ny) // expensive
idx := 0
for yi := 0; yi < ny; yi++ {
for xi := 0; xi < nx; xi++ {
sum := float32(0)
ns := 0
for dy := -width; dy <= width; dy++ {
y := yi + dy
if y < 0 || y >= ny {
continue
}
for dx := -width; dx <= width; dx++ {
x := xi + dx
if x < 0 || x >= nx {
continue
}
idx := []int{y, x}
act := float32(pat.Float(idx...))
sum += act
ns++
}
}
avgs[idx].avg = sum / float32(ns)
avgs[idx].x = xi
avgs[idx].y = yi
idx++
}
}
// sort highest to lowest
sort.Slice(avgs, func(i, j int) bool {
return avgs[i].avg > avgs[j].avg
})
vals := make([]math32.Vector2, nvals)
for i := range vals {
avg := math32.Vector2{}
sum := float32(0)
mxi := avgs[i].x
myi := avgs[i].y
for dy := -width; dy <= width; dy++ {
y := myi + dy
if y < 0 || y >= ny {
continue
}
for dx := -width; dx <= width; dx++ {
x := mxi + dx
if x < 0 || x >= nx {
continue
}
idx := []int{y, x}
act := float32(pat.Float(idx...))
if act < pc.Thr {
act = 0
}
fi := math32.Vec2(float32(x), float32(y))
trg := pc.Min.Add(incr.Mul(fi))
avg = avg.Add(trg.MulScalar(act))
sum += act
}
}
sum = math32.Max(sum, pc.MinSum)
vals[i] = avg.DivScalar(sum)
}
return vals, nil
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package popcode
import (
"cogentcore.org/core/math32"
)
// Ring is a OneD popcode that encodes a circular value such as an angle
// that wraps around at the ends. It uses two internal vectors
// to render the wrapped-around values into, and then adds them into
// the final result. Unlike regular PopCodes, the Min and Max should
// represent the exact range of the value (e.g., 0 to 360 for angle)
// with no extra on the ends, as that extra will wrap around to
// the other side in this case.
type Ring struct {
OneD
// low-end encoding vector
LowVec []float32 `display:"-"`
// high-end encoding vector
HighVec []float32 `display:"-"`
}
// AllocVecs allocates internal LowVec, HighVec storage,
// allowing for variable lengths to be encoded using same object,
// growing capacity to max, but using exact amount each time
func (pc *Ring) AllocVecs(n int) {
if cap(pc.LowVec) < n {
pc.LowVec = make([]float32, n)
pc.HighVec = make([]float32, n)
}
pc.LowVec = pc.LowVec[:n]
pc.HighVec = pc.HighVec[:n]
}
// Encode generates a pattern of activation of given size to encode given value.
// n must be 2 or more.
// pat slice will be constructed if len != n
func (pc *Ring) Encode(pat *[]float32, val float32, n int) {
pc.Clip = false // doesn't work with clip!
if len(*pat) != n {
*pat = make([]float32, n)
}
pc.AllocVecs(n)
rng := pc.Max - pc.Min
sr := pc.Sigma * rng
if math32.Abs(pc.Max-val) < sr { // close to top
pc.EncodeImpl(&pc.LowVec, pc.Min+(val-pc.Max), n) // 0 + (340 - 360) = -20
pc.EncodeImpl(&pc.HighVec, val, n)
} else if math32.Abs(val-pc.Min) < sr { // close to bottom
pc.EncodeImpl(&pc.LowVec, val, n) // 0 + (340 - 360) = -20
pc.EncodeImpl(&pc.HighVec, pc.Max+(val-pc.Min), n) // 360 + (20-0) = 380
} else {
pc.EncodeImpl(pat, val, n)
return
}
for i := 0; i < n; i++ {
(*pat)[i] = pc.LowVec[i] + pc.HighVec[i]
}
}
// EncodeImpl generates a pattern of activation of given size to encode given value.
// n must be 2 or more.
// pat slice will be constructed if len != n
func (pc *Ring) EncodeImpl(pat *[]float32, val float32, n int) {
if len(*pat) != n {
*pat = make([]float32, n)
}
if pc.Clip {
val = math32.Clamp(val, pc.Min, pc.Max)
}
rng := pc.Max - pc.Min
gnrm := 1 / (rng * pc.Sigma)
incr := rng / float32(n-1)
for i := 0; i < n; i++ {
trg := pc.Min + incr*float32(i)
act := float32(0)
switch pc.Code {
case GaussBump:
dist := gnrm * (trg - val)
act = math32.Exp(-(dist * dist))
case Localist:
dist := math32.Abs(trg - val)
if dist > incr {
act = 0
} else {
act = 1.0 - (dist / incr)
}
}
(*pat)[i] = act
}
}
// Decode decodes value from a pattern of activation
// as the activation-weighted-average of the unit's preferred
// tuning values.
// pat pattern must be len >= 2
func (pc *Ring) Decode(pat []float32) float32 {
n := len(pat)
sn := int(pc.Sigma * float32(n)) // amount on each end to blank
hsn := (n - 1) - sn
hn := n / 2
// and record activity in each end region
var lsum, hsum, lend, hend float32
for i := 0; i < n; i++ {
v := pat[i]
if i < sn {
lend += v
} else if i >= hsn {
hend += v
}
if i < hn {
lsum += v
} else {
hsum += v
}
}
rng := pc.Max - pc.Min
half := rng / 2
incr := rng / float32(n-1)
avg := float32(0)
sum := float32(0)
thr := float32(sn) * pc.Thr // threshold activity to count as having something in tail
if lend < thr && hend < thr { // neither has significant activity, use straight decode
for i := 0; i < n; i++ {
act := pat[i]
trg := pc.Min + incr*float32(i)
if act < pc.Thr {
act = 0
}
avg += trg * act
sum += act
}
} else if lsum > hsum { // lower is more active -- wrap high end below low end
for i := 0; i < hn; i++ { // decode lower half as usual
act := pat[i]
trg := pc.Min + incr*float32(i)
if act < pc.Thr {
act = 0
}
avg += trg * act
sum += act
}
min := pc.Min - half
for i := hn; i < n; i++ { // decode upper half as starting below lower
act := pat[i]
trg := min + incr*float32(i-hn)
if act < pc.Thr {
act = 0
}
avg += trg * act
sum += act
}
} else {
for i := hn; i < n; i++ { // decode upper half as usual
act := pat[i]
trg := pc.Min + incr*float32(i)
if act < pc.Thr {
act = 0
}
avg += trg * act
sum += act
}
min := pc.Max
for i := 0; i < hn; i++ { // decode lower half as starting above upper
act := pat[i]
trg := min + incr*float32(i)
if act < pc.Thr {
act = 0
}
avg += trg * act
sum += act
}
}
sum = math32.Max(sum, pc.MinSum)
avg /= sum
return avg
}
// Values sets the vals slice to the target preferred tuning values
// for each unit, for a distribution of given size n.
// n must be 2 or more.
// vals slice will be constructed if len != n
func (pc *Ring) Values(vals *[]float32, n int) {
if len(*vals) != n {
*vals = make([]float32, n)
}
rng := pc.Max - pc.Min
incr := rng / float32(n-1)
for i := 0; i < n; i++ {
trg := pc.Min + incr*float32(i)
(*vals)[i] = trg
}
}
// Code generated by "core generate -add-types"; DO NOT EDIT.
package relpos
import (
"cogentcore.org/core/enums"
)
var _RelationsValues = []Relations{0, 1, 2, 3, 4, 5, 6}
// RelationsN is the highest valid value for type Relations, plus one.
const RelationsN Relations = 7
var _RelationsValueMap = map[string]Relations{`NoRel`: 0, `RightOf`: 1, `LeftOf`: 2, `Behind`: 3, `FrontOf`: 4, `Above`: 5, `Below`: 6}
var _RelationsDescMap = map[Relations]string{0: ``, 1: ``, 2: ``, 3: ``, 4: ``, 5: ``, 6: ``}
var _RelationsMap = map[Relations]string{0: `NoRel`, 1: `RightOf`, 2: `LeftOf`, 3: `Behind`, 4: `FrontOf`, 5: `Above`, 6: `Below`}
// String returns the string representation of this Relations value.
func (i Relations) String() string { return enums.String(i, _RelationsMap) }
// SetString sets the Relations value from its string representation,
// and returns an error if the string is invalid.
func (i *Relations) SetString(s string) error {
return enums.SetString(i, s, _RelationsValueMap, "Relations")
}
// Int64 returns the Relations value as an int64.
func (i Relations) Int64() int64 { return int64(i) }
// SetInt64 sets the Relations value from an int64.
func (i *Relations) SetInt64(in int64) { *i = Relations(in) }
// Desc returns the description of the Relations value.
func (i Relations) Desc() string { return enums.Desc(i, _RelationsDescMap) }
// RelationsValues returns all possible values for the type Relations.
func RelationsValues() []Relations { return _RelationsValues }
// Values returns all possible values for the type Relations.
func (i Relations) Values() []enums.Enum { return enums.Values(_RelationsValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i Relations) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *Relations) UnmarshalText(text []byte) error {
return enums.UnmarshalText(i, text, "Relations")
}
var _XAlignsValues = []XAligns{0, 1, 2}
// XAlignsN is the highest valid value for type XAligns, plus one.
const XAlignsN XAligns = 3
var _XAlignsValueMap = map[string]XAligns{`Left`: 0, `Middle`: 1, `Right`: 2}
var _XAlignsDescMap = map[XAligns]string{0: ``, 1: ``, 2: ``}
var _XAlignsMap = map[XAligns]string{0: `Left`, 1: `Middle`, 2: `Right`}
// String returns the string representation of this XAligns value.
func (i XAligns) String() string { return enums.String(i, _XAlignsMap) }
// SetString sets the XAligns value from its string representation,
// and returns an error if the string is invalid.
func (i *XAligns) SetString(s string) error {
return enums.SetString(i, s, _XAlignsValueMap, "XAligns")
}
// Int64 returns the XAligns value as an int64.
func (i XAligns) Int64() int64 { return int64(i) }
// SetInt64 sets the XAligns value from an int64.
func (i *XAligns) SetInt64(in int64) { *i = XAligns(in) }
// Desc returns the description of the XAligns value.
func (i XAligns) Desc() string { return enums.Desc(i, _XAlignsDescMap) }
// XAlignsValues returns all possible values for the type XAligns.
func XAlignsValues() []XAligns { return _XAlignsValues }
// Values returns all possible values for the type XAligns.
func (i XAligns) Values() []enums.Enum { return enums.Values(_XAlignsValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i XAligns) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *XAligns) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "XAligns") }
var _YAlignsValues = []YAligns{0, 1, 2}
// YAlignsN is the highest valid value for type YAligns, plus one.
const YAlignsN YAligns = 3
var _YAlignsValueMap = map[string]YAligns{`Front`: 0, `Center`: 1, `Back`: 2}
var _YAlignsDescMap = map[YAligns]string{0: ``, 1: ``, 2: ``}
var _YAlignsMap = map[YAligns]string{0: `Front`, 1: `Center`, 2: `Back`}
// String returns the string representation of this YAligns value.
func (i YAligns) String() string { return enums.String(i, _YAlignsMap) }
// SetString sets the YAligns value from its string representation,
// and returns an error if the string is invalid.
func (i *YAligns) SetString(s string) error {
return enums.SetString(i, s, _YAlignsValueMap, "YAligns")
}
// Int64 returns the YAligns value as an int64.
func (i YAligns) Int64() int64 { return int64(i) }
// SetInt64 sets the YAligns value from an int64.
func (i *YAligns) SetInt64(in int64) { *i = YAligns(in) }
// Desc returns the description of the YAligns value.
func (i YAligns) Desc() string { return enums.Desc(i, _YAlignsDescMap) }
// YAlignsValues returns all possible values for the type YAligns.
func YAlignsValues() []YAligns { return _YAlignsValues }
// Values returns all possible values for the type YAligns.
func (i YAligns) Values() []enums.Enum { return enums.Values(_YAlignsValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i YAligns) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *YAligns) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "YAligns") }
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package relpos defines a position relationship among layers,
in terms of X,Y width and height of layer
and associated position within a given X-Y plane,
and Z vertical stacking of layers above and below each other.
*/
package relpos
//go:generate core generate -add-types
import (
"cogentcore.org/core/math32"
)
// Pos specifies the relative spatial relationship to another
// layer, which determines positioning. Every layer except one
// "anchor" layer should be positioned relative to another,
// e.g., RightOf, Above, etc. This provides robust positioning
// in the face of layer size changes etc.
// Layers are arranged in X-Y planes, stacked vertically along the Z axis.
type Pos struct { //git:add
// spatial relationship between this layer and the other layer
Rel Relations
// ] horizontal (x-axis) alignment relative to other
XAlign XAligns
// ] vertical (y-axis) alignment relative to other
YAlign YAligns
// name of the other layer we are in relationship to
Other string
// scaling factor applied to layer size for displaying
Scale float32
// number of unit-spaces between us
Space float32
// for vertical (y-axis) alignment, amount we are offset relative to perfect alignment
XOffset float32
// for horizontial (x-axis) alignment, amount we are offset relative to perfect alignment
YOffset float32
// Pos is the computed position of lower-left-hand corner of layer
// in 3D space, computed from the relation to other layer.
Pos math32.Vector3 `edit:"-"`
}
// Defaults sets default scale, space, offset values.
// The relationship and align must be set specifically.
// These are automatically applied if Scale = 0
func (rp *Pos) Defaults() {
if rp.Scale == 0 {
rp.Scale = 1
}
if rp.Space == 0 {
rp.Space = 5
}
}
func (rp *Pos) ShouldDisplay(field string) bool {
switch field {
case "XAlign":
return rp.Rel == FrontOf || rp.Rel == Behind || rp.Rel == Above || rp.Rel == Below
case "YAlign":
return rp.Rel == LeftOf || rp.Rel == RightOf || rp.Rel == Above || rp.Rel == Below
default:
return true
}
}
// SetRightOf sets a RightOf relationship with default YAlign:
// Front alignment and given spacing.
func (rp *Pos) SetRightOf(other string, space float32) {
rp.Rel = RightOf
rp.Other = other
rp.YAlign = Front
rp.Space = space
rp.Scale = 1
}
// SetBehind sets a Behind relationship with default XAlign:
// Left alignment and given spacing.
func (rp *Pos) SetBehind(other string, space float32) {
rp.Rel = Behind
rp.Other = other
rp.XAlign = Left
rp.Space = space
rp.Scale = 1
}
// SetAbove returns an Above relationship with default XAlign:
// Left, YAlign: Front alignment
func (rp *Pos) SetAbove(other string) {
rp.Rel = Above
rp.Other = other
rp.XAlign = Left
rp.YAlign = Front
rp.YOffset = 1
rp.Scale = 1
}
// SetPos sets the relative position based on other layer
// position and size, using current settings.
// osz and sz must both have already been scaled by
// relevant Scale factor.
func (rp *Pos) SetPos(op math32.Vector3, osz math32.Vector2, sz math32.Vector2) {
if rp.Scale == 0 {
rp.Defaults()
}
rp.Pos = op
switch rp.Rel {
case NoRel:
return
case RightOf:
rp.Pos.X = op.X + osz.X + rp.Space
rp.Pos.Y = rp.AlignYPos(op.Y, osz.Y, sz.Y)
case LeftOf:
rp.Pos.X = op.X - sz.X - rp.Space
rp.Pos.Y = rp.AlignYPos(op.Y, osz.Y, sz.Y)
case Behind:
rp.Pos.Y = op.Y + osz.Y + rp.Space
rp.Pos.X = rp.AlignXPos(op.X, osz.X, sz.X)
case FrontOf:
rp.Pos.Y = op.Y - sz.Y - rp.Space
rp.Pos.X = rp.AlignXPos(op.X, osz.X, sz.X)
case Above:
rp.Pos.Z += 1
rp.Pos.X = rp.AlignXPos(op.X, osz.X, sz.X)
rp.Pos.Y = rp.AlignYPos(op.Y, osz.Y, sz.Y)
case Below:
rp.Pos.Z -= 1
rp.Pos.X = rp.AlignXPos(op.X, osz.X, sz.X)
rp.Pos.Y = rp.AlignYPos(op.Y, osz.Y, sz.Y)
}
}
// AlignYPos returns the Y-axis (within-plane vertical or height)
// position according to alignment factors.
func (rp *Pos) AlignYPos(yop, yosz, ysz float32) float32 {
switch rp.YAlign {
case Front:
return yop + rp.YOffset
case Center:
return yop + 0.5*yosz - 0.5*ysz + rp.YOffset
case Back:
return yop + yosz - ysz + rp.YOffset
}
return yop
}
// AlignXPos returns the X-axis (within-plane horizontal or width)
// position according to alignment factors.
func (rp *Pos) AlignXPos(xop, xosz, xsz float32) float32 {
switch rp.XAlign {
case Left:
return xop + rp.XOffset
case Middle:
return xop + 0.5*xosz - 0.5*xsz + rp.XOffset
case Right:
return xop + xosz - xsz + rp.XOffset
}
return xop
}
// Relations are different spatial relationships (of layers)
type Relations int32 //enums:enum
// The relations
const (
NoRel Relations = iota
RightOf
LeftOf
Behind
FrontOf
Above
Below
)
// XAligns are different horizontal alignments
type XAligns int32 //enums:enum
const (
Left XAligns = iota
Middle
Right
)
// YAligns are different vertical alignments
type YAligns int32 //enums:enum
const (
Front YAligns = iota
Center
Back
)
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ringidx
//gosl:start ringidx
// FIx is a fixed-length ring index structure -- does not grow
// or shrink dynamically.
type FIx struct {
// the zero index position -- where logical 0 is in physical buffer
Zi uint32
// the length of the buffer -- wraps around at this modulus
Len uint32
pad, pad1 uint32
}
// Index returns the physical index of the logical index i.
// i must be < Len.
func (fi *FIx) Index(i uint32) uint32 {
i += fi.Zi
if i >= fi.Len {
i -= fi.Len
}
return i
}
// IndexIsValid returns true if given index is valid: >= 0 and < Len
func (fi *FIx) IndexIsValid(i uint32) bool {
return i < fi.Len
}
// Shift moves the zero index up by n.
func (fi *FIx) Shift(n uint32) {
fi.Zi = uint32(fi.Index(n))
}
//gosl:end ringidx
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package ringidx provides circular indexing logic for writing a given
length of data into a fixed-sized buffer and wrapping around this
buffer, overwriting the oldest data. No copying is required so
it is highly efficient
*/
package ringidx
//go:generate core generate -add-types
// Index is the ring index structure for a dynamically-sized ring buffer,
// maintaining starting index and length into a ring-buffer with maximum
// length Max. Max must be > 0 and Len <= Max.
// When adding new items would overflow Max, starting index is shifted over
// to overwrite the oldest items with the new ones. No moving is ever
// required: just a fixed-length buffer of size Max.
type Index struct {
// Start the starting index where current data starts.
// The oldest data is at this index, and continues for Len items,
// wrapping around at Max, coming back up at most to Start-1.
Start int
// Len is the number of items stored starting at Start. Capped at Max.
Len int
// Max is the maximum number of items that can be stored in this ring.
Max int
}
// Index returns the index of the i'th item starting from Start.
// i must be < Len.
func (ri *Index) Index(i int) int {
i += ri.Start
if i >= ri.Max {
i -= ri.Max
}
return i
}
// LastIndex returns the index of the last (most recently added) item in the ring.
// Only valid if Len > 0
func (ri *Index) LastIndex() int {
return ri.Index(ri.Len - 1)
}
// IndexIsValid returns true if given index is valid: >= 0 and < Len
func (ri *Index) IndexIsValid(i int) bool {
return i >= 0 && i < ri.Len
}
// Add adds given number of items to the ring (n <= Len.
// Shift is called for Len+n - Max extra items to make room.
func (ri *Index) Add(n int) {
over := (ri.Len + n) - ri.Max
if over > 0 {
ri.Shift(over)
}
ri.Len += n
}
// Shift moves the starting index up by n, and decrements the Len by n as well.
// This is called prior to adding new items if doing so would exceed Max length.
func (ri *Index) Shift(n int) {
ri.Start = ri.Index(n)
ri.Len -= n
}
// Reset initializes start index and length to 0
func (ri *Index) Reset() {
ri.Start = 0
ri.Len = 0
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package weights
import (
"bufio"
"fmt"
"io"
"strconv"
"strings"
"cogentcore.org/core/base/errors"
)
// NetReadCpp reads weights for entire network from old emergent C++ format
func NetReadCpp(r io.Reader) (*Network, error) {
nw := &Network{}
var (
lw *Layer
pw *Path
rw *Recv
ri int
pi int
skipnext bool
cidx int
err error
errs []error
)
scan := bufio.NewScanner(r) // line at a time
for scan.Scan() {
if skipnext {
skipnext = false
continue
}
b := scan.Bytes()
bs := string(b)
switch {
case strings.HasPrefix(bs, "</"): // don't care about any ending tags
continue
case strings.HasPrefix(bs, "<Fmt "):
continue
case strings.HasPrefix(bs, "<Name "):
continue
case strings.HasPrefix(bs, "<Epoch "):
continue
case bs == "<Network>":
continue
case bs == "<Ug>":
continue
case bs == "<Un>":
skipnext = true // skip over bias weight
continue
case strings.HasPrefix(bs, "<Lay "):
lnm := strings.TrimSuffix(strings.TrimPrefix(bs, "<Lay "), ">")
nw.Layers = append(nw.Layers, Layer{Layer: lnm})
lw = &nw.Layers[len(nw.Layers)-1]
pw = nil
continue
case strings.HasPrefix(bs, "<UgUn "):
us := strings.TrimSuffix(strings.TrimPrefix(bs, "<UgUn "), ">")
uss := strings.Split(us, " ") // includes unit name
ri, err = strconv.Atoi(uss[0])
if err != nil {
errs = append(errs, err)
}
continue
case strings.HasPrefix(bs, "<Cg "):
cs := strings.TrimSuffix(strings.TrimPrefix(bs, "<Cg "), ">")
css := strings.Split(cs, " ")
pi, err = strconv.Atoi(css[0])
if err != nil {
errs = append(errs, err)
}
fm := strings.TrimPrefix(css[1], "From:")
if len(lw.Paths) < pi+1 {
lw.Paths = append(lw.Paths, Path{From: fm})
}
pw = &lw.Paths[pi]
continue
case strings.HasPrefix(bs, "<Cn "):
us := strings.TrimSuffix(strings.TrimPrefix(bs, "<Cn "), ">")
nc, err := strconv.Atoi(us)
if err != nil {
errs = append(errs, err)
}
if len(pw.Rs) < ri+1 {
pw.Rs = append(pw.Rs, Recv{Ri: ri, N: nc})
}
rw = &pw.Rs[ri]
if len(rw.Si) != nc {
rw.Si = make([]int, nc)
rw.Wt = make([]float32, nc)
rw.Wt1 = make([]float32, nc)
}
cidx = 0 // start reading on next ones
continue
case strings.HasPrefix(bs, "<"): // misc meta
kvl := strings.Split(bs, " ")
if len(kvl) != 2 {
err = fmt.Errorf("NetReadCpp: unrecognized input: %v", bs)
errs = append(errs, err)
continue
}
ky := strings.TrimPrefix(kvl[0], "<")
vl := strings.TrimSuffix(kvl[1], ">")
switch ky {
case "acts_m_avg":
ky = "ActMAvg"
case "acts_p_avg":
ky = "ActPAvg"
}
if lw == nil {
nw.SetMetaData(ky, vl)
} else if pw == nil {
lw.SetMetaData(ky, vl)
} else {
pw.SetMetaData(ky, vl)
}
continue
default: // weight values read into current rw
siwts := strings.Split(bs, " ")
switch len(siwts) {
case 2:
si, err := strconv.Atoi(siwts[0])
if err != nil {
errs = append(errs, err)
}
wt, err := strconv.ParseFloat(siwts[1], 32)
if err != nil {
errs = append(errs, err)
}
rw.Si[cidx] = si
rw.Wt[cidx] = float32(wt)
rw.Wt1[cidx] = float32(0)
cidx++
case 3:
si, err := strconv.Atoi(siwts[0])
if err != nil {
errs = append(errs, err)
}
wt, err := strconv.ParseFloat(siwts[1], 32)
if err != nil {
errs = append(errs, err)
}
scale, err := strconv.ParseFloat(siwts[2], 32)
if err != nil {
errs = append(errs, err)
}
rw.Si[cidx] = si
rw.Wt[cidx] = float32(wt)
rw.Wt1[cidx] = float32(scale)
cidx++
default:
err = fmt.Errorf("NetReadCpp: unrecognized input: %v", bs)
errs = append(errs, err)
continue
}
}
}
return nw, errors.Log(errors.Join(errs...))
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package weights
import (
"io"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/iox/jsonx"
)
// Prec is the precision for weight output in text formats.
// The default is aggressive for Leabra models.
// May need to increase for other models.
var Prec = 4
// NetReadJSON reads weights for entire network in a JSON format into Network structure
func NetReadJSON(r io.Reader) (*Network, error) {
nw := &Network{}
err := errors.Log(jsonx.Read(nw, r))
return nw, err
}
// LayReadJSON reads weights for layer in a JSON format into Layer structure
func LayReadJSON(r io.Reader) (*Layer, error) {
lw := &Layer{}
err := errors.Log(jsonx.Read(lw, r))
return lw, err
}
// PathReadJSON reads weights for path in a JSON format into Path structure
func PathReadJSON(r io.Reader) (*Path, error) {
pw := &Path{}
err := errors.Log(jsonx.Read(pw, r))
return pw, err
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package weights
//go:generate core generate -add-types
// Network is temp structure for holding decoded weights
type Network struct {
Network string
MetaData map[string]string // used for optional network-level params, metadata
Layers []Layer
}
func (nt *Network) SetMetaData(key, val string) {
if nt.MetaData == nil {
nt.MetaData = make(map[string]string)
}
nt.MetaData[key] = val
}
// Layer is temp structure for holding decoded weights, one for each layer
type Layer struct {
Layer string
MetaData map[string]string // for optional layer-level params, metadata such as ActMAvg, ActPAvg
Units map[string][]float32 // for unit-level adapting parameters
Paths []Path // receiving pathways
}
func (ly *Layer) SetMetaData(key, val string) {
if ly.MetaData == nil {
ly.MetaData = make(map[string]string)
}
ly.MetaData[key] = val
}
// Path is temp structure for holding decoded weights, one for each pathway
type Path struct {
From string
MetaData map[string]string // used for optional path-level params, metadata such as GScale
MetaValues map[string][]float32 // optional values at the pathway level
Rs []Recv
}
func (pj *Path) SetMetaData(key, val string) {
if pj.MetaData == nil {
pj.MetaData = make(map[string]string)
}
pj.MetaData[key] = val
}
// Recv is temp structure for holding decoded weights, one for each recv unit
type Recv struct {
Ri int
N int
Si []int
Wt []float32
Wt1 []float32 // call extra synapse-level vars 1,2..
Wt2 []float32 // call extra synapse-level vars 1,2..
}
// Code generated by 'yaegi extract github.com/emer/emergent/v2/egui'. DO NOT EDIT.
package yaegiemergent
import (
"github.com/cogentcore/yaegi/interp"
"github.com/emer/emergent/v2/egui"
"reflect"
)
func init() {
Symbols["github.com/emer/emergent/v2/egui/egui"] = map[string]reflect.Value{
// function, constant and variable definitions
"ActiveAlways": reflect.ValueOf(egui.ActiveAlways),
"ActiveRunning": reflect.ValueOf(egui.ActiveRunning),
"ActiveStopped": reflect.ValueOf(egui.ActiveStopped),
"Embed": reflect.ValueOf(interp.GenericFunc("func Embed[S, C any](parent tree.Node) *S { //yaegi:add\n\tcfgC, cfg := NewConfig[C]()\n\n\tcfg.AsBaseConfig().GUI = true // force GUI on\n\n\tsimS := new(S)\n\tsim := any(simS).(Sim[C])\n\n\tsim.SetConfig(cfgC)\n\tsim.ConfigSim()\n\tsim.Init()\n\tsim.ConfigGUI(parent)\n\treturn simS\n}")),
"NewConfig": reflect.ValueOf(interp.GenericFunc("func NewConfig[C any]() (*C, Config) { //yaegi:add\n\tcfgC := new(C)\n\tcfg := any(cfgC).(Config)\n\n\terrors.Log(reflectx.SetFromDefaultTags(cfg))\n\tcfg.AsBaseConfig().BaseDefaults()\n\tcfg.Defaults()\n\treturn cfgC, cfg\n}")),
"NewGUIBody": reflect.ValueOf(egui.NewGUIBody),
"ToolGhostingN": reflect.ValueOf(egui.ToolGhostingN),
"ToolGhostingValues": reflect.ValueOf(egui.ToolGhostingValues),
// type definitions
"BaseConfig": reflect.ValueOf((*egui.BaseConfig)(nil)),
"Config": reflect.ValueOf((*egui.Config)(nil)),
"GUI": reflect.ValueOf((*egui.GUI)(nil)),
"ToolGhosting": reflect.ValueOf((*egui.ToolGhosting)(nil)),
"ToolbarItem": reflect.ValueOf((*egui.ToolbarItem)(nil)),
// interface wrapper definitions
"_Config": reflect.ValueOf((*_github_com_emer_emergent_v2_egui_Config)(nil)),
}
}
// _github_com_emer_emergent_v2_egui_Config is an interface wrapper for Config type
type _github_com_emer_emergent_v2_egui_Config struct {
IValue interface{}
WAsBaseConfig func() *egui.BaseConfig
WDefaults func()
}
func (W _github_com_emer_emergent_v2_egui_Config) AsBaseConfig() *egui.BaseConfig {
return W.WAsBaseConfig()
}
func (W _github_com_emer_emergent_v2_egui_Config) Defaults() { W.WDefaults() }
// Code generated by 'yaegi extract github.com/emer/emergent/v2/netview'. DO NOT EDIT.
package yaegiemergent
import (
"github.com/emer/emergent/v2/netview"
"go/constant"
"go/token"
"reflect"
)
func init() {
Symbols["github.com/emer/emergent/v2/netview/netview"] = map[string]reflect.Value{
// function, constant and variable definitions
"FormDialog": reflect.ValueOf(netview.FormDialog),
"MinUnitHeight": reflect.ValueOf(&netview.MinUnitHeight).Elem(),
"NVarCols": reflect.ValueOf(&netview.NVarCols).Elem(),
"NaNSub": reflect.ValueOf(constant.MakeFromLiteral("-1.109999999999999999999999999999999999999916376162440762314758469872874704797839925826661159333171163347944189256956023240834378505e-37", token.FLOAT, 0)),
"NewLayMesh": reflect.ValueOf(netview.NewLayMesh),
"NewLayName": reflect.ValueOf(netview.NewLayName),
"NewLayObj": reflect.ValueOf(netview.NewLayObj),
"NewNetView": reflect.ValueOf(netview.NewNetView),
"NewScene": reflect.ValueOf(netview.NewScene),
"NilColor": reflect.ValueOf(&netview.NilColor).Elem(),
// type definitions
"LayData": reflect.ValueOf((*netview.LayData)(nil)),
"LayMesh": reflect.ValueOf((*netview.LayMesh)(nil)),
"LayName": reflect.ValueOf((*netview.LayName)(nil)),
"LayObj": reflect.ValueOf((*netview.LayObj)(nil)),
"NetData": reflect.ValueOf((*netview.NetData)(nil)),
"NetView": reflect.ValueOf((*netview.NetView)(nil)),
"Options": reflect.ValueOf((*netview.Options)(nil)),
"PathData": reflect.ValueOf((*netview.PathData)(nil)),
"RasterOptions": reflect.ValueOf((*netview.RasterOptions)(nil)),
"Scene": reflect.ValueOf((*netview.Scene)(nil)),
"VarOptions": reflect.ValueOf((*netview.VarOptions)(nil)),
"ViewUpdate": reflect.ValueOf((*netview.ViewUpdate)(nil)),
}
}