// Copyright (c) 2018, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package atomicx implements misc atomic functions.
package atomicx
import (
"sync/atomic"
)
// Counter implements a basic atomic int64 counter.
type Counter int64
// Add adds to counter.
func (a *Counter) Add(inc int64) int64 {
return atomic.AddInt64((*int64)(a), inc)
}
// Sub subtracts from counter.
func (a *Counter) Sub(dec int64) int64 {
return atomic.AddInt64((*int64)(a), -dec)
}
// Inc increments by 1.
func (a *Counter) Inc() int64 {
return atomic.AddInt64((*int64)(a), 1)
}
// Dec decrements by 1.
func (a *Counter) Dec() int64 {
return atomic.AddInt64((*int64)(a), -1)
}
// Value returns the current value.
func (a *Counter) Value() int64 {
return atomic.LoadInt64((*int64)(a))
}
// Set sets the counter to a new value.
func (a *Counter) Set(val int64) {
atomic.StoreInt64((*int64)(a), val)
}
// Swap swaps a new value in and returns the old value.
func (a *Counter) Swap(val int64) int64 {
return atomic.SwapInt64((*int64)(a), val)
}
// Copyright (c) 2018, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package atomicx
import "sync/atomic"
// MaxInt32 performs an atomic Max operation: a = max(a, b)
func MaxInt32(a *int32, b int32) {
old := atomic.LoadInt32(a)
for old < b && !atomic.CompareAndSwapInt32(a, old, b) {
old = atomic.LoadInt32(a)
}
}
// Code generated by dummy.gen.go.tmpl. DO NOT EDIT.
// Copyright (c) 2020, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !mpi && !mpich
package mpi
// SendF64 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendF64(toProc int, tag int, vals []float64) error {
return nil
}
// Recv64F64 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvF64(fmProc int, tag int, vals []float64) error {
return nil
}
// BcastF64 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastF64(fmProc int, vals []float64) error {
return nil
}
// ReduceF64 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceF64(toProc int, op Op, dest, orig []float64) error {
return nil
}
// AllReduceF64 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceF64(op Op, dest, orig []float64) error {
return nil
}
// GatherF64 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherF64(toProc int, dest, orig []float64) error {
return nil
}
// AllGatherF64 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherF64(dest, orig []float64) error {
return nil
}
// ScatterF64 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterF64(fmProc int, dest, orig []float64) error {
return nil
}
// SendF32 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendF32(toProc int, tag int, vals []float32) error {
return nil
}
// Recv64F32 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvF32(fmProc int, tag int, vals []float32) error {
return nil
}
// BcastF32 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastF32(fmProc int, vals []float32) error {
return nil
}
// ReduceF32 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceF32(toProc int, op Op, dest, orig []float32) error {
return nil
}
// AllReduceF32 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceF32(op Op, dest, orig []float32) error {
return nil
}
// GatherF32 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherF32(toProc int, dest, orig []float32) error {
return nil
}
// AllGatherF32 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherF32(dest, orig []float32) error {
return nil
}
// ScatterF32 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterF32(fmProc int, dest, orig []float32) error {
return nil
}
// SendInt sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendInt(toProc int, tag int, vals []int) error {
return nil
}
// Recv64Int receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvInt(fmProc int, tag int, vals []int) error {
return nil
}
// BcastInt broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastInt(fmProc int, vals []int) error {
return nil
}
// ReduceInt reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceInt(toProc int, op Op, dest, orig []int) error {
return nil
}
// AllReduceInt reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceInt(op Op, dest, orig []int) error {
return nil
}
// GatherInt gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherInt(toProc int, dest, orig []int) error {
return nil
}
// AllGatherInt gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherInt(dest, orig []int) error {
return nil
}
// ScatterInt scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterInt(fmProc int, dest, orig []int) error {
return nil
}
// SendI64 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendI64(toProc int, tag int, vals []int64) error {
return nil
}
// Recv64I64 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvI64(fmProc int, tag int, vals []int64) error {
return nil
}
// BcastI64 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastI64(fmProc int, vals []int64) error {
return nil
}
// ReduceI64 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceI64(toProc int, op Op, dest, orig []int64) error {
return nil
}
// AllReduceI64 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceI64(op Op, dest, orig []int64) error {
return nil
}
// GatherI64 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherI64(toProc int, dest, orig []int64) error {
return nil
}
// AllGatherI64 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherI64(dest, orig []int64) error {
return nil
}
// ScatterI64 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterI64(fmProc int, dest, orig []int64) error {
return nil
}
// SendU64 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendU64(toProc int, tag int, vals []uint64) error {
return nil
}
// Recv64U64 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvU64(fmProc int, tag int, vals []uint64) error {
return nil
}
// BcastU64 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastU64(fmProc int, vals []uint64) error {
return nil
}
// ReduceU64 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceU64(toProc int, op Op, dest, orig []uint64) error {
return nil
}
// AllReduceU64 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceU64(op Op, dest, orig []uint64) error {
return nil
}
// GatherU64 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherU64(toProc int, dest, orig []uint64) error {
return nil
}
// AllGatherU64 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherU64(dest, orig []uint64) error {
return nil
}
// ScatterU64 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterU64(fmProc int, dest, orig []uint64) error {
return nil
}
// SendI32 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendI32(toProc int, tag int, vals []int32) error {
return nil
}
// Recv64I32 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvI32(fmProc int, tag int, vals []int32) error {
return nil
}
// BcastI32 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastI32(fmProc int, vals []int32) error {
return nil
}
// ReduceI32 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceI32(toProc int, op Op, dest, orig []int32) error {
return nil
}
// AllReduceI32 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceI32(op Op, dest, orig []int32) error {
return nil
}
// GatherI32 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherI32(toProc int, dest, orig []int32) error {
return nil
}
// AllGatherI32 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherI32(dest, orig []int32) error {
return nil
}
// ScatterI32 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterI32(fmProc int, dest, orig []int32) error {
return nil
}
// SendU32 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendU32(toProc int, tag int, vals []uint32) error {
return nil
}
// Recv64U32 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvU32(fmProc int, tag int, vals []uint32) error {
return nil
}
// BcastU32 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastU32(fmProc int, vals []uint32) error {
return nil
}
// ReduceU32 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceU32(toProc int, op Op, dest, orig []uint32) error {
return nil
}
// AllReduceU32 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceU32(op Op, dest, orig []uint32) error {
return nil
}
// GatherU32 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherU32(toProc int, dest, orig []uint32) error {
return nil
}
// AllGatherU32 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherU32(dest, orig []uint32) error {
return nil
}
// ScatterU32 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterU32(fmProc int, dest, orig []uint32) error {
return nil
}
// SendI16 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendI16(toProc int, tag int, vals []int16) error {
return nil
}
// Recv64I16 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvI16(fmProc int, tag int, vals []int16) error {
return nil
}
// BcastI16 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastI16(fmProc int, vals []int16) error {
return nil
}
// ReduceI16 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceI16(toProc int, op Op, dest, orig []int16) error {
return nil
}
// AllReduceI16 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceI16(op Op, dest, orig []int16) error {
return nil
}
// GatherI16 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherI16(toProc int, dest, orig []int16) error {
return nil
}
// AllGatherI16 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherI16(dest, orig []int16) error {
return nil
}
// ScatterI16 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterI16(fmProc int, dest, orig []int16) error {
return nil
}
// SendU16 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendU16(toProc int, tag int, vals []uint16) error {
return nil
}
// Recv64U16 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvU16(fmProc int, tag int, vals []uint16) error {
return nil
}
// BcastU16 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastU16(fmProc int, vals []uint16) error {
return nil
}
// ReduceU16 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceU16(toProc int, op Op, dest, orig []uint16) error {
return nil
}
// AllReduceU16 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceU16(op Op, dest, orig []uint16) error {
return nil
}
// GatherU16 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherU16(toProc int, dest, orig []uint16) error {
return nil
}
// AllGatherU16 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherU16(dest, orig []uint16) error {
return nil
}
// ScatterU16 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterU16(fmProc int, dest, orig []uint16) error {
return nil
}
// SendI8 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendI8(toProc int, tag int, vals []int8) error {
return nil
}
// Recv64I8 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvI8(fmProc int, tag int, vals []int8) error {
return nil
}
// BcastI8 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastI8(fmProc int, vals []int8) error {
return nil
}
// ReduceI8 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceI8(toProc int, op Op, dest, orig []int8) error {
return nil
}
// AllReduceI8 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceI8(op Op, dest, orig []int8) error {
return nil
}
// GatherI8 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherI8(toProc int, dest, orig []int8) error {
return nil
}
// AllGatherI8 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherI8(dest, orig []int8) error {
return nil
}
// ScatterI8 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterI8(fmProc int, dest, orig []int8) error {
return nil
}
// SendU8 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendU8(toProc int, tag int, vals []uint8) error {
return nil
}
// Recv64U8 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvU8(fmProc int, tag int, vals []uint8) error {
return nil
}
// BcastU8 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastU8(fmProc int, vals []uint8) error {
return nil
}
// ReduceU8 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceU8(toProc int, op Op, dest, orig []uint8) error {
return nil
}
// AllReduceU8 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceU8(op Op, dest, orig []uint8) error {
return nil
}
// GatherU8 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherU8(toProc int, dest, orig []uint8) error {
return nil
}
// AllGatherU8 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherU8(dest, orig []uint8) error {
return nil
}
// ScatterU8 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterU8(fmProc int, dest, orig []uint8) error {
return nil
}
// SendC128 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendC128(toProc int, tag int, vals []complex128) error {
return nil
}
// Recv64C128 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvC128(fmProc int, tag int, vals []complex128) error {
return nil
}
// BcastC128 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastC128(fmProc int, vals []complex128) error {
return nil
}
// ReduceC128 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceC128(toProc int, op Op, dest, orig []complex128) error {
return nil
}
// AllReduceC128 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceC128(op Op, dest, orig []complex128) error {
return nil
}
// GatherC128 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherC128(toProc int, dest, orig []complex128) error {
return nil
}
// AllGatherC128 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherC128(dest, orig []complex128) error {
return nil
}
// ScatterC128 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterC128(fmProc int, dest, orig []complex128) error {
return nil
}
// SendC64 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendC64(toProc int, tag int, vals []complex64) error {
return nil
}
// Recv64C64 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvC64(fmProc int, tag int, vals []complex64) error {
return nil
}
// BcastC64 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastC64(fmProc int, vals []complex64) error {
return nil
}
// ReduceC64 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceC64(toProc int, op Op, dest, orig []complex64) error {
return nil
}
// AllReduceC64 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceC64(op Op, dest, orig []complex64) error {
return nil
}
// GatherC64 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherC64(toProc int, dest, orig []complex64) error {
return nil
}
// AllGatherC64 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherC64(dest, orig []complex64) error {
return nil
}
// ScatterC64 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterC64(fmProc int, dest, orig []complex64) error {
return nil
}
// Copyright (c) 2020, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !mpi && !mpich
package mpi
// this file provides dummy versions, built by default, so mpi can be included
// generically without incurring additional complexity.
// set LogErrors to control whether MPI errors are automatically logged or not
var LogErrors = true
// Op is an aggregation operation: Sum, Min, Max, etc
type Op int
const (
OpSum Op = iota
OpMax
OpMin
OpProd
OpLAND // logical AND
OpLOR // logical OR
OpBAND // bitwise AND
OpBOR // bitwise OR
)
const (
// Root is the rank 0 node -- it is more semantic to use this
Root int = 0
)
// IsOn tells whether MPI is on or not
//
// NOTE: this returns true even after Stop
func IsOn() bool {
return false
}
// Init initialises MPI
func Init() {
}
// InitThreadSafe initialises MPI thread safe
func InitThreadSafe() error {
return nil
}
// Finalize finalises MPI (frees resources, shuts it down)
func Finalize() {
}
// WorldRank returns this proc's rank/ID within the World communicator.
// Returns 0 if not yet initialized, so it is always safe to call.
func WorldRank() (rank int) {
return 0
}
// WorldSize returns the number of procs in the World communicator.
// Returns 1 if not yet initialized, so it is always safe to call.
func WorldSize() (size int) {
return 1
}
// Comm is the MPI communicator -- all MPI communication operates as methods
// on this struct. It holds the MPI_Comm communicator and MPI_Group for
// sub-World group communication.
type Comm struct {
}
// NewComm creates a new communicator.
// if ranks is nil, communicator is for World (all active procs).
// otherwise, defined a group-level commuicator for given ranks.
func NewComm(ranks []int) (*Comm, error) {
cm := &Comm{}
return cm, nil
}
// Rank returns the rank/ID for this proc
func (cm *Comm) Rank() (rank int) {
return 0
}
// Size returns the number of procs in this communicator
func (cm *Comm) Size() (size int) {
return 1
}
// Abort aborts MPI
func (cm *Comm) Abort() error {
return nil
}
// Barrier forces synchronisation
func (cm *Comm) Barrier() error {
return nil
}
// Copyright (c) 2020, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mpi
import "fmt"
// PrintAllProcs causes mpi.Printf to print on all processors -- otherwise just 0
var PrintAllProcs = false
// Printf does fmt.Printf only on the 0 rank node (see also AllPrintf to do all)
// and PrintAllProcs var to override for debugging, and print all
func Printf(fs string, pars ...any) {
if !PrintAllProcs && WorldRank() > 0 {
return
}
if WorldRank() > 0 {
AllPrintf(fs, pars...)
} else {
fmt.Printf(fs, pars...)
}
}
// AllPrintf does fmt.Printf on all nodes, with node rank printed first
// This is best for debugging MPI itself.
func AllPrintf(fs string, pars ...any) {
fs = fmt.Sprintf("P%d: ", WorldRank()) + fs
fmt.Printf(fs, pars...)
}
// Println does fmt.Println only on the 0 rank node (see also AllPrintln to do all)
// and PrintAllProcs var to override for debugging, and print all
func Println(fs ...any) {
if !PrintAllProcs && WorldRank() > 0 {
return
}
if WorldRank() > 0 {
AllPrintln(fs...)
} else {
fmt.Println(fs...)
}
}
// AllPrintln does fmt.Println on all nodes, with node rank printed first
// This is best for debugging MPI itself.
func AllPrintln(fs ...any) {
fsa := make([]any, len(fs))
copy(fsa[1:], fs)
fsa[0] = fmt.Sprintf("P%d: ", WorldRank())
fmt.Println(fsa...)
}
// Copyright (c) 2019, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package randx
// BoolP is a simple method to generate a true value with given probability
// (else false). It is just rand.Float64() < p but this is more readable
// and explicit.
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func BoolP(p float64, randOpt ...Rand) bool {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
return rnd.Float64() < p
}
// BoolP32 is a simple method to generate a true value with given probability
// (else false). It is just rand.Float32() < p but this is more readable
// and explicit.
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func BoolP32(p float32, randOpt ...Rand) bool {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
return rnd.Float32() < p
}
// Copyright (c) 2023, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package randx
import (
"math"
)
// note: this file contains random distribution functions
// from gonum.org/v1/gonum/stat/distuv
// which we modified only to use the randx.Rand interface.
// BinomialGen returns binomial with n trials (par) each of probability p (var)
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func BinomialGen(n, p float64, randOpt ...Rand) float64 {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
// NUMERICAL RECIPES IN C: THE ART OF SCIENTIFIC COMPUTING (ISBN 0-521-43108-5)
// p. 295-6
// http://www.aip.de/groups/soe/local/numres/bookcpdf/c7-3.pdf
porg := p
if p > 0.5 {
p = 1 - p
}
am := n * p
if n < 25 {
// Use direct method.
bnl := 0.0
for i := 0; i < int(n); i++ {
if rnd.Float64() < p {
bnl++
}
}
if p != porg {
return n - bnl
}
return bnl
}
if am < 1 {
// Use rejection method with Poisson proposal.
const logM = 2.6e-2 // constant for rejection sampling (https://en.wikipedia.org/wiki/Rejection_sampling)
var bnl float64
z := -p
pclog := (1 + 0.5*z) * z / (1 + (1+1.0/6*z)*z) // Padé approximant of log(1 + x)
for {
bnl = 0.0
t := 0.0
for i := 0; i < int(n); i++ {
t += rnd.ExpFloat64()
if t >= am {
break
}
bnl++
}
bnlc := n - bnl
z = -bnl / n
log1p := (1 + 0.5*z) * z / (1 + (1+1.0/6*z)*z)
t = (bnlc+0.5)*log1p + bnl - bnlc*pclog + 1/(12*bnlc) - am + logM // Uses Stirling's expansion of log(n!)
if rnd.ExpFloat64() >= t {
break
}
}
if p != porg {
return n - bnl
}
return bnl
}
// Original algorithm samples from a Poisson distribution with the
// appropriate expected value. However, the Poisson approximation is
// asymptotic such that the absolute deviation in probability is O(1/n).
// Rejection sampling produces exact variates with at worst less than 3%
// rejection with miminal additional computation.
// Use rejection method with Cauchy proposal.
g, _ := math.Lgamma(n + 1)
plog := math.Log(p)
pclog := math.Log1p(-p)
sq := math.Sqrt(2 * am * (1 - p))
for {
var em, y float64
for {
y = math.Tan(math.Pi * rnd.Float64())
em = sq*y + am
if em >= 0 && em < n+1 {
break
}
}
em = math.Floor(em)
lg1, _ := math.Lgamma(em + 1)
lg2, _ := math.Lgamma(n - em + 1)
t := 1.2 * sq * (1 + y*y) * math.Exp(g-lg1-lg2+em*plog+(n-em)*pclog)
if rnd.Float64() <= t {
if p != porg {
return n - em
}
return em
}
}
}
// PoissonGen returns poisson variable, as number of events in interval,
// with event rate (lmb = Var) plus mean
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func PoissonGen(lambda float64, randOpt ...Rand) float64 {
// NUMERICAL RECIPES IN C: THE ART OF SCIENTIFIC COMPUTING (ISBN 0-521-43108-5)
// p. 294
// <http://www.aip.de/groups/soe/local/numres/bookcpdf/c7-3.pdf>
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
if lambda < 10.0 {
// Use direct method.
var em float64
t := 0.0
for {
t += rnd.ExpFloat64()
if t >= lambda {
break
}
em++
}
return em
}
// Generate using:
// W. Hörmann. "The transformed rejection method for generating Poisson
// random variables." Insurance: Mathematics and Economics
// 12.1 (1993): 39-45.
b := 0.931 + 2.53*math.Sqrt(lambda)
a := -0.059 + 0.02483*b
invalpha := 1.1239 + 1.1328/(b-3.4)
vr := 0.9277 - 3.6224/(b-2)
for {
U := rnd.Float64() - 0.5
V := rnd.Float64()
us := 0.5 - math.Abs(U)
k := math.Floor((2*a/us+b)*U + lambda + 0.43)
if us >= 0.07 && V <= vr {
return k
}
if k <= 0 || (us < 0.013 && V > us) {
continue
}
lg, _ := math.Lgamma(k + 1)
if math.Log(V*invalpha/(a/(us*us)+b)) <= k*math.Log(lambda)-lambda-lg {
return k
}
}
}
// GammaGen represents maximum entropy distribution with two parameters:
// a shape parameter (Alpha, Par in RandParams),
// and a scaling parameter (Beta, Var in RandParams).
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func GammaGen(alpha, beta float64, randOpt ...Rand) float64 {
const (
// The 0.2 threshold is from https://www4.stat.ncsu.edu/~rmartin/Codes/rgamss.R
// described in detail in https://arxiv.org/abs/1302.1884.
smallAlphaThresh = 0.2
)
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
if beta <= 0 {
panic("GammaGen: beta <= 0")
}
a := alpha
b := beta
switch {
case a <= 0:
panic("gamma: alpha <= 0")
case a == 1:
// Generate from exponential
return rnd.ExpFloat64() / b
case a < smallAlphaThresh:
// Generate using
// Liu, Chuanhai, Martin, Ryan and Syring, Nick. "Simulating from a
// gamma distribution with small shape parameter"
// https://arxiv.org/abs/1302.1884
// use this reference: http://link.springer.com/article/10.1007/s00180-016-0692-0
// Algorithm adjusted to work in log space as much as possible.
lambda := 1/a - 1
lr := -math.Log1p(1 / lambda / math.E)
for {
e := rnd.ExpFloat64()
var z float64
if e >= -lr {
z = e + lr
} else {
z = -rnd.ExpFloat64() / lambda
}
eza := math.Exp(-z / a)
lh := -z - eza
var lEta float64
if z >= 0 {
lEta = -z
} else {
lEta = -1 + lambda*z
}
if lh-lEta > -rnd.ExpFloat64() {
return eza / b
}
}
case a >= smallAlphaThresh:
// Generate using:
// Marsaglia, George, and Wai Wan Tsang. "A simple method for generating
// gamma variables." ACM Transactions on Mathematical Software (TOMS)
// 26.3 (2000): 363-372.
d := a - 1.0/3
m := 1.0
if a < 1 {
d += 1.0
m = math.Pow(rnd.Float64(), 1/a)
}
c := 1 / (3 * math.Sqrt(d))
for {
x := rnd.NormFloat64()
v := 1 + x*c
if v <= 0.0 {
continue
}
v = v * v * v
u := rnd.Float64()
if u < 1.0-0.0331*(x*x)*(x*x) {
return m * d * v / b
}
if math.Log(u) < 0.5*x*x+d*(1-v+math.Log(v)) {
return m * d * v / b
}
}
}
panic("unreachable")
}
// GaussianGen returns gaussian (normal) random number with given
// mean and sigma standard deviation.
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func GaussianGen(mean, sigma float64, randOpt ...Rand) float64 {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
return mean + sigma*rnd.NormFloat64()
}
// BetaGen returns beta random number with two shape parameters
// alpha > 0 and beta > 0
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func BetaGen(alpha, beta float64, randOpt ...Rand) float64 {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
ga := GammaGen(alpha, 1, rnd)
gb := GammaGen(beta, 1, rnd)
return ga / (ga + gb)
}
// Code generated by "core generate -add-types"; DO NOT EDIT.
package randx
import (
"cogentcore.org/core/enums"
)
var _RandDistsValues = []RandDists{0, 1, 2, 3, 4, 5, 6}
// RandDistsN is the highest valid value for type RandDists, plus one.
const RandDistsN RandDists = 7
var _RandDistsValueMap = map[string]RandDists{`Uniform`: 0, `Binomial`: 1, `Poisson`: 2, `Gamma`: 3, `Gaussian`: 4, `Beta`: 5, `Mean`: 6}
var _RandDistsDescMap = map[RandDists]string{0: `Uniform has a uniform probability distribution over Var = range on either side of the Mean`, 1: `Binomial represents number of 1's in n (Par) random (Bernouli) trials of probability p (Var)`, 2: `Poisson represents number of events in interval, with event rate (lambda = Var) plus Mean`, 3: `Gamma represents maximum entropy distribution with two parameters: scaling parameter (Var) and shape parameter k (Par) plus Mean`, 4: `Gaussian normal with Var = stddev plus Mean`, 5: `Beta with Var = alpha and Par = beta shape parameters`, 6: `Mean is just the constant Mean, no randomness`}
var _RandDistsMap = map[RandDists]string{0: `Uniform`, 1: `Binomial`, 2: `Poisson`, 3: `Gamma`, 4: `Gaussian`, 5: `Beta`, 6: `Mean`}
// String returns the string representation of this RandDists value.
func (i RandDists) String() string { return enums.String(i, _RandDistsMap) }
// SetString sets the RandDists value from its string representation,
// and returns an error if the string is invalid.
func (i *RandDists) SetString(s string) error {
return enums.SetString(i, s, _RandDistsValueMap, "RandDists")
}
// Int64 returns the RandDists value as an int64.
func (i RandDists) Int64() int64 { return int64(i) }
// SetInt64 sets the RandDists value from an int64.
func (i *RandDists) SetInt64(in int64) { *i = RandDists(in) }
// Desc returns the description of the RandDists value.
func (i RandDists) Desc() string { return enums.Desc(i, _RandDistsDescMap) }
// RandDistsValues returns all possible values for the type RandDists.
func RandDistsValues() []RandDists { return _RandDistsValues }
// Values returns all possible values for the type RandDists.
func (i RandDists) Values() []enums.Enum { return enums.Values(_RandDistsValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i RandDists) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *RandDists) UnmarshalText(text []byte) error {
return enums.UnmarshalText(i, text, "RandDists")
}
// Copyright (c) 2019, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package randx
// PChoose32 chooses an index in given slice of float32's at random according
// to the probilities of each item (must be normalized to sum to 1).
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func PChoose32(ps []float32, randOpt ...Rand) int {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
pv := rnd.Float32()
sum := float32(0)
for i, p := range ps {
sum += p
if pv < sum { // note: lower values already excluded
return i
}
}
return len(ps) - 1
}
// PChoose64 chooses an index in given slice of float64's at random according
// to the probilities of each item (must be normalized to sum to 1)
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func PChoose64(ps []float64, randOpt ...Rand) int {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
pv := rnd.Float64()
sum := float64(0)
for i, p := range ps {
sum += p
if pv < sum { // note: lower values already excluded
return i
}
}
return len(ps) - 1
}
// Copyright (c) 2019, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package randx
// SequentialInts initializes slice of ints to sequential start..start+N-1
// numbers -- for cases where permuting the order is optional.
func SequentialInts(ins []int, start int) {
for i := range ins {
ins[i] = start + i
}
}
// PermuteInts permutes (shuffles) the order of elements in the given int slice
// using the standard Fisher-Yates shuffle
// https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
// So you don't have to remember how to call rand.Shuffle.
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func PermuteInts(ins []int, randOpt ...Rand) {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
rnd.Shuffle(len(ins), func(i, j int) {
ins[i], ins[j] = ins[j], ins[i]
})
}
// PermuteStrings permutes (shuffles) the order of elements in the given string slice
// using the standard Fisher-Yates shuffle
// https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
// So you don't have to remember how to call rand.Shuffle
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func PermuteStrings(ins []string, randOpt ...Rand) {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
rnd.Shuffle(len(ins), func(i, j int) {
ins[i], ins[j] = ins[j], ins[i]
})
}
// PermuteFloat32s permutes (shuffles) the order of elements in the given float32 slice
// using the standard Fisher-Yates shuffle
// https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
// So you don't have to remember how to call rand.Shuffle
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func PermuteFloat32s(ins []float32, randOpt ...Rand) {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
rnd.Shuffle(len(ins), func(i, j int) {
ins[i], ins[j] = ins[j], ins[i]
})
}
// PermuteFloat64s permutes (shuffles) the order of elements in the given float64 slice
// using the standard Fisher-Yates shuffle
// https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
// So you don't have to remember how to call rand.Shuffle
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func PermuteFloat64s(ins []float64, randOpt ...Rand) {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
rnd.Shuffle(len(ins), func(i, j int) {
ins[i], ins[j] = ins[j], ins[i]
})
}
// Copyright (c) 2023, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package randx
//go:generate core generate -add-types
import "math/rand"
// Rand provides an interface with most of the standard
// rand.Rand methods, to support the use of either the
// global rand generator or a separate Rand source.
type Rand interface {
// Seed uses the provided seed value to initialize the generator to a deterministic state.
// Seed should not be called concurrently with any other Rand method.
Seed(seed int64)
// Int63 returns a non-negative pseudo-random 63-bit integer as an int64.
Int63() int64
// Uint32 returns a pseudo-random 32-bit value as a uint32.
Uint32() uint32
// Uint64 returns a pseudo-random 64-bit value as a uint64.
Uint64() uint64
// Int31 returns a non-negative pseudo-random 31-bit integer as an int32.
Int31() int32
// Int returns a non-negative pseudo-random int.
Int() int
// Int63n returns, as an int64, a non-negative pseudo-random number in the half-open interval [0,n).
// It panics if n <= 0.
Int63n(n int64) int64
// Int31n returns, as an int32, a non-negative pseudo-random number in the half-open interval [0,n).
// It panics if n <= 0.
Int31n(n int32) int32
// Intn returns, as an int, a non-negative pseudo-random number in the half-open interval [0,n).
// It panics if n <= 0.
Intn(n int) int
// Float64 returns, as a float64, a pseudo-random number in the half-open interval [0.0,1.0).
Float64() float64
// Float32 returns, as a float32, a pseudo-random number in the half-open interval [0.0,1.0).
Float32() float32
// NormFloat64 returns a normally distributed float64 in the range
// [-math.MaxFloat64, +math.MaxFloat64] with
// standard normal distribution (mean = 0, stddev = 1)
// from the default Source.
// To produce a different normal distribution, callers can
// adjust the output using:
//
// sample = NormFloat64() * desiredStdDev + desiredMean
NormFloat64() float64
// ExpFloat64 returns an exponentially distributed float64 in the range
// (0, +math.MaxFloat64] with an exponential distribution whose rate parameter
// (lambda) is 1 and whose mean is 1/lambda (1) from the default Source.
// To produce a distribution with a different rate parameter,
// callers can adjust the output using:
//
// sample = ExpFloat64() / desiredRateParameter
ExpFloat64() float64
// Perm returns, as a slice of n ints, a pseudo-random permutation of the integers
// in the half-open interval [0,n).
Perm(n int) []int
// Shuffle pseudo-randomizes the order of elements.
// n is the number of elements. Shuffle panics if n < 0.
// swap swaps the elements with indexes i and j.
Shuffle(n int, swap func(i, j int))
}
// SysRand supports the system random number generator
// for either a separate rand.Rand source, or, if that
// is nil, the global rand stream.
type SysRand struct {
// if non-nil, use this random number source instead of the global default one
Rand *rand.Rand `display:"-"`
}
// NewGlobalRand returns a new SysRand that implements the
// randx.Rand interface, with the system global rand source.
func NewGlobalRand() *SysRand {
r := &SysRand{}
return r
}
// NewSysRand returns a new SysRand with a new
// rand.Rand random source with given initial seed.
func NewSysRand(seed int64) *SysRand {
r := &SysRand{}
r.NewRand(seed)
return r
}
// NewRand sets Rand to a new rand.Rand source using given seed.
func (r *SysRand) NewRand(seed int64) {
r.Rand = rand.New(rand.NewSource(seed))
}
// Seed uses the provided seed value to initialize the generator to a deterministic state.
// Seed should not be called concurrently with any other Rand method.
func (r *SysRand) Seed(seed int64) {
if r.Rand == nil {
rand.Seed(seed)
return
}
r.Rand.Seed(seed)
}
// Int63 returns a non-negative pseudo-random 63-bit integer as an int64.
func (r *SysRand) Int63() int64 {
if r.Rand == nil {
return rand.Int63()
}
return r.Rand.Int63()
}
// Uint32 returns a pseudo-random 32-bit value as a uint32.
func (r *SysRand) Uint32() uint32 {
if r.Rand == nil {
return rand.Uint32()
}
return r.Rand.Uint32()
}
// Uint64 returns a pseudo-random 64-bit value as a uint64.
func (r *SysRand) Uint64() uint64 {
if r.Rand == nil {
return rand.Uint64()
}
return r.Rand.Uint64()
}
// Int31 returns a non-negative pseudo-random 31-bit integer as an int32.
func (r *SysRand) Int31() int32 {
if r.Rand == nil {
return rand.Int31()
}
return r.Rand.Int31()
}
// Int returns a non-negative pseudo-random int.
func (r *SysRand) Int() int {
if r.Rand == nil {
return rand.Int()
}
return r.Rand.Int()
}
// Int63n returns, as an int64, a non-negative pseudo-random number in the half-open interval [0,n).
// It panics if n <= 0.
func (r *SysRand) Int63n(n int64) int64 {
if r.Rand == nil {
return rand.Int63n(n)
}
return r.Rand.Int63n(n)
}
// Int31n returns, as an int32, a non-negative pseudo-random number in the half-open interval [0,n).
// It panics if n <= 0.
func (r *SysRand) Int31n(n int32) int32 {
if r.Rand == nil {
return rand.Int31n(n)
}
return r.Rand.Int31n(n)
}
// Intn returns, as an int, a non-negative pseudo-random number in the half-open interval [0,n).
// It panics if n <= 0.
func (r *SysRand) Intn(n int) int {
if r.Rand == nil {
return rand.Intn(n)
}
return r.Rand.Intn(n)
}
// Float64 returns, as a float64, a pseudo-random number in the half-open interval [0.0,1.0).
func (r *SysRand) Float64() float64 {
if r.Rand == nil {
return rand.Float64()
}
return r.Rand.Float64()
}
// Float32 returns, as a float32, a pseudo-random number in the half-open interval [0.0,1.0).
func (r *SysRand) Float32() float32 {
if r.Rand == nil {
return rand.Float32()
}
return r.Rand.Float32()
}
// NormFloat64 returns a normally distributed float64 in the range
// [-math.MaxFloat64, +math.MaxFloat64] with
// standard normal distribution (mean = 0, stddev = 1)
// from the default Source.
// To produce a different normal distribution, callers can
// adjust the output using:
//
// sample = NormFloat64() * desiredStdDev + desiredMean
func (r *SysRand) NormFloat64() float64 {
if r.Rand == nil {
return rand.NormFloat64()
}
return r.Rand.NormFloat64()
}
// ExpFloat64 returns an exponentially distributed float64 in the range
// (0, +math.MaxFloat64] with an exponential distribution whose rate parameter
// (lambda) is 1 and whose mean is 1/lambda (1) from the default Source.
// To produce a distribution with a different rate parameter,
// callers can adjust the output using:
//
// sample = ExpFloat64() / desiredRateParameter
func (r *SysRand) ExpFloat64() float64 {
if r.Rand == nil {
return rand.ExpFloat64()
}
return r.Rand.ExpFloat64()
}
// Perm returns, as a slice of n ints, a pseudo-random permutation of the integers
// in the half-open interval [0,n).
func (r *SysRand) Perm(n int) []int {
if r.Rand == nil {
return rand.Perm(n)
}
return r.Rand.Perm(n)
}
// Shuffle pseudo-randomizes the order of elements.
// n is the number of elements. Shuffle panics if n < 0.
// swap swaps the elements with indexes i and j.
func (r *SysRand) Shuffle(n int, swap func(i, j int)) {
if r.Rand == nil {
rand.Shuffle(n, swap)
return
}
r.Rand.Shuffle(n, swap)
}
// Copyright (c) 2019, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package randx
// RandParams provides parameterized random number generation according to different distributions
// and variance, mean params
type RandParams struct { //git:add
// distribution to generate random numbers from
Dist RandDists
// mean of random distribution -- typically added to generated random variants
Mean float64
// variability parameter for the random numbers (gauss = standard deviation, not variance; uniform = half-range, others as noted in RandDists)
Var float64
// extra parameter for distribution (depends on each one)
Par float64
}
func (rp *RandParams) Defaults() {
rp.Var = 1
rp.Par = 1
}
func (rp *RandParams) ShouldDisplay(field string) bool {
switch field {
case "Par":
return rp.Dist == Gamma || rp.Dist == Binomial || rp.Dist == Beta
}
return true
}
// Gen generates a random variable according to current parameters.
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func (rp *RandParams) Gen(randOpt ...Rand) float64 {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
switch rp.Dist {
case Uniform:
return UniformMeanRange(rp.Mean, rp.Var, rnd)
case Binomial:
return rp.Mean + BinomialGen(rp.Par, rp.Var, rnd)
case Poisson:
return rp.Mean + PoissonGen(rp.Var, rnd)
case Gamma:
return rp.Mean + GammaGen(rp.Par, rp.Var, rnd)
case Gaussian:
return GaussianGen(rp.Mean, rp.Var, rnd)
case Beta:
return rp.Mean + BetaGen(rp.Var, rp.Par, rnd)
}
return rp.Mean
}
// RandDists are different random number distributions
type RandDists int32 //enums:enum
// The random number distributions
const (
// Uniform has a uniform probability distribution over Var = range on either side of the Mean
Uniform RandDists = iota
// Binomial represents number of 1's in n (Par) random (Bernouli) trials of probability p (Var)
Binomial
// Poisson represents number of events in interval, with event rate (lambda = Var) plus Mean
Poisson
// Gamma represents maximum entropy distribution with two parameters: scaling parameter (Var)
// and shape parameter k (Par) plus Mean
Gamma
// Gaussian normal with Var = stddev plus Mean
Gaussian
// Beta with Var = alpha and Par = beta shape parameters
Beta
// Mean is just the constant Mean, no randomness
Mean
)
// IntZeroN returns uniform random integer in the range between 0 and n, exclusive of n: [0,n).
// Thr is an optional parallel thread index (-1 0 to ignore).
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func IntZeroN(n int64, randOpt ...Rand) int64 {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
return rnd.Int63n(n)
}
// IntMinMax returns uniform random integer in range between min and max, exclusive of max: [min,max).
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func IntMinMax(min, max int64, randOpt ...Rand) int64 {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
return min + rnd.Int63n(max-min)
}
// IntMeanRange returns uniform random integer with given range on either side of the mean:
// [mean - range, mean + range]
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func IntMeanRange(mean, rnge int64, randOpt ...Rand) int64 {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
return mean + (rnd.Int63n(2*rnge+1) - rnge)
}
// ZeroOne returns a uniform random number between zero and one (exclusive of 1)
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func ZeroOne(randOpt ...Rand) float64 {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
return rnd.Float64()
}
// UniformMinMax returns uniform random number between min and max values inclusive
// (Do not use for generating integers - will not include max!)
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func UniformMinMax(min, max float64, randOpt ...Rand) float64 {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
return min + (max-min)*rnd.Float64()
}
// UniformMeanRange returns uniform random number with given range on either size of the mean:
// [mean - range, mean + range]
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func UniformMeanRange(mean, rnge float64, randOpt ...Rand) float64 {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
return mean + rnge*2.0*(rnd.Float64()-0.5)
}
// Copyright (c) 2019, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package randx
import (
"time"
)
// Seeds is a set of random seeds, typically used one per Run
type Seeds []int64
// Init allocates given number of seeds and initializes them to
// sequential numbers 1..n
func (rs *Seeds) Init(n int) {
*rs = make([]int64, n)
for i := range *rs {
(*rs)[i] = int64(i) + 1
}
}
// Set sets the given seed to either the single Rand
// interface passed, or the system global Rand source.
func (rs *Seeds) Set(idx int, randOpt ...Rand) {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
rnd.Seed((*rs)[idx])
}
// NewSeeds sets a new set of random seeds based on current time
func (rs *Seeds) NewSeeds() {
rn := time.Now().UnixNano()
for i := range *rs {
(*rs)[i] = rn + int64(i)
}
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"bufio"
"flag"
"fmt"
"os"
"sort"
"cogentcore.org/core/core"
"cogentcore.org/lab/stats/stats"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorfs"
)
var (
Output string
Col string
OutFile *os.File
OutWriter *bufio.Writer
LF = []byte("\n")
Delete bool
LogPrec = 4
)
func main() {
var help bool
var avg bool
var colavg bool
flag.BoolVar(&help, "help", false, "if true, report usage info")
flag.BoolVar(&avg, "avg", false, "if true, files must have same cols (ideally rows too, though not necessary), outputs average of any float-type columns across files")
flag.BoolVar(&colavg, "colavg", false, "if true, outputs average of any float-type columns aggregated by column")
flag.StringVar(&Col, "col", "", "name of column for colavg")
flag.StringVar(&Output, "output", "", "name of output file -- stdout if not specified")
flag.StringVar(&Output, "o", "", "name of output file -- stdout if not specified")
flag.BoolVar(&Delete, "delete", false, "if true, delete the source files after cat -- careful!")
flag.BoolVar(&Delete, "d", false, "if true, delete the source files after cat -- careful!")
flag.IntVar(&LogPrec, "prec", 4, "precision for number output -- defaults to 4")
flag.Parse()
files := flag.Args()
sort.StringSlice(files).Sort()
if Output != "" {
OutFile, err := os.Create(Output)
if err != nil {
fmt.Println("Error creating output file:", err)
os.Exit(1)
}
defer OutFile.Close()
OutWriter = bufio.NewWriter(OutFile)
} else {
OutWriter = bufio.NewWriter(os.Stdout)
}
switch {
case help || len(files) == 0:
fmt.Printf("\netcat is a data table concatenation utility\n\tassumes all files have header lines, and only retains the header for the first file\n\t(otherwise just use regular cat)\n")
flag.PrintDefaults()
case colavg:
AvgByColumn(files, Col)
case avg:
AvgCat(files)
default:
RawCat(files)
}
OutWriter.Flush()
}
// RawCat concatenates all data in one big file
func RawCat(files []string) {
for fi, fn := range files {
fp, err := os.Open(fn)
if err != nil {
fmt.Println("Error opening file: ", err)
continue
}
scan := bufio.NewScanner(fp)
li := 0
for {
if !scan.Scan() {
break
}
ln := scan.Bytes()
if li == 0 {
if fi == 0 {
OutWriter.Write(ln)
OutWriter.Write(LF)
}
} else {
OutWriter.Write(ln)
OutWriter.Write(LF)
}
li++
}
fp.Close()
if Delete {
os.Remove(fn)
}
}
}
// AvgCat computes average across all runs
func AvgCat(files []string) {
dts := make([]*table.Table, 0, len(files))
for _, fn := range files {
dt := table.New()
err := dt.OpenCSV(core.Filename(fn), tensor.Tab)
if err != nil {
fmt.Println("Error opening file: ", err)
continue
}
if dt.NumRows() == 0 {
fmt.Printf("File %v empty\n", fn)
continue
}
dts = append(dts, dt)
}
if len(dts) == 0 {
fmt.Println("No files or files are empty, exiting")
return
}
// todo: need meantables
// avgdt := stats.MeanTables(dts)
// tensor.SetPrecision(avgdt, LogPrec)
// avgdt.SaveCSV(core.Filename(Output), tensor.Tab, table.Headers)
}
// AvgByColumn computes average by given column for given files
// If column is empty, averages across all rows.
func AvgByColumn(files []string, column string) {
for _, fn := range files {
dt := table.New()
err := dt.OpenCSV(core.Filename(fn), tensor.Tab)
if err != nil {
fmt.Println("Error opening file: ", err)
continue
}
if dt.NumRows() == 0 {
fmt.Printf("File %v empty\n", fn)
continue
}
dir, _ := tensorfs.NewDir("Groups")
if column == "" {
stats.GroupAll(dir, dt.ColumnByIndex(0))
} else {
stats.TableGroups(dir, dt, column)
}
var cols []string
for ci, cl := range dt.Columns.Values {
if cl.IsString() || dt.Columns.Keys[ci] == column {
continue
}
cols = append(cols, dt.Columns.Keys[ci])
}
stats.TableGroupStats(dir, stats.StatMean, dt, cols...)
std := dir.Node("Stats")
avgdt := tensorfs.DirTable(std, nil) // todo: has stat name slash
tensor.SetPrecision(avgdt, LogPrec)
avgdt.SaveCSV(core.Filename(Output), tensor.Tab, table.Headers)
}
}
// Copyright (c) 2020, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"bufio"
"os"
"strings"
"time"
)
// File represents one opened file -- all data is read in and maintained here
type File struct {
// file name (either in same dir or include path)
FName string `desc:"file name (either in same dir or include path)"`
// mod time of file when last read
ModTime time.Time `desc:"mod time of file when last read"`
// delim is commas, not tabs
Commas bool `desc:"delim is commas, not tabs"`
// rows of data == len(Data)
Rows int `desc:"rows of data == len(Data)"`
// width of each column: resized to fit widest element
Widths []int `desc:"width of each column: resized to fit widest element"`
// headers
Heads []string `desc:"headers"`
// data -- rows 1..end
Data [][]string `desc:"data -- rows 1..end"`
}
// Files is a slice of open files
type Files []*File
// TheFiles are the set of open files
var TheFiles Files
// Open opens file, reads it
func (fl *File) Open(fname string) error {
fl.FName = fname
return fl.Read()
}
// CheckUpdate checks if file has been modified -- returns true if so
func (fl *File) CheckUpdate() bool {
st, err := os.Stat(fl.FName)
if err != nil {
return false
}
return st.ModTime().After(fl.ModTime)
}
// Read reads data from file
func (fl *File) Read() error {
st, err := os.Stat(fl.FName)
if err != nil {
return err
}
fl.ModTime = st.ModTime()
f, err := os.Open(fl.FName)
if err != nil {
return err
}
defer f.Close()
if fl.Data != nil {
fl.Data = fl.Data[:0]
}
scan := bufio.NewScanner(f)
ln := 0
for scan.Scan() {
s := string(scan.Bytes())
var fd []string
if fl.Commas {
fd = strings.Split(s, ",")
} else {
fd = strings.Split(s, "\t")
}
if ln == 0 {
if len(fd) == 0 || strings.Count(s, ",") > strings.Count(s, "\t") {
fl.Commas = true
fd = strings.Split(s, ",")
}
fl.Heads = fd
fl.Widths = make([]int, len(fl.Heads))
fl.FitWidths(fd)
ln++
continue
}
fl.Data = append(fl.Data, fd)
fl.FitWidths(fd)
ln++
}
fl.Rows = ln - 1 // skip header
return err
}
// FitWidths expands widths given current set of fields
func (fl *File) FitWidths(fd []string) {
nw := len(fl.Widths)
for i, f := range fd {
if i >= nw {
break
}
w := max(fl.Widths[i], len(f))
fl.Widths[i] = w
}
}
/////////////////////////////////////////////////////////////////
// Files
// Open opens all files
func (fl *Files) Open(fnms []string) {
for _, fn := range fnms {
f := &File{}
err := f.Open(fn)
if err == nil {
*fl = append(*fl, f)
}
}
}
// CheckUpdates check for any updated files, re-read if so -- returns true if so
func (fl *Files) CheckUpdates() bool {
got := false
for _, f := range *fl {
if f.CheckUpdate() {
f.Read()
got = true
}
}
return got
}
// Copyright (c) 2020, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import "github.com/nsf/termbox-go"
func (tm *Term) Help() {
termbox.Clear(termbox.ColorDefault, termbox.ColorDefault)
ln := 0
tm.DrawStringDef(0, ln, "Key(s) Function")
ln++
tm.DrawStringDef(0, ln, "--------------------------------------------------------------")
ln++
tm.DrawStringDef(0, ln, "spc,n page down")
ln++
tm.DrawStringDef(0, ln, "p page up")
ln++
tm.DrawStringDef(0, ln, "f scroll right-hand panel to the right")
ln++
tm.DrawStringDef(0, ln, "b scroll right-hand panel to the left")
ln++
tm.DrawStringDef(0, ln, "w widen the left-hand panel of columns")
ln++
tm.DrawStringDef(0, ln, "s shrink the left-hand panel of columns")
ln++
tm.DrawStringDef(0, ln, "t toggle tail-mode (auto updating as file grows) on/off")
ln++
tm.DrawStringDef(0, ln, "a jump to top")
ln++
tm.DrawStringDef(0, ln, "e jump to end")
ln++
tm.DrawStringDef(0, ln, "v rotate down through the list of files (if not all displayed)")
ln++
tm.DrawStringDef(0, ln, "u rotate up through the list of files (if not all displayed)")
ln++
tm.DrawStringDef(0, ln, "m more minimum lines per file -- increase amount shown of each file")
ln++
tm.DrawStringDef(0, ln, "l less minimum lines per file -- decrease amount shown of each file")
ln++
tm.DrawStringDef(0, ln, "d toggle display of file names")
ln++
tm.DrawStringDef(0, ln, "c toggle display of column numbers instead of names")
ln++
tm.DrawStringDef(0, ln, "q quit")
ln++
termbox.Flush()
}
// Copyright (c) 2020, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"image"
"sync"
termbox "github.com/nsf/termbox-go"
)
// Term represents the terminal display -- has all drawing routines
// and all display data. See Tail for two diff display modes.
type Term struct {
// size of terminal
Size image.Point `desc:"size of terminal"`
// number of fixed (non-scrolling) columns on left
FixCols int `desc:"number of fixed (non-scrolling) columns on left"`
// starting column index -- relative to FixCols
ColSt int `desc:"starting column index -- relative to FixCols"`
// starting row index -- for !Tail mode
RowSt int `desc:"starting row index -- for !Tail mode"`
// row from end -- for Tail mode
RowFromEnd int `desc:"row from end (relative to RowsPer) -- for Tail mode"`
// starting index into files (if too many to display)
FileSt int `desc:"starting index into files (if too many to display)"`
// number of files to display (if too many to display)
NFiles int `desc:"number of files to display (if too many to display)"`
// minimum number of lines per file
MinLines int `desc:"minimum number of lines per file"`
// maximum column width (1/4 of term width)
MaxWd int `desc:"maximum column width (1/4 of term width)"`
// max number of rows across all files
MaxRows int `desc:"max number of rows across all files"`
// number of Y rows per file total: Size.Y / len(TheFiles)
YPer int `desc:"number of Y rows per file total: Size.Y / len(TheFiles)"`
// rows of data per file (subtracting header, filename)
RowsPer int `desc:"rows of data per file (subtracting header, filename)"`
// if true, print filename
ShowFName bool `desc:"if true, print filename"`
// if true, display is synchronized by the last row for each file, and otherwise it is synchronized by the starting row. Tail also checks for file updates
Tail bool `desc:"if true, display is synchronized by the last row for each file, and otherwise it is synchronized by the starting row. Tail also checks for file updates"`
// display column numbers instead of names
ColNums bool `desc:"display column numbers instead of names"`
// draw mutex
Mu sync.Mutex `desc:"draw mutex"`
}
// TheTerm is the terminal instance
var TheTerm Term
// Draw draws the current terminal display
func (tm *Term) Draw() error {
tm.Mu.Lock()
defer tm.Mu.Unlock()
err := termbox.Clear(termbox.ColorDefault, termbox.ColorDefault)
if err != nil {
return err
}
w, h := termbox.Size()
tm.Size.X = w
tm.Size.Y = h
tm.MaxWd = tm.Size.X / 4
if tm.MinLines == 0 {
tm.MinLines = min(5, tm.Size.Y-1)
}
nf := len(TheFiles)
if nf == 0 {
return fmt.Errorf("No files")
}
ysz := tm.Size.Y - 1 // status line
tm.YPer = ysz / nf
tm.NFiles = nf
if tm.YPer < tm.MinLines {
tm.NFiles = ysz / tm.MinLines
tm.YPer = tm.MinLines
}
if tm.NFiles+tm.FileSt > nf {
tm.FileSt = max(0, nf-tm.NFiles)
}
tm.RowsPer = tm.YPer - 1
if tm.ShowFName {
tm.RowsPer--
}
sty := 0
mxrows := 0
for fi := 0; fi < tm.NFiles; fi++ {
ffi := tm.FileSt + fi
if ffi >= nf {
break
}
fl := TheFiles[ffi]
tm.DrawFile(fl, sty)
sty += tm.YPer
mxrows = max(mxrows, fl.Rows)
}
tm.MaxRows = mxrows
tm.StatusLine()
termbox.Flush()
return nil
}
// StatusLine renders the status line at bottom
func (tm *Term) StatusLine() {
pos := tm.RowSt
if tm.Tail {
pos = tm.RowFromEnd
}
stat := fmt.Sprintf("Tail: %v\tPos: %d\tMaxRows: %d\tNFile: %d\tFileSt: %d\t h = help [spc,n,p,r,f,l,b,w,s,t,a,e,v,u,m,l,c,q] ", tm.Tail, pos, tm.MaxRows, len(TheFiles), tm.FileSt)
tm.DrawString(0, tm.Size.Y-1, stat, len(stat), termbox.AttrReverse, termbox.AttrReverse)
}
// NextPage moves down a page
func (tm *Term) NextPage() error {
if tm.Tail {
mn := min(-(tm.MaxRows - tm.RowsPer), 0)
tm.RowFromEnd = min(tm.RowFromEnd+tm.RowsPer, 0)
tm.RowFromEnd = max(tm.RowFromEnd, mn)
} else {
tm.RowSt = min(tm.RowSt+tm.RowsPer, tm.MaxRows-tm.RowsPer)
tm.RowSt = max(tm.RowSt, 0)
}
return tm.Draw()
}
// PrevPage moves up a page
func (tm *Term) PrevPage() error {
if tm.Tail {
mn := min(-(tm.MaxRows - tm.RowsPer), 0)
tm.RowFromEnd = min(tm.RowFromEnd-tm.RowsPer, 0)
tm.RowFromEnd = max(tm.RowFromEnd, mn)
} else {
tm.RowSt = max(tm.RowSt-tm.RowsPer, 0)
tm.RowSt = min(tm.RowSt, tm.MaxRows-tm.RowsPer)
}
return tm.Draw()
}
// NextLine moves down a page
func (tm *Term) NextLine() error {
if tm.Tail {
mn := min(-(tm.MaxRows - tm.RowsPer), 0)
tm.RowFromEnd = min(tm.RowFromEnd+1, 0)
tm.RowFromEnd = max(tm.RowFromEnd, mn)
} else {
tm.RowSt = min(tm.RowSt+1, tm.MaxRows-tm.RowsPer)
tm.RowSt = max(tm.RowSt, 0)
}
return tm.Draw()
}
// PrevLine moves up a page
func (tm *Term) PrevLine() error {
if tm.Tail {
mn := min(-(tm.MaxRows - tm.RowsPer), 0)
tm.RowFromEnd = min(tm.RowFromEnd-1, 0)
tm.RowFromEnd = max(tm.RowFromEnd, mn)
} else {
tm.RowSt = max(tm.RowSt-1, 0)
tm.RowSt = min(tm.RowSt, tm.MaxRows-tm.RowsPer)
}
return tm.Draw()
}
// Top moves to starting row = 0
func (tm *Term) Top() error {
mn := min(-(tm.MaxRows - tm.RowsPer), 0)
tm.RowFromEnd = mn
tm.RowSt = 0
return tm.Draw()
}
// End moves row start to last position in longest file
func (tm *Term) End() error {
mx := max(tm.MaxRows-tm.RowsPer, 0)
tm.RowFromEnd = 0
tm.RowSt = mx
return tm.Draw()
}
// ScrollRight scrolls columns to right
func (tm *Term) ScrollRight() error {
tm.ColSt++ // no obvious max
return tm.Draw()
}
// ScrollLeft scrolls columns to left
func (tm *Term) ScrollLeft() error {
tm.ColSt = max(tm.ColSt-1, 0)
return tm.Draw()
}
// FixRight increases number of fixed columns
func (tm *Term) FixRight() error {
tm.FixCols++ // no obvious max
return tm.Draw()
}
// FixLeft decreases number of fixed columns
func (tm *Term) FixLeft() error {
tm.FixCols = max(tm.FixCols-1, 0)
return tm.Draw()
}
// FilesNext moves down in list of files to display
func (tm *Term) FilesNext() error {
nf := len(TheFiles)
tm.FileSt = min(tm.FileSt+1, nf-tm.NFiles)
tm.FileSt = max(tm.FileSt, 0)
return tm.Draw()
}
// FilesPrev moves up in list of files to display
func (tm *Term) FilesPrev() error {
nf := len(TheFiles)
tm.FileSt = max(tm.FileSt-1, 0)
tm.FileSt = min(tm.FileSt, nf-tm.NFiles)
return tm.Draw()
}
// MoreMinLines increases minimum number of lines per file
func (tm *Term) MoreMinLines() error {
tm.MinLines++
return tm.Draw()
}
// LessMinLines decreases minimum number of lines per file
func (tm *Term) LessMinLines() error {
tm.MinLines--
tm.MinLines = max(3, tm.MinLines)
return tm.Draw()
}
// ToggleNames toggles whether file names are shown
func (tm *Term) ToggleNames() error {
tm.ShowFName = !tm.ShowFName
return tm.Draw()
}
// ToggleTail toggles Tail mode
func (tm *Term) ToggleTail() error {
tm.Tail = !tm.Tail
return tm.Draw()
}
// ToggleColNums toggles ColNums mode
func (tm *Term) ToggleColNums() error {
tm.ColNums = !tm.ColNums
return tm.Draw()
}
// TailCheck does tail update check -- returns true if updated
func (tm *Term) TailCheck() bool {
if !tm.Tail {
return false
}
tm.Mu.Lock()
update := TheFiles.CheckUpdates()
tm.Mu.Unlock()
if !update {
return false
}
tm.Draw()
return true
}
// DrawFile draws one file, starting at given y offset
func (tm *Term) DrawFile(fl *File, sty int) {
tdo := (fl.Rows - tm.RowsPer) + tm.RowFromEnd // tail data offset for this file
tdo = max(0, tdo)
rst := min(tm.RowSt, fl.Rows-tm.RowsPer)
rst = max(0, rst)
stx := 0
for ci, hs := range fl.Heads {
if !(ci < tm.FixCols || ci >= tm.FixCols+tm.ColSt) {
continue
}
my := sty
if tm.ShowFName {
tm.DrawString(0, my, fl.FName, tm.Size.X, termbox.AttrReverse, termbox.AttrReverse)
my++
}
wmax := min(fl.Widths[ci], tm.MaxWd)
if tm.ColNums {
hs = fmt.Sprintf("%d", ci)
}
tm.DrawString(stx, my, hs, wmax, termbox.AttrReverse, termbox.AttrReverse)
if ci == tm.FixCols-1 {
tm.DrawString(stx+wmax+1, my, "|", 1, termbox.AttrReverse, termbox.AttrReverse)
}
my++
for ri := 0; ri < tm.RowsPer; ri++ {
var di int
if tm.Tail {
di = tdo + ri
} else {
di = rst + ri
}
if di >= len(fl.Data) || di < 0 {
continue
}
dr := fl.Data[di]
if ci >= len(dr) {
break
}
ds := dr[ci]
tm.DrawString(stx, my+ri, ds, wmax, termbox.ColorDefault, termbox.ColorDefault)
if ci == tm.FixCols-1 {
tm.DrawString(stx+wmax+1, my+ri, "|", 1, termbox.AttrReverse, termbox.AttrReverse)
}
}
stx += wmax + 1
if ci == tm.FixCols-1 {
stx += 2
}
if stx >= tm.Size.X {
break
}
}
}
// DrawStringDef draws string at given position, using default colors
func (tm *Term) DrawStringDef(x, y int, s string) {
tm.DrawString(x, y, s, tm.Size.X, termbox.ColorDefault, termbox.ColorDefault)
}
// DrawString draws string at given position, using given attributes
func (tm *Term) DrawString(x, y int, s string, maxlen int, fg, bg termbox.Attribute) {
if y >= tm.Size.Y || y < 0 {
return
}
for i, r := range s {
if i >= maxlen {
break
}
xp := x + i
if xp >= tm.Size.X || xp < 0 {
continue
}
termbox.SetCell(xp, y, r, fg, bg)
}
}
// Copyright (c) 2020, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"log"
"os"
"time"
"github.com/nsf/termbox-go"
)
func main() {
err := termbox.Init()
if err != nil {
log.Println(err)
panic(err)
}
defer termbox.Close()
TheFiles.Open(os.Args[1:])
nf := len(TheFiles)
if nf == 0 {
fmt.Printf("usage: etail <filename>... (space separated)\n")
return
}
if nf > 1 {
TheTerm.ShowFName = true
}
err = TheTerm.ToggleTail() // start in tail mode
if err != nil {
log.Println(err)
panic(err)
}
Tailer := time.NewTicker(time.Duration(500) * time.Millisecond)
go func() {
for {
<-Tailer.C
TheTerm.TailCheck()
}
}()
loop:
for {
switch ev := termbox.PollEvent(); ev.Type {
case termbox.EventKey:
switch {
case ev.Key == termbox.KeyEsc || ev.Ch == 'Q' || ev.Ch == 'q':
break loop
case ev.Ch == ' ' || ev.Ch == 'n' || ev.Ch == 'N' || ev.Key == termbox.KeyPgdn || ev.Key == termbox.KeySpace:
TheTerm.NextPage()
case ev.Ch == 'p' || ev.Ch == 'P' || ev.Key == termbox.KeyPgup:
TheTerm.PrevPage()
case ev.Key == termbox.KeyArrowDown:
TheTerm.NextLine()
case ev.Key == termbox.KeyArrowUp:
TheTerm.PrevLine()
case ev.Ch == 'f' || ev.Ch == 'F' || ev.Key == termbox.KeyArrowRight:
TheTerm.ScrollRight()
case ev.Ch == 'b' || ev.Ch == 'B' || ev.Key == termbox.KeyArrowLeft:
TheTerm.ScrollLeft()
case ev.Ch == 'a' || ev.Ch == 'A' || ev.Key == termbox.KeyHome:
TheTerm.Top()
case ev.Ch == 'e' || ev.Ch == 'E' || ev.Key == termbox.KeyEnd:
TheTerm.End()
case ev.Ch == 'w' || ev.Ch == 'W':
TheTerm.FixRight()
case ev.Ch == 's' || ev.Ch == 'S':
TheTerm.FixLeft()
case ev.Ch == 'v' || ev.Ch == 'V':
TheTerm.FilesNext()
case ev.Ch == 'u' || ev.Ch == 'U':
TheTerm.FilesPrev()
case ev.Ch == 'm' || ev.Ch == 'M':
TheTerm.MoreMinLines()
case ev.Ch == 'l' || ev.Ch == 'L':
TheTerm.LessMinLines()
case ev.Ch == 'd' || ev.Ch == 'D':
TheTerm.ToggleNames()
case ev.Ch == 't' || ev.Ch == 'T':
TheTerm.ToggleTail()
case ev.Ch == 'c' || ev.Ch == 'C':
TheTerm.ToggleColNums()
case ev.Ch == 'h' || ev.Ch == 'H':
TheTerm.Help()
}
case termbox.EventResize:
TheTerm.Draw()
}
}
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"embed"
"cogentcore.org/core/core"
"cogentcore.org/core/pages"
)
//go:embed content
var content embed.FS
func main() {
b := core.NewBody("Cogent Lab")
pg := pages.NewPage(b).SetContent(content)
b.AddTopBar(func(bar *core.Frame) {
core.NewToolbar(bar).Maker(pg.MakeToolbar)
})
b.RunMainWindow()
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"embed"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/core"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorcore"
)
//go:embed *.tsv
var tsv embed.FS
func main() {
pats := table.New("TrainPats")
metadata.SetDoc(pats, "Training patterns")
// todo: meta data for grid size
errors.Log(pats.OpenFS(tsv, "random_5x5_25.tsv", tensor.Tab))
b := core.NewBody("grids")
tv := core.NewTabs(b)
nt, _ := tv.NewTab("Patterns")
etv := tensorcore.NewTable(nt)
tensorcore.AddGridStylerTo(pats, func(s *tensorcore.GridStyle) {
s.TotalSize = 200
})
etv.SetTable(pats)
b.AddTopBar(func(bar *core.Frame) {
core.NewToolbar(bar).Maker(etv.MakeToolbar)
})
b.RunMainWindow()
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"embed"
"math"
"cogentcore.org/core/cli"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/icons"
"cogentcore.org/core/tree"
"cogentcore.org/core/yaegicore/coresymbols"
"cogentcore.org/lab/goal/interpreter"
"cogentcore.org/lab/lab"
"cogentcore.org/lab/stats/stats"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorfs"
)
//go:embed *.csv
var csv embed.FS
// AnalyzePlanets analyzes planets.csv data following some of the examples
// in pandas from:
// https://jakevdp.github.io/PythonDataScienceHandbook/03.08-aggregation-and-grouping.html
func AnalyzePlanets(dir *tensorfs.Node) {
Planets := table.New("planets")
Planets.OpenFS(csv, "planets.csv", tensor.Comma)
vals := []string{"number", "orbital_period", "mass", "distance", "year"}
stats.DescribeTable(dir, Planets, vals...)
decade := Planets.AddFloat64Column("decade")
year := Planets.Column("year")
for row := range Planets.NumRows() {
yr := year.FloatRow(row, 0)
dec := math.Floor(yr/10) * 10
decade.SetFloatRow(dec, row, 0)
}
stats.TableGroups(dir, Planets, "method", "decade")
stats.TableGroupDescribe(dir, Planets, vals...)
// byMethod := split.GroupBy(PlanetsAll, "method")
// split.AggColumn(byMethod, "orbital_period", stats.Median)
// GpMethodOrbit = byMethod.AggsToTable(table.AddAggName)
// byMethod.DeleteAggs()
// split.DescColumn(byMethod, "year") // full desc stats of year
// byMethod.Filter(func(idx int) bool {
// ag := errors.Log1(byMethod.AggByColumnName("year:Std"))
// return ag.Aggs[idx][0] > 0 // exclude results with 0 std
// })
// GpMethodYear = byMethod.AggsToTable(table.AddAggName)
// split.AggColumn(byMethodDecade, "number", stats.Sum)
// uncomment this to switch to decade first, then method
// byMethodDecade.ReorderLevels([]int{1, 0})
// byMethodDecade.SortLevels()
// decadeOnly := errors.Log1(byMethodDecade.ExtractLevels([]int{1}))
// split.AggColumn(decadeOnly, "number", stats.Sum)
// GpDecade = decadeOnly.AggsToTable(table.AddAggName)
//
// GpMethodDecade = byMethodDecade.AggsToTable(table.AddAggName) // here to ensure that decadeOnly didn't mess up..
// todo: need unstack -- should be specific to the splits data because we already have the cols and
// groups etc -- the ExtractLevels method provides key starting point.
// todo: pivot table -- neeeds unstack function.
// todo: could have a generic unstack-like method that takes a column for the data to turn into columns
// and another that has the data to put in the cells.
}
// important: must be run from an interactive terminal.
// Will quit immediately if not!
func main() {
dir := tensorfs.Mkdir("Planets")
AnalyzePlanets(dir)
opts := cli.DefaultOptions("planets", "interactive data analysis.")
cfg := &interpreter.Config{}
cfg.InteractiveFunc = Interactive
cli.Run(opts, cfg, interpreter.Run, interpreter.Build)
}
func Interactive(c *interpreter.Config, in *interpreter.Interpreter) error {
in.Interp.Use(coresymbols.Symbols) // gui imports
in.Config()
b, _ := lab.NewBasicWindow(tensorfs.CurRoot, "Planets")
b.AddTopBar(func(bar *core.Frame) {
tb := core.NewToolbar(bar)
// tb.Maker(tbv.MakeToolbar)
tb.Maker(func(p *tree.Plan) {
tree.Add(p, func(w *core.Button) {
w.SetText("README").SetIcon(icons.FileMarkdown).
SetTooltip("open README help file").OnClick(func(e events.Event) {
core.TheApp.OpenURL("https://github.com/cogentcore/core/blob/main/tensor/examples/planets/README.md")
})
})
})
})
b.OnShow(func(e events.Event) {
go func() {
if c.Expr != "" {
in.Eval(c.Expr)
}
in.Interactive()
}()
})
b.RunWindow()
core.Wait()
return nil
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"embed"
"cogentcore.org/core/core"
"cogentcore.org/lab/plot"
"cogentcore.org/lab/plotcore"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
)
//go:embed *.tsv
var tsv embed.FS
func main() {
b := core.NewBody("Plot Example")
epc := table.New("epc")
epc.OpenFS(tsv, "ra25epoch.tsv", tensor.Tab)
pst := func(s *plot.Style) {
s.Plot.Title = "RA25 Epoch Train"
}
perr := epc.Column("PctErr")
plot.SetStylersTo(perr, plot.Stylers{pst, func(s *plot.Style) {
s.On = true
s.Role = plot.Y
}})
pl := plotcore.NewPlotEditor(b)
pl.SetTable(epc)
b.AddTopBar(func(bar *core.Frame) {
core.NewToolbar(bar).Maker(pl.MakeToolbar)
})
b.RunMainWindow()
}
// Copyright (c) 2020, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package random plots histograms of random distributions.
package main
//go:generate core generate
import (
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/icons"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/core/styles"
"cogentcore.org/core/tree"
"cogentcore.org/lab/base/randx"
"cogentcore.org/lab/plot"
"cogentcore.org/lab/plot/plots"
"cogentcore.org/lab/plotcore"
"cogentcore.org/lab/stats/histogram"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
)
// Random is the random distribution plotter widget.
type Random struct {
core.Frame
Data
}
// Data contains the random distribution plotter data and options.
type Data struct { //types:add
// random params
Dist randx.RandParams `display:"add-fields"`
// number of samples
NumSamples int
// number of bins in the histogram
NumBins int
// range for histogram
Range minmax.F64
// table for raw data
Table *table.Table `display:"no-inline"`
// histogram of data
Histogram *table.Table `display:"no-inline"`
// the plot
plot *plotcore.PlotEditor `display:"-"`
}
// logPrec is the precision for saving float values in logs.
const logPrec = 4
func (rd *Random) Init() {
rd.Frame.Init()
rd.Dist.Defaults()
rd.Dist.Dist = randx.Gaussian
rd.Dist.Mean = 0.5
rd.Dist.Var = 0.15
rd.NumSamples = 1000000
rd.NumBins = 100
rd.Range.Set(0, 1)
rd.Table = table.New()
rd.Histogram = table.New()
rd.ConfigTable(rd.Table)
rd.Plot()
rd.Styler(func(s *styles.Style) {
s.Grow.Set(1, 1)
})
tree.AddChild(rd, func(w *core.Splits) {
w.SetSplits(0.3, 0.7)
tree.AddChild(w, func(w *core.Form) {
w.SetStruct(&rd.Data)
w.OnChange(func(e events.Event) {
rd.Plot()
})
})
tree.AddChild(w, func(w *plotcore.PlotEditor) {
w.SetTable(rd.Histogram)
rd.plot = w
})
})
}
// Plot generates the data and plots a histogram of results.
func (rd *Random) Plot() { //types:add
dt := rd.Table
dt.SetNumRows(rd.NumSamples)
for vi := 0; vi < rd.NumSamples; vi++ {
vl := rd.Dist.Gen()
dt.Column("Value").SetFloat(float64(vl), vi)
}
histogram.F64Table(rd.Histogram, dt.Columns.Values[0].(*tensor.Float64).Values, rd.NumBins, rd.Range.Min, rd.Range.Max)
if rd.plot != nil {
rd.plot.UpdatePlot()
}
}
func (rd *Random) ConfigTable(dt *table.Table) {
metadata.SetName(dt, "Data")
// metadata.SetReadOnly(dt, true)
tensor.SetPrecision(dt, logPrec)
val := dt.AddFloat64Column("Value")
plot.SetFirstStylerTo(val, func(s *plot.Style) {
s.Role = plot.X
s.Plotter = plots.BarType
s.Plot.XAxis.Rotation = 45
s.Plot.Title = "Random distribution histogram"
})
}
func (rd *Random) MakeToolbar(p *tree.Plan) {
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(rd.Plot).SetIcon(icons.ScatterPlot)
})
tree.Add(p, func(w *core.Separator) {})
if rd.plot != nil {
rd.plot.MakeToolbar(p)
}
}
func main() {
b := core.NewBody("Random numbers")
rd := NewRandom(b)
b.AddTopBar(func(bar *core.Frame) {
core.NewToolbar(bar).Maker(rd.MakeToolbar)
})
b.RunMainWindow()
}
// Code generated by "core generate"; DO NOT EDIT.
package main
import (
"cogentcore.org/core/tree"
"cogentcore.org/core/types"
)
var _ = types.AddType(&types.Type{Name: "main.Random", IDName: "random", Doc: "Random is the random distribution plotter widget.", Methods: []types.Method{{Name: "Plot", Doc: "Plot generates the data and plots a histogram of results.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}}, Embeds: []types.Field{{Name: "Frame"}, {Name: "Data"}}})
// NewRandom returns a new [Random] with the given optional parent:
// Random is the random distribution plotter widget.
func NewRandom(parent ...tree.Node) *Random { return tree.New[Random](parent...) }
var _ = types.AddType(&types.Type{Name: "main.Data", IDName: "data", Doc: "Data contains the random distribution plotter data and options.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Fields: []types.Field{{Name: "Dist", Doc: "random params"}, {Name: "NumSamples", Doc: "number of samples"}, {Name: "NumBins", Doc: "number of bins in the histogram"}, {Name: "Range", Doc: "range for histogram"}, {Name: "Table", Doc: "table for raw data"}, {Name: "Histogram", Doc: "histogram of data"}, {Name: "plot", Doc: "the plot"}}})
// Code generated by "core generate"; DO NOT EDIT.
package main
import (
"cogentcore.org/core/enums"
)
var _TimesValues = []Times{0, 1, 2}
// TimesN is the highest valid value for type Times, plus one.
const TimesN Times = 3
var _TimesValueMap = map[string]Times{`Trial`: 0, `Epoch`: 1, `Run`: 2}
var _TimesDescMap = map[Times]string{0: ``, 1: ``, 2: ``}
var _TimesMap = map[Times]string{0: `Trial`, 1: `Epoch`, 2: `Run`}
// String returns the string representation of this Times value.
func (i Times) String() string { return enums.String(i, _TimesMap) }
// SetString sets the Times value from its string representation,
// and returns an error if the string is invalid.
func (i *Times) SetString(s string) error { return enums.SetString(i, s, _TimesValueMap, "Times") }
// Int64 returns the Times value as an int64.
func (i Times) Int64() int64 { return int64(i) }
// SetInt64 sets the Times value from an int64.
func (i *Times) SetInt64(in int64) { *i = Times(in) }
// Desc returns the description of the Times value.
func (i Times) Desc() string { return enums.Desc(i, _TimesDescMap) }
// TimesValues returns all possible values for the type Times.
func TimesValues() []Times { return _TimesValues }
// Values returns all possible values for the type Times.
func (i Times) Values() []enums.Enum { return enums.Values(_TimesValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i Times) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *Times) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Times") }
var _LoopPhaseValues = []LoopPhase{0, 1}
// LoopPhaseN is the highest valid value for type LoopPhase, plus one.
const LoopPhaseN LoopPhase = 2
var _LoopPhaseValueMap = map[string]LoopPhase{`Start`: 0, `Step`: 1}
var _LoopPhaseDescMap = map[LoopPhase]string{0: `Start is the start of the loop: resets accumulated stats, initializes.`, 1: `Step is each iteration of the loop.`}
var _LoopPhaseMap = map[LoopPhase]string{0: `Start`, 1: `Step`}
// String returns the string representation of this LoopPhase value.
func (i LoopPhase) String() string { return enums.String(i, _LoopPhaseMap) }
// SetString sets the LoopPhase value from its string representation,
// and returns an error if the string is invalid.
func (i *LoopPhase) SetString(s string) error {
return enums.SetString(i, s, _LoopPhaseValueMap, "LoopPhase")
}
// Int64 returns the LoopPhase value as an int64.
func (i LoopPhase) Int64() int64 { return int64(i) }
// SetInt64 sets the LoopPhase value from an int64.
func (i *LoopPhase) SetInt64(in int64) { *i = LoopPhase(in) }
// Desc returns the description of the LoopPhase value.
func (i LoopPhase) Desc() string { return enums.Desc(i, _LoopPhaseDescMap) }
// LoopPhaseValues returns all possible values for the type LoopPhase.
func LoopPhaseValues() []LoopPhase { return _LoopPhaseValues }
// Values returns all possible values for the type LoopPhase.
func (i LoopPhase) Values() []enums.Enum { return enums.Values(_LoopPhaseValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i LoopPhase) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *LoopPhase) UnmarshalText(text []byte) error {
return enums.UnmarshalText(i, text, "LoopPhase")
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
//go:generate core generate
import (
"math/rand/v2"
"cogentcore.org/core/core"
"cogentcore.org/lab/lab"
"cogentcore.org/lab/plot"
"cogentcore.org/lab/stats/stats"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorfs"
)
// Times are the looping time levels for running and statistics.
type Times int32 //enums:enum
const (
Trial Times = iota
Epoch
Run
)
// LoopPhase is the phase of loop processing for given time.
type LoopPhase int32 //enums:enum
const (
// Start is the start of the loop: resets accumulated stats, initializes.
Start LoopPhase = iota
// Step is each iteration of the loop.
Step
)
type Sim struct {
// Root is the root data dir.
Root *tensorfs.Node
// Config has config data.
Config *tensorfs.Node
// Stats has all stats data.
Stats *tensorfs.Node
// Current has current value of all stats
Current *tensorfs.Node
// StatFuncs are statistics functions, per stat, handles everything.
StatFuncs []func(ltime Times, lphase LoopPhase)
// Counters are current values of counters: normally in looper.
Counters [TimesN]int
}
// ConfigAll configures the sim
func (ss *Sim) ConfigAll() {
ss.Root, _ = tensorfs.NewDir("Root")
ss.Config = ss.Root.Dir("Config")
mx := tensorfs.Value[int](ss.Config, "Max", int(TimesN)).(*tensor.Int)
mx.Set1D(5, int(Trial))
mx.Set1D(4, int(Epoch))
mx.Set1D(3, int(Run))
// todo: failing - assigns 3 to all
// # mx[Trial] = 5
// # mx[Epoch] = 4
// # mx[Run] = 3
ss.ConfigStats()
}
func (ss *Sim) AddStat(f func(ltime Times, lphase LoopPhase)) {
ss.StatFuncs = append(ss.StatFuncs, f)
}
func (ss *Sim) RunStats(ltime Times, lphase LoopPhase) {
for _, sf := range ss.StatFuncs {
sf(ltime, lphase)
}
}
func (ss *Sim) ConfigStats() {
ss.Stats = ss.Root.Dir("Stats")
ss.Current = ss.Stats.Dir("Current")
ctrs := []Times{Run, Epoch, Trial}
for _, ctr := range ctrs {
ss.AddStat(func(ltime Times, lphase LoopPhase) {
if ltime > ctr { // don't record counter for time above it
return
}
name := ctr.String() // name of stat = counter
timeDir := ss.Stats.Dir(ltime.String())
tsr := tensorfs.Value[int](timeDir, name)
if lphase == Start {
tsr.SetNumRows(0)
if ps := plot.GetStylersFrom(tsr); ps == nil {
ps.Add(func(s *plot.Style) {
s.Range.SetMin(0)
})
plot.SetStylersTo(tsr, ps)
}
return
}
ctv := ss.Counters[ctr]
tensorfs.Scalar[int](ss.Current, name).SetInt1D(ctv, 0)
tsr.AppendRowInt(ctv)
})
}
// note: it is essential to only have 1 per func
// so generic names can be used for everything.
ss.AddStat(func(ltime Times, lphase LoopPhase) {
name := "SSE"
timeDir := ss.Stats.Dir(ltime.String())
tsr := timeDir.Float64(name)
if lphase == Start {
tsr.SetNumRows(0)
if ps := plot.GetStylersFrom(tsr); ps == nil {
ps.Add(func(s *plot.Style) {
s.Range.SetMin(0).SetMax(1)
s.On = true
})
plot.SetStylersTo(tsr, ps)
}
return
}
switch ltime {
case Trial:
stat := rand.Float64()
tensorfs.Scalar[float64](ss.Current, name).SetFloat(stat, 0)
tsr.AppendRowFloat(stat)
case Epoch:
subd := ss.Stats.Dir((ltime - 1).String())
stat := stats.StatMean.Call(subd.Float64(name))
tsr.AppendRow(stat)
case Run:
subd := ss.Stats.Dir((ltime - 1).String())
stat := stats.StatMean.Call(subd.Float64(name))
tsr.AppendRow(stat)
}
})
ss.AddStat(func(ltime Times, lphase LoopPhase) {
name := "Err"
timeDir := ss.Stats.Dir(ltime.String())
tsr := tensorfs.Value[float64](timeDir, name)
if lphase == Start {
tsr.SetNumRows(0)
if ps := plot.GetStylersFrom(tsr); ps == nil {
ps.Add(func(s *plot.Style) {
s.Range.SetMin(0).SetMax(1)
s.On = true
})
plot.SetStylersTo(tsr, ps)
}
return
}
switch ltime {
case Trial:
sse := tensorfs.Scalar[float64](ss.Current, "SSE").Float1D(0)
stat := 1.0
if sse < 0.5 {
stat = 0
}
tensorfs.Scalar[float64](ss.Current, name).SetFloat(stat, 0)
tsr.AppendRowFloat(stat)
case Epoch:
subd := ss.Stats.Dir((ltime - 1).String())
stat := stats.StatMean.Call(subd.Value(name))
tsr.AppendRow(stat)
case Run:
subd := ss.Stats.Dir((ltime - 1).String())
stat := stats.StatMean.Call(subd.Value(name))
tsr.AppendRow(stat)
}
})
}
func (ss *Sim) Run() {
mx := ss.Config.Value("Max").(*tensor.Int)
nrun := mx.Value1D(int(Run))
nepc := mx.Value1D(int(Epoch))
ntrl := mx.Value1D(int(Trial))
ss.RunStats(Run, Start)
for run := range nrun {
ss.Counters[Run] = run
ss.RunStats(Epoch, Start)
for epc := range nepc {
ss.Counters[Epoch] = epc
ss.RunStats(Trial, Start)
for trl := range ntrl {
ss.Counters[Trial] = trl
ss.RunStats(Trial, Step)
}
ss.RunStats(Epoch, Step)
}
ss.RunStats(Run, Step)
}
// todo: could do final analysis here
// alldt := ss.Logs.Item("AllTrials").GetDirTable(nil)
// dir := ss.Logs.Dir("Stats")
// stats.TableGroups(dir, alldt, "Run", "Epoch", "Trial")
// sts := []string{"SSE", "AvgSSE", "TrlErr"}
// stats.TableGroupStats(dir, stats.StatMean, alldt, sts...)
// stats.TableGroupStats(dir, stats.StatSem, alldt, sts...)
}
func main() {
ss := &Sim{}
ss.ConfigAll()
ss.Run()
b, _ := lab.NewBasicWindow(ss.Root, "Root")
b.RunWindow()
core.Wait()
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package goal
import (
"context"
"fmt"
"log/slog"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/exec"
"cogentcore.org/core/base/logx"
"cogentcore.org/core/base/sshclient"
"cogentcore.org/core/base/stringsx"
"github.com/mitchellh/go-homedir"
)
// InstallBuiltins adds the builtin goal commands to [Goal.Builtins].
func (gl *Goal) InstallBuiltins() {
gl.Builtins = make(map[string]func(cmdIO *exec.CmdIO, args ...string) error)
gl.Builtins["cd"] = gl.Cd
gl.Builtins["exit"] = gl.Exit
gl.Builtins["jobs"] = gl.JobsCmd
gl.Builtins["kill"] = gl.Kill
gl.Builtins["set"] = gl.Set
gl.Builtins["unset"] = gl.Unset
gl.Builtins["add-path"] = gl.AddPath
gl.Builtins["which"] = gl.Which
gl.Builtins["source"] = gl.Source
gl.Builtins["cossh"] = gl.CoSSH
gl.Builtins["scp"] = gl.Scp
gl.Builtins["debug"] = gl.Debug
gl.Builtins["history"] = gl.History
}
// Cd changes the current directory.
func (gl *Goal) Cd(cmdIO *exec.CmdIO, args ...string) error {
if len(args) > 1 {
return fmt.Errorf("no more than one argument can be passed to cd")
}
dir := ""
if len(args) == 1 {
dir = args[0]
}
dir, err := homedir.Expand(dir)
if err != nil {
return err
}
if dir == "" {
dir, err = homedir.Dir()
if err != nil {
return err
}
}
dir, err = filepath.Abs(dir)
if err != nil {
return err
}
err = os.Chdir(dir)
if err != nil {
return err
}
gl.Config.Dir = dir
return nil
}
// Exit exits the shell.
func (gl *Goal) Exit(cmdIO *exec.CmdIO, args ...string) error {
os.Exit(0)
return nil
}
// Set sets the given environment variable to the given value.
func (gl *Goal) Set(cmdIO *exec.CmdIO, args ...string) error {
if len(args) != 2 {
return fmt.Errorf("expected two arguments, got %d", len(args))
}
val := args[1]
if strings.Count(val, ":") > 1 || strings.Contains(val, "~") {
vl := stringsx.UniqueList(strings.Split(val, ":"))
vl = AddHomeExpand([]string{}, vl...)
val = strings.Join(vl, ":")
}
err := os.Setenv(args[0], val)
if runtime.GOOS == "darwin" {
gl.Config.RunIO(cmdIO, "/bin/launchctl", "setenv", args[0], val)
}
return err
}
// Unset un-sets the given environment variable.
func (gl *Goal) Unset(cmdIO *exec.CmdIO, args ...string) error {
if len(args) != 1 {
return fmt.Errorf("expected one argument, got %d", len(args))
}
err := os.Unsetenv(args[0])
if runtime.GOOS == "darwin" {
gl.Config.RunIO(cmdIO, "/bin/launchctl", "unsetenv", args[0])
}
return err
}
// JobsCmd is the builtin jobs command
func (gl *Goal) JobsCmd(cmdIO *exec.CmdIO, args ...string) error {
for i, jb := range gl.Jobs {
cmdIO.Printf("[%d] %s\n", i+1, jb.String())
}
return nil
}
// Kill kills a job by job number or PID.
// Just expands the job id expressions %n into PIDs and calls system kill.
func (gl *Goal) Kill(cmdIO *exec.CmdIO, args ...string) error {
if len(args) == 0 {
return fmt.Errorf("goal kill: expected at least one argument")
}
gl.JobIDExpand(args)
gl.Config.RunIO(cmdIO, "kill", args...)
return nil
}
// Fg foregrounds a job by job number
func (gl *Goal) Fg(cmdIO *exec.CmdIO, args ...string) error {
if len(args) != 1 {
return fmt.Errorf("goal fg: requires exactly one job id argument")
}
jid := args[0]
exp := gl.JobIDExpand(args)
if exp != 1 {
return fmt.Errorf("goal fg: argument was not a job id in the form %%n")
}
jno, _ := strconv.Atoi(jid[1:]) // guaranteed good
job := gl.Jobs[jno]
cmdIO.Printf("foregrounding job [%d]\n", jno)
_ = job
// todo: the problem here is we need to change the stdio for running job
// job.Cmd.Wait() // wait
// * probably need to have wrapper StdIO for every exec so we can flexibly redirect for fg, bg commands.
// * likewise, need to run everything effectively as a bg job with our own explicit Wait, which we can then communicate with to move from fg to bg.
return nil
}
// AddHomeExpand adds given strings to the given list of strings,
// expanding any ~ symbols with the home directory,
// and returns the updated list.
func AddHomeExpand(list []string, adds ...string) []string {
for _, add := range adds {
add, err := homedir.Expand(add)
errors.Log(err)
has := false
for _, s := range list {
if s == add {
has = true
}
}
if !has {
list = append(list, add)
}
}
return list
}
// AddPath adds the given path(s) to $PATH.
func (gl *Goal) AddPath(cmdIO *exec.CmdIO, args ...string) error {
if len(args) == 0 {
return fmt.Errorf("goal add-path expected at least one argument")
}
path := os.Getenv("PATH")
ps := strings.Split(path, ":")
ps = stringsx.UniqueList(ps)
ps = AddHomeExpand(ps, args...)
path = strings.Join(ps, ":")
err := os.Setenv("PATH", path)
// if runtime.GOOS == "darwin" {
// this is what would be required to work:
// sudo launchctl config user path $PATH -- the following does not work:
// gl.Config.RunIO(cmdIO, "/bin/launchctl", "setenv", "PATH", path)
// }
return err
}
// Which reports the executable associated with the given command.
// Processes builtins and commands, and if not found, then passes on
// to exec which.
func (gl *Goal) Which(cmdIO *exec.CmdIO, args ...string) error {
if len(args) != 1 {
return fmt.Errorf("goal which: requires one argument")
}
cmd := args[0]
if _, hasCmd := gl.Commands[cmd]; hasCmd {
cmdIO.Println(cmd, "is a user-defined command")
return nil
}
if _, hasBlt := gl.Builtins[cmd]; hasBlt {
cmdIO.Println(cmd, "is a goal builtin command")
return nil
}
gl.Config.RunIO(cmdIO, "which", args...)
return nil
}
// Source loads and evaluates the given file(s)
func (gl *Goal) Source(cmdIO *exec.CmdIO, args ...string) error {
if len(args) == 0 {
return fmt.Errorf("goal source: requires at least one argument")
}
for _, fn := range args {
gl.TranspileCodeFromFile(fn)
}
// note that we do not execute the file -- just loads it in
return nil
}
// CoSSH manages SSH connections, which are referenced by the @name
// identifier. It handles the following cases:
// - @name -- switches to using given host for all subsequent commands
// - host [name] -- connects to a server specified in first arg and switches
// to using it, with optional name instead of default sequential number.
// - close -- closes all open connections, or the specified one
func (gl *Goal) CoSSH(cmdIO *exec.CmdIO, args ...string) error {
if len(args) < 1 {
return fmt.Errorf("cossh: requires at least one argument")
}
cmd := args[0]
var err error
host := ""
name := fmt.Sprintf("%d", 1+len(gl.SSHClients))
con := false
switch {
case cmd == "close":
gl.CloseSSH()
return nil
case cmd == "@" && len(args) == 2:
name = args[1]
case len(args) == 2:
con = true
host = args[0]
name = args[1]
default:
con = true
host = args[0]
}
if con {
cl := sshclient.NewClient(gl.SSH)
err = cl.Connect(host)
if err != nil {
return err
}
gl.SSHClients[name] = cl
gl.SSHActive = name
} else {
if name == "0" {
gl.SSHActive = ""
} else {
gl.SSHActive = name
cl := gl.ActiveSSH()
if cl == nil {
err = fmt.Errorf("goal: ssh connection named: %q not found", name)
}
}
}
return err
}
// Scp performs file copy over SSH connection, with the remote filename
// prefixed with the @name: and the local filename un-prefixed.
// The order is from -> to, as in standard cp.
// The remote filename is automatically relative to the current working
// directory on the remote host.
func (gl *Goal) Scp(cmdIO *exec.CmdIO, args ...string) error {
if len(args) != 2 {
return fmt.Errorf("scp: requires exactly two arguments")
}
var lfn, hfn string
toHost := false
if args[0][0] == '@' {
hfn = args[0]
lfn = args[1]
} else if args[1][0] == '@' {
hfn = args[1]
lfn = args[0]
toHost = true
} else {
return fmt.Errorf("scp: one of the files must a remote host filename, specified by @name:")
}
ci := strings.Index(hfn, ":")
if ci < 0 {
return fmt.Errorf("scp: remote host filename does not contain a : after the host name")
}
host := hfn[1:ci]
hfn = hfn[ci+1:]
cl, err := gl.SSHByHost(host)
if err != nil {
return err
}
ctx := gl.Ctx
if ctx == nil {
ctx = context.Background()
}
if toHost {
err = cl.CopyLocalFileToHost(ctx, lfn, hfn)
} else {
err = cl.CopyHostToLocalFile(ctx, hfn, lfn)
}
return err
}
// Debug changes log level
func (gl *Goal) Debug(cmdIO *exec.CmdIO, args ...string) error {
if len(args) == 0 {
if logx.UserLevel == slog.LevelDebug {
logx.UserLevel = slog.LevelInfo
} else {
logx.UserLevel = slog.LevelDebug
}
}
if len(args) == 1 {
lev := args[0]
if lev == "on" || lev == "true" || lev == "1" {
logx.UserLevel = slog.LevelDebug
} else {
logx.UserLevel = slog.LevelInfo
}
}
return nil
}
// History shows history
func (gl *Goal) History(cmdIO *exec.CmdIO, args ...string) error {
n := len(gl.Hist)
nh := n
if len(args) == 1 {
an, err := strconv.Atoi(args[0])
if err != nil {
return fmt.Errorf("history: error parsing number of history items: %q, error: %s", args[0], err.Error())
}
nh = min(n, an)
} else if len(args) > 1 {
return fmt.Errorf("history: uses at most one argument")
}
for i := n - nh; i < n; i++ {
cmdIO.Printf("%d:\t%s\n", i, gl.Hist[i])
}
return nil
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Command goal is an interactive cli for running and compiling Goal code.
package main
import (
"cogentcore.org/core/cli"
"cogentcore.org/lab/goal/interpreter"
)
func main() { //types:skip
opts := cli.DefaultOptions("goal", "An interactive tool for running and compiling Goal (Go augmented language).")
cfg := &interpreter.Config{}
cfg.InteractiveFunc = interpreter.Interactive
cli.Run(opts, cfg, interpreter.Run, interpreter.Build)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package goal
import (
"os"
"path/filepath"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/icons"
"cogentcore.org/core/parse/complete"
"github.com/mitchellh/go-homedir"
)
// CompleteMatch is the [complete.MatchFunc] for the shell.
func (gl *Goal) CompleteMatch(data any, text string, posLine, posChar int) (md complete.Matches) {
comps := complete.Completions{}
text = text[:posChar]
md.Seed = complete.SeedPath(text)
fullPath := complete.SeedSpace(text)
fullPath = errors.Log1(homedir.Expand(fullPath))
parent := strings.TrimSuffix(fullPath, md.Seed)
dir := filepath.Join(gl.Config.Dir, parent)
if filepath.IsAbs(parent) {
dir = parent
}
entries := errors.Log1(os.ReadDir(dir))
for _, entry := range entries {
icon := icons.File
if entry.IsDir() {
icon = icons.Folder
}
name := strings.ReplaceAll(entry.Name(), " ", `\ `) // escape spaces
comps = append(comps, complete.Completion{
Text: name,
Icon: icon,
Desc: filepath.Join(gl.Config.Dir, name),
})
}
if parent == "" {
for cmd := range gl.Builtins {
comps = append(comps, complete.Completion{
Text: cmd,
Icon: icons.Terminal,
Desc: "Builtin command: " + cmd,
})
}
for cmd := range gl.Commands {
comps = append(comps, complete.Completion{
Text: cmd,
Icon: icons.Terminal,
Desc: "Command: " + cmd,
})
}
// todo: write something that looks up all files on path -- should cache that per
// path string setting
}
md.Matches = complete.MatchSeedCompletion(comps, md.Seed)
return md
}
// CompleteEdit is the [complete.EditFunc] for the shell.
func (gl *Goal) CompleteEdit(data any, text string, cursorPos int, completion complete.Completion, seed string) (ed complete.Edit) {
return complete.EditWord(text, cursorPos, completion.Text, seed)
}
// ReadlineCompleter implements [github.com/ergochat/readline.AutoCompleter].
type ReadlineCompleter struct {
Goal *Goal
}
func (rc *ReadlineCompleter) Do(line []rune, pos int) (newLine [][]rune, length int) {
text := string(line)
md := rc.Goal.CompleteMatch(nil, text, 0, pos)
res := [][]rune{}
for _, match := range md.Matches {
after := strings.TrimPrefix(match.Text, md.Seed)
if md.Seed != "" && after == match.Text {
continue // no overlap
}
if match.Icon == icons.Folder {
after += string(filepath.Separator)
} else {
after += " "
}
res = append(res, []rune(after))
}
return res, len(md.Seed)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package goal
import (
"bytes"
"fmt"
"os"
"path/filepath"
"slices"
"strings"
"sync"
"time"
"cogentcore.org/core/base/exec"
"cogentcore.org/core/base/reflectx"
"cogentcore.org/core/base/sshclient"
"github.com/mitchellh/go-homedir"
)
// Exec handles command execution for all cases, parameterized by the args.
// It executes the given command string, waiting for the command to finish,
// handling the given arguments appropriately.
// If there is any error, it adds it to the goal, and triggers CancelExecution.
// - errOk = don't call AddError so execution will not stop on error
// - start = calls Start on the command, which then runs asynchronously, with
// a goroutine forked to Wait for it and close its IO
// - output = return the output of the command as a string (otherwise return is "")
func (gl *Goal) Exec(errOk, start, output bool, cmd any, args ...any) string {
out := ""
if !errOk && len(gl.Errors) > 0 {
return out
}
cmdIO := exec.NewCmdIO(&gl.Config)
cmdIO.StackStart()
if start {
cmdIO.PushIn(nil) // no stdin for bg
}
cl, scmd, sargs := gl.ExecArgs(cmdIO, errOk, cmd, args...)
if scmd == "" {
return out
}
var err error
if cl != nil {
switch {
case start:
err = cl.Start(&cmdIO.StdIOState, scmd, sargs...)
case output:
cmdIO.PushOut(nil)
out, err = cl.Output(&cmdIO.StdIOState, scmd, sargs...)
default:
err = cl.Run(&cmdIO.StdIOState, scmd, sargs...)
}
if !errOk {
gl.AddError(err)
}
} else {
ran := false
ran, out = gl.RunBuiltinOrCommand(cmdIO, errOk, start, output, scmd, sargs...)
if !ran {
gl.isCommand.Push(false)
switch {
case start:
// fmt.Fprintf(gl.debugTrace, "start exe %s in: %#v out: %#v %v\n ", scmd, cmdIO.In, cmdIO.Out, cmdIO.OutIsPipe())
err = gl.Config.StartIO(cmdIO, scmd, sargs...)
job := &Job{CmdIO: cmdIO}
gl.Jobs.Push(job)
go func() {
if !cmdIO.OutIsPipe() {
fmt.Printf("[%d] %s\n", len(gl.Jobs), cmdIO.String())
}
cmdIO.Cmd.Wait()
cmdIO.PopToStart()
gl.DeleteJob(job)
}()
case output:
cmdIO.PushOut(nil)
out, err = gl.Config.OutputIO(cmdIO, scmd, sargs...)
default:
// fmt.Fprintf(gl.debugTrace, "run exe %s in: %#v out: %#v %v\n ", scmd, cmdIO.In, cmdIO.Out, cmdIO.OutIsPipe())
err = gl.Config.RunIO(cmdIO, scmd, sargs...)
}
if !errOk {
gl.AddError(err)
}
gl.isCommand.Pop()
}
}
if !start {
cmdIO.PopToStart()
}
return out
}
// RunBuiltinOrCommand runs a builtin or a command, returning true if it ran,
// and the output string if running in output mode.
func (gl *Goal) RunBuiltinOrCommand(cmdIO *exec.CmdIO, errOk, start, output bool, cmd string, args ...string) (bool, string) {
out := ""
cmdFun, hasCmd := gl.Commands[cmd]
bltFun, hasBlt := gl.Builtins[cmd]
if !hasCmd && !hasBlt {
return false, out
}
if hasCmd {
gl.commandArgs.Push(args)
gl.isCommand.Push(true)
}
// note: we need to set both os. and wrapper versions, so it works the same
// in compiled vs. interpreted mode
var oldsh, oldwrap, oldstd *exec.StdIO
save := func() {
oldsh = gl.Config.StdIO.Set(&cmdIO.StdIO)
oldwrap = gl.StdIOWrappers.SetWrappers(&cmdIO.StdIO)
oldstd = cmdIO.SetToOS()
}
done := func() {
if hasCmd {
gl.isCommand.Pop()
gl.commandArgs.Pop()
}
// fmt.Fprintf(gl.debugTrace, "%s restore %#v\n", cmd, oldstd.In)
oldstd.SetToOS()
gl.StdIOWrappers.SetWrappers(oldwrap)
gl.Config.StdIO = *oldsh
}
switch {
case start:
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
if !cmdIO.OutIsPipe() {
fmt.Printf("[%d] %s\n", len(gl.Jobs), cmd)
}
if hasCmd {
oldwrap = gl.StdIOWrappers.SetWrappers(&cmdIO.StdIO)
// oldstd = cmdIO.SetToOS()
// fmt.Fprintf(gl.debugTrace, "%s oldstd in: %#v out: %#v\n", cmd, oldstd.In, oldstd.Out)
cmdFun(args...)
// oldstd.SetToOS()
gl.StdIOWrappers.SetWrappers(oldwrap)
gl.isCommand.Pop()
gl.commandArgs.Pop()
} else {
gl.AddError(bltFun(cmdIO, args...))
}
time.Sleep(time.Millisecond)
wg.Done()
}()
// fmt.Fprintf(gl.debugTrace, "%s push: %#v out: %#v %v\n", cmd, cmdIO.In, cmdIO.Out, cmdIO.OutIsPipe())
job := &Job{CmdIO: cmdIO}
gl.Jobs.Push(job)
go func() {
wg.Wait()
cmdIO.PopToStart()
gl.DeleteJob(job)
}()
case output:
save()
obuf := &bytes.Buffer{}
// os.Stdout = obuf // needs a file
gl.Config.StdIO.Out = obuf
gl.StdIOWrappers.SetWrappedOut(obuf)
cmdIO.PushOut(obuf)
if hasCmd {
cmdFun(args...)
} else {
gl.AddError(bltFun(cmdIO, args...))
}
out = strings.TrimSuffix(obuf.String(), "\n")
done()
default:
save()
if hasCmd {
cmdFun(args...)
} else {
gl.AddError(bltFun(cmdIO, args...))
}
done()
}
return true, out
}
func (gl *Goal) HandleArgErr(errok bool, err error) error {
if err == nil {
return err
}
if errok {
gl.Config.StdIO.ErrPrintln(err.Error())
} else {
gl.AddError(err)
}
return err
}
// ExecArgs processes the args to given exec command,
// handling all of the input / output redirection and
// file globbing, homedir expansion, etc.
func (gl *Goal) ExecArgs(cmdIO *exec.CmdIO, errOk bool, cmd any, args ...any) (*sshclient.Client, string, []string) {
if len(gl.Jobs) > 0 {
jb := gl.Jobs.Peek()
if jb.OutIsPipe() && !jb.GotPipe {
jb.GotPipe = true
cmdIO.PushIn(jb.PipeIn.Peek())
}
}
scmd := reflectx.ToString(cmd)
cl := gl.ActiveSSH()
// isCmd := gl.isCommand.Peek()
sargs := make([]string, 0, len(args))
var err error
for _, a := range args {
s := reflectx.ToString(a)
if s == "" {
continue
}
if cl == nil {
s, err = homedir.Expand(s)
gl.HandleArgErr(errOk, err)
// note: handling globbing in a later pass, to not clutter..
} else {
if s[0] == '~' {
s = "$HOME/" + s[1:]
}
}
sargs = append(sargs, s)
}
if scmd[0] == '@' {
newHost := ""
if scmd == "@0" { // local
cl = nil
} else {
hnm := scmd[1:]
if scl, ok := gl.SSHClients[hnm]; ok {
newHost = hnm
cl = scl
} else {
gl.HandleArgErr(errOk, fmt.Errorf("goal: ssh connection named: %q not found", hnm))
}
}
if len(sargs) > 0 {
scmd = sargs[0]
sargs = sargs[1:]
} else { // just a ssh switch
gl.SSHActive = newHost
return nil, "", nil
}
}
for i := 0; i < len(sargs); i++ { // we modify so no range
s := sargs[i]
switch {
case s[0] == '>':
sargs = gl.OutToFile(cl, cmdIO, errOk, sargs, i)
case s[0] == '|':
sargs = gl.OutToPipe(cl, cmdIO, errOk, sargs, i)
case cl == nil && strings.HasPrefix(s, "args"):
sargs = gl.CmdArgs(errOk, sargs, i)
i-- // back up because we consume this one
}
}
// do globbing late here so we don't have to wade through everything.
// only for local.
if cl == nil {
gargs := make([]string, 0, len(sargs))
for _, s := range sargs {
g, err := filepath.Glob(s)
if err != nil || len(g) == 0 { // not valid
gargs = append(gargs, s)
} else {
gargs = append(gargs, g...)
}
}
sargs = gargs
}
return cl, scmd, sargs
}
// OutToFile processes the > arg that sends output to a file
func (gl *Goal) OutToFile(cl *sshclient.Client, cmdIO *exec.CmdIO, errOk bool, sargs []string, i int) []string {
n := len(sargs)
s := sargs[i]
sn := len(s)
fn := ""
narg := 1
if i < n-1 {
fn = sargs[i+1]
narg = 2
}
appn := false
errf := false
switch {
case sn > 1 && s[1] == '>':
appn = true
if sn > 2 && s[2] == '&' {
errf = true
}
case sn > 1 && s[1] == '&':
errf = true
case sn > 1:
fn = s[1:]
narg = 1
}
if fn == "" {
gl.HandleArgErr(errOk, fmt.Errorf("goal: no output file specified"))
return sargs
}
if cl != nil {
if !strings.HasPrefix(fn, "@0:") {
return sargs
}
fn = fn[3:]
}
sargs = slices.Delete(sargs, i, i+narg)
// todo: process @n: expressions here -- if @0 then it is the same
// if @1, then need to launch an ssh "cat >[>] file" with pipe from command as stdin
var f *os.File
var err error
if appn {
f, err = os.OpenFile(fn, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
} else {
f, err = os.Create(fn)
}
if err == nil {
cmdIO.PushOut(f)
if errf {
cmdIO.PushErr(f)
}
} else {
gl.HandleArgErr(errOk, err)
}
return sargs
}
// OutToPipe processes the | arg that sends output to a pipe
func (gl *Goal) OutToPipe(cl *sshclient.Client, cmdIO *exec.CmdIO, errOk bool, sargs []string, i int) []string {
s := sargs[i]
sn := len(s)
errf := false
if sn > 1 && s[1] == '&' {
errf = true
}
sargs = slices.Delete(sargs, i, i+1)
cmdIO.PushOutPipe()
if errf {
cmdIO.PushErr(cmdIO.Out)
}
return sargs
}
// CmdArgs processes expressions involving "args" for commands
func (gl *Goal) CmdArgs(errOk bool, sargs []string, i int) []string {
// n := len(sargs)
// s := sargs[i]
// sn := len(s)
args := gl.commandArgs.Peek()
// fmt.Println("command args:", args)
switch {
case sargs[i] == "args...":
sargs = slices.Delete(sargs, i, i+1)
sargs = slices.Insert(sargs, i, args...)
}
return sargs
}
// CancelExecution calls the Cancel() function if set.
func (gl *Goal) CancelExecution() {
if gl.Cancel != nil {
gl.Cancel()
}
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package goal provides the Goal Go augmented language transpiler,
// which combines the best parts of Go, bash, and Python to provide
// an integrated shell and numerical expression processing experience.
package goal
import (
"context"
"fmt"
"io/fs"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/exec"
"cogentcore.org/core/base/logx"
"cogentcore.org/core/base/reflectx"
"cogentcore.org/core/base/sshclient"
"cogentcore.org/core/base/stack"
"cogentcore.org/lab/goal/transpile"
"github.com/mitchellh/go-homedir"
)
// Goal represents one running Goal language context.
type Goal struct {
// Config is the [exec.Config] used to run commands.
Config exec.Config
// StdIOWrappers are IO wrappers sent to the interpreter, so we can
// control the IO streams used within the interpreter.
// Call SetWrappers on this with another StdIO object to update settings.
StdIOWrappers exec.StdIO
// ssh connection, configuration
SSH *sshclient.Config
// collection of ssh clients
SSHClients map[string]*sshclient.Client
// SSHActive is the name of the active SSH client
SSHActive string
// Builtins are all the builtin shell commands
Builtins map[string]func(cmdIO *exec.CmdIO, args ...string) error
// commands that have been defined, which can be run in Exec mode.
Commands map[string]func(args ...string)
// Jobs is a stack of commands running in the background
// (via Start instead of Run)
Jobs stack.Stack[*Job]
// Cancel, while the interpreter is running, can be called
// to stop the code interpreting.
// It is connected to the Ctx context, by StartContext()
// Both can be nil.
Cancel func()
// Errors is a stack of runtime errors.
Errors []error
// CliArgs are input arguments from the command line.
CliArgs []any
// Ctx is the context used for cancelling current shell running
// a single chunk of code, typically from the interpreter.
// We are not able to pass the context around so it is set here,
// in the StartContext function. Clear when done with ClearContext.
Ctx context.Context
// original standard IO setings, to restore
OrigStdIO exec.StdIO
// Hist is the accumulated list of command-line input,
// which is displayed with the history builtin command,
// and saved / restored from ~/.goalhist file
Hist []string
// transpiling state
TrState transpile.State
// commandArgs is a stack of args passed to a command, used for simplified
// processing of args expressions.
commandArgs stack.Stack[[]string]
// isCommand is a stack of bools indicating whether the _immediate_ run context
// is a command, which affects the way that args are processed.
isCommand stack.Stack[bool]
// debugTrace is a file written to for debugging
debugTrace *os.File
}
// NewGoal returns a new [Goal] with default options.
func NewGoal() *Goal {
gl := &Goal{
Config: exec.Config{
Dir: errors.Log1(os.Getwd()),
Env: map[string]string{},
Buffer: false,
},
}
gl.TrState.FuncToVar = true
gl.Config.StdIO.SetFromOS()
gl.SSH = sshclient.NewConfig(&gl.Config)
gl.SSHClients = make(map[string]*sshclient.Client)
gl.Commands = make(map[string]func(args ...string))
gl.InstallBuiltins()
// gl.debugTrace, _ = os.Create("goal.debug") // debugging
return gl
}
// StartContext starts a processing context,
// setting the Ctx and Cancel Fields.
// Call EndContext when current operation finishes.
func (gl *Goal) StartContext() context.Context {
gl.Ctx, gl.Cancel = context.WithCancel(context.Background())
return gl.Ctx
}
// EndContext ends a processing context, clearing the
// Ctx and Cancel fields.
func (gl *Goal) EndContext() {
gl.Ctx = nil
gl.Cancel = nil
}
// SaveOrigStdIO saves the current Config.StdIO as the original to revert to
// after an error, and sets the StdIOWrappers to use them.
func (gl *Goal) SaveOrigStdIO() {
gl.OrigStdIO = gl.Config.StdIO
gl.StdIOWrappers.NewWrappers(&gl.OrigStdIO)
}
// RestoreOrigStdIO reverts to using the saved OrigStdIO
func (gl *Goal) RestoreOrigStdIO() {
gl.Config.StdIO = gl.OrigStdIO
gl.OrigStdIO.SetToOS()
gl.StdIOWrappers.SetWrappers(&gl.OrigStdIO)
}
// Close closes any resources associated with the shell,
// including terminating any commands that are not running "nohup"
// in the background.
func (gl *Goal) Close() {
gl.CloseSSH()
// todo: kill jobs etc
}
// CloseSSH closes all open ssh client connections
func (gl *Goal) CloseSSH() {
gl.SSHActive = ""
for _, cl := range gl.SSHClients {
cl.Close()
}
gl.SSHClients = make(map[string]*sshclient.Client)
}
// ActiveSSH returns the active ssh client
func (gl *Goal) ActiveSSH() *sshclient.Client {
if gl.SSHActive == "" {
return nil
}
return gl.SSHClients[gl.SSHActive]
}
// Host returns the name we're running commands on,
// which is empty if localhost (default).
func (gl *Goal) Host() string {
cl := gl.ActiveSSH()
if cl == nil {
return ""
}
return "@" + gl.SSHActive + ":" + cl.Host
}
// HostAndDir returns the name we're running commands on,
// which is empty if localhost (default),
// and the current directory on that host.
func (gl *Goal) HostAndDir() string {
host := ""
dir := gl.Config.Dir
home := errors.Log1(homedir.Dir())
cl := gl.ActiveSSH()
if cl != nil {
host = "@" + gl.SSHActive + ":" + cl.Host + ":"
dir = cl.Dir
home = cl.HomeDir
}
rel := errors.Log1(filepath.Rel(home, dir))
// if it has to go back, then it is not in home dir, so no ~
if strings.Contains(rel, "..") {
return host + dir + string(filepath.Separator)
}
return host + filepath.Join("~", rel) + string(filepath.Separator)
}
// SSHByHost returns the SSH client for given host name, with err if not found
func (gl *Goal) SSHByHost(host string) (*sshclient.Client, error) {
if scl, ok := gl.SSHClients[host]; ok {
return scl, nil
}
return nil, fmt.Errorf("ssh connection named: %q not found", host)
}
// TranspileCode processes each line of given code,
// adding the results to the LineStack
func (gl *Goal) TranspileCode(code string) {
gl.TrState.TranspileCode(code)
}
// TranspileCodeFromFile transpiles the code in given file
func (gl *Goal) TranspileCodeFromFile(file string) error {
b, err := os.ReadFile(file)
if err != nil {
return err
}
gl.TranspileCode(string(b))
return nil
}
// TranspileFile transpiles the given input goal file to the
// given output Go file. If no existing package declaration
// is found, then package main and func main declarations are
// added. This also affects how functions are interpreted.
func (gl *Goal) TranspileFile(in string, out string) error {
return gl.TrState.TranspileFile(in, out)
}
// AddError adds the given error to the error stack if it is non-nil,
// and calls the Cancel function if set, to stop execution.
// This is the main way that goal errors are handled.
// It also prints the error.
func (gl *Goal) AddError(err error) error {
if err == nil {
return nil
}
gl.Errors = append(gl.Errors, err)
logx.PrintlnError(err)
gl.CancelExecution()
return err
}
// TranspileConfig transpiles the .goal startup config file in the user's
// home directory if it exists.
func (gl *Goal) TranspileConfig() error {
path, err := homedir.Expand("~/.goal")
if err != nil {
return err
}
b, err := os.ReadFile(path)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil
}
return err
}
gl.TranspileCode(string(b))
return nil
}
// AddHistory adds given line to the Hist record of commands
func (gl *Goal) AddHistory(line string) {
gl.Hist = append(gl.Hist, line)
}
// SaveHistory saves up to the given number of lines of current history
// to given file, e.g., ~/.goalhist for the default goal program.
// If n is <= 0 all lines are saved. n is typically 500 by default.
func (gl *Goal) SaveHistory(n int, file string) error {
path, err := homedir.Expand(file)
if err != nil {
return err
}
hn := len(gl.Hist)
sn := hn
if n > 0 {
sn = min(n, hn)
}
lh := strings.Join(gl.Hist[hn-sn:hn], "\n")
err = os.WriteFile(path, []byte(lh), 0666)
if err != nil {
return err
}
return nil
}
// OpenHistory opens Hist history lines from given file,
// e.g., ~/.goalhist
func (gl *Goal) OpenHistory(file string) error {
path, err := homedir.Expand(file)
if err != nil {
return err
}
b, err := os.ReadFile(path)
if err != nil {
return err
}
gl.Hist = strings.Split(string(b), "\n")
return nil
}
// Args returns the command line arguments.
func (gl *Goal) Args() []any {
return gl.CliArgs
}
// AddCommand adds given command to list of available commands.
func (gl *Goal) AddCommand(name string, cmd func(args ...string)) {
gl.Commands[name] = cmd
}
// RunCommands runs the given command(s). This is typically called
// from a Makefile-style goal script.
func (gl *Goal) RunCommands(cmds []any) error {
for _, cmd := range cmds {
if cmdFun, hasCmd := gl.Commands[reflectx.ToString(cmd)]; hasCmd {
cmdFun()
} else {
return errors.Log(fmt.Errorf("command %q not found", cmd))
}
}
return nil
}
// DeleteAllJobs deletes any existing jobs, closing stdio.
func (gl *Goal) DeleteAllJobs() {
n := len(gl.Jobs)
for i := n - 1; i >= 0; i-- {
jb := gl.Jobs.Pop()
jb.CmdIO.PopToStart()
}
}
// DeleteJob deletes the given job and returns true if successful,
func (gl *Goal) DeleteJob(job *Job) bool {
idx := slices.Index(gl.Jobs, job)
if idx >= 0 {
gl.Jobs = slices.Delete(gl.Jobs, idx, idx+1)
return true
}
return false
}
// JobIDExpand expands %n job id values in args with the full PID
// returns number of PIDs expanded
func (gl *Goal) JobIDExpand(args []string) int {
exp := 0
for i, id := range args {
if id[0] == '%' {
idx, err := strconv.Atoi(id[1:])
if err == nil {
if idx > 0 && idx <= len(gl.Jobs) {
jb := gl.Jobs[idx-1]
if jb.Cmd != nil && jb.Cmd.Process != nil {
args[i] = fmt.Sprintf("%d", jb.Cmd.Process.Pid)
exp++
}
} else {
gl.AddError(fmt.Errorf("goal: job number out of range: %d", idx))
}
}
}
}
return exp
}
// Job represents a job that has been started and we're waiting for it to finish.
type Job struct {
*exec.CmdIO
IsExec bool
GotPipe bool
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package goalib defines convenient utility functions for
// use in the goal shell, available with the goalib prefix.
package goalib
import (
"io/fs"
"os"
"path/filepath"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/base/slicesx"
"cogentcore.org/core/base/stringsx"
)
// SplitLines returns a slice of given string split by lines
// with any extra whitespace trimmed for each line entry.
func SplitLines(s string) []string {
sl := stringsx.SplitLines(s)
for i, s := range sl {
sl[i] = strings.TrimSpace(s)
}
return sl
}
// FileExists returns true if given file exists
func FileExists(path string) bool {
ex := errors.Log1(fsx.FileExists(path))
return ex
}
// WriteFile writes string to given file with standard permissions,
// logging any errors.
func WriteFile(filename, str string) error {
err := os.WriteFile(filename, []byte(str), 0666)
if err != nil {
errors.Log(err)
}
return err
}
// ReadFile reads the string from the given file, logging any errors.
func ReadFile(filename string) string {
str, err := os.ReadFile(filename)
if err != nil {
errors.Log(err)
}
return string(str)
}
// ReplaceInFile replaces all occurrences of given string with replacement
// in given file, rewriting the file. Also returns the updated string.
func ReplaceInFile(filename, old, new string) string {
str := ReadFile(filename)
str = strings.ReplaceAll(str, old, new)
WriteFile(filename, str)
return str
}
// StringsToAnys converts a slice of strings to a slice of any,
// using slicesx.ToAny. The interpreter cannot process generics
// yet, so this wrapper is needed. Use for passing args to
// a command, for example.
func StringsToAnys(s []string) []any {
return slicesx.As[string, any](s)
}
// AllFiles returns a list of all files (excluding directories)
// under the given path.
func AllFiles(path string) []string {
var files []string
filepath.WalkDir(path, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
return nil
}
files = append(files, path)
return nil
})
return files
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package interpreter
import (
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/exec"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/base/logx"
"cogentcore.org/lab/goal"
"cogentcore.org/lab/goal/goalib"
"github.com/cogentcore/yaegi/interp"
)
//go:generate core generate -add-types -add-funcs
// Config is the configuration information for the goal cli.
type Config struct {
// Input is the input file to run/compile.
// If this is provided as the first argument,
// then the program will exit after running,
// unless the Interactive mode is flagged.
Input string `posarg:"0" required:"-"`
// Expr is an optional expression to evaluate, which can be used
// in addition to the Input file to run, to execute commands
// defined within that file for example, or as a command to run
// prior to starting interactive mode if no Input is specified.
Expr string `flag:"e,expr"`
// Args is an optional list of arguments to pass in the run command.
// These arguments will be turned into an "args" local variable in the goal.
// These are automatically processed from any leftover arguments passed, so
// you should not need to specify this flag manually.
Args []string `cmd:"run" posarg:"leftover" required:"-"`
// Interactive runs the interactive command line after processing any input file.
// Interactive mode is the default mode for the run command unless an input file
// is specified.
Interactive bool `cmd:"run" flag:"i,interactive"`
// InteractiveFunc is the function to run in interactive mode.
// set it to your own function as needed.
InteractiveFunc func(c *Config, in *Interpreter) error
}
// Run runs the specified goal file. If no file is specified,
// it runs an interactive shell that allows the user to input goal.
func Run(c *Config) error { //cli:cmd -root
in := NewInterpreter(interp.Options{})
if len(c.Args) > 0 {
in.Goal.CliArgs = goalib.StringsToAnys(c.Args)
}
if c.Input == "" {
return c.InteractiveFunc(c, in)
}
in.Config()
code := ""
if errors.Log1(fsx.FileExists(c.Input)) {
b, err := os.ReadFile(c.Input)
if err != nil && c.Expr == "" {
return err
}
code = string(b)
}
if c.Expr != "" {
if code != "" {
code += "\n"
}
code += c.Expr + "\n"
}
_, _, err := in.Eval(code)
if err == nil {
err = in.Goal.TrState.DepthError()
}
if c.Interactive {
return c.InteractiveFunc(c, in)
}
return err
}
// Interactive runs an interactive shell that allows the user to input goal.
func Interactive(c *Config, in *Interpreter) error {
in.Config()
if c.Expr != "" {
in.Eval(c.Expr)
}
in.Interactive()
return nil
}
// Build builds the specified input goal file, or all .goal files in the current
// directory if no input is specified, to corresponding .go file name(s).
// If the file does not already contain a "package" specification, then
// "package main; func main()..." wrappers are added, which allows the same
// code to be used in interactive and Go compiled modes.
// go build is run after this.
func Build(c *Config) error {
var fns []string
verbose := logx.UserLevel <= slog.LevelInfo
if c.Input != "" {
fns = []string{c.Input}
} else {
fns = fsx.Filenames(".", ".goal")
}
curpkg, _ := exec.Minor().Output("go", "list", "./")
var errs []error
for _, fn := range fns {
fpath := filepath.Join(curpkg, fn)
if verbose {
fmt.Println(fpath)
}
ofn := strings.TrimSuffix(fn, filepath.Ext(fn)) + ".go"
err := goal.NewGoal().TranspileFile(fn, ofn)
if err != nil {
errs = append(errs, err)
}
}
// cfg := &gotosl.Config{}
// cfg.Debug = verbose
// err := gotosl.Run(cfg)
// if err != nil {
// errs = append(errs, err)
// }
args := []string{"build"}
if verbose {
args = append(args, "-v")
}
err := exec.Verbose().Run("go", args...)
if err != nil {
errs = append(errs, err)
}
return errors.Join(errs...)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package interpreter
import (
"reflect"
"github.com/cogentcore/yaegi/interp"
)
var Symbols = map[string]map[string]reflect.Value{}
// ImportGoal imports special symbols from the goal package.
func (in *Interpreter) ImportGoal() {
in.Interp.Use(interp.Exports{
"cogentcore.org/lab/goal/goal": map[string]reflect.Value{
"Run": reflect.ValueOf(in.Goal.Run),
"RunErrOK": reflect.ValueOf(in.Goal.RunErrOK),
"Output": reflect.ValueOf(in.Goal.Output),
"OutputErrOK": reflect.ValueOf(in.Goal.OutputErrOK),
"Start": reflect.ValueOf(in.Goal.Start),
"AddCommand": reflect.ValueOf(in.Goal.AddCommand),
"RunCommands": reflect.ValueOf(in.Goal.RunCommands),
"Args": reflect.ValueOf(in.Goal.Args),
},
})
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package interpreter
import (
"context"
"fmt"
"io"
"log"
"os"
"os/signal"
"reflect"
"strconv"
"strings"
"syscall"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/yaegicore/basesymbols"
"cogentcore.org/lab/goal"
_ "cogentcore.org/lab/stats/metric"
_ "cogentcore.org/lab/stats/stats"
_ "cogentcore.org/lab/tensor/tmath"
"cogentcore.org/lab/tensorfs"
"cogentcore.org/lab/yaegilab/nogui"
"github.com/cogentcore/yaegi/interp"
"github.com/cogentcore/yaegi/stdlib"
"github.com/ergochat/readline"
)
// Interpreter represents one running shell context
type Interpreter struct {
// the goal shell
Goal *goal.Goal
// HistFile is the name of the history file to open / save.
// Defaults to ~/.goal-history for the default goal shell.
// Update this prior to running Config() to take effect.
HistFile string
// the yaegi interpreter
Interp *interp.Interpreter
}
func init() {
delete(stdlib.Symbols, "errors/errors") // use our errors package instead
}
// NewInterpreter returns a new [Interpreter] initialized with the given options.
// It automatically imports the standard library and configures necessary shell
// functions. End user app must call [Interp.Config] after importing any additional
// symbols, prior to running the interpreter.
func NewInterpreter(options interp.Options) *Interpreter {
in := &Interpreter{HistFile: "~/.goal-history"}
in.Goal = goal.NewGoal()
if options.Stdin != nil {
in.Goal.Config.StdIO.In = options.Stdin
}
if options.Stdout != nil {
in.Goal.Config.StdIO.Out = options.Stdout
}
if options.Stderr != nil {
in.Goal.Config.StdIO.Err = options.Stderr
}
in.Goal.SaveOrigStdIO()
options.Stdout = in.Goal.StdIOWrappers.Out
options.Stderr = in.Goal.StdIOWrappers.Err
options.Stdin = in.Goal.StdIOWrappers.In
in.Interp = interp.New(options)
errors.Log(in.Interp.Use(basesymbols.Symbols))
errors.Log(in.Interp.Use(nogui.Symbols))
in.ImportGoal()
go in.MonitorSignals()
return in
}
// Prompt returns the appropriate REPL prompt to show the user.
func (in *Interpreter) Prompt() string {
dp := in.Goal.TrState.TotalDepth()
pc := ">"
dir := in.Goal.HostAndDir()
if in.Goal.TrState.MathMode {
pc = "#"
dir = tensorfs.CurDir.Path()
}
if dp == 0 {
return dir + " " + pc + " "
}
res := pc + " "
for range dp {
res += " " // note: /t confuses readline
}
return res
}
// Eval evaluates (interprets) the given code,
// returning the value returned from the interpreter.
// HasPrint indicates whether the last line of code
// has the string print in it, which is for determining
// whether to print the result in interactive mode.
// It automatically logs any error in addition to returning it.
func (in *Interpreter) Eval(code string) (v reflect.Value, hasPrint bool, err error) {
in.Goal.TranspileCode(code)
source := false
if in.Goal.SSHActive == "" {
source = strings.HasPrefix(code, "source")
}
if in.Goal.TrState.TotalDepth() == 0 {
nl := len(in.Goal.TrState.Lines)
if nl > 0 {
ln := in.Goal.TrState.Lines[nl-1]
if strings.Contains(strings.ToLower(ln), "print") {
hasPrint = true
}
}
v, err = in.RunCode()
in.Goal.Errors = nil
}
if source {
v, err = in.RunCode() // run accumulated code
}
return
}
// RunCode runs the accumulated set of code lines
// and clears the stack of code lines.
// It automatically logs any error in addition to returning it.
func (in *Interpreter) RunCode() (reflect.Value, error) {
if len(in.Goal.Errors) > 0 {
return reflect.Value{}, errors.Join(in.Goal.Errors...)
}
in.Goal.TrState.AddChunk()
code := in.Goal.TrState.Chunks
in.Goal.TrState.ResetCode()
var v reflect.Value
var err error
for _, ch := range code {
ctx := in.Goal.StartContext()
v, err = in.Interp.EvalWithContext(ctx, ch)
in.Goal.EndContext()
if err != nil {
cancelled := errors.Is(err, context.Canceled)
// fmt.Println("cancelled:", cancelled)
in.Goal.DeleteAllJobs()
in.Goal.RestoreOrigStdIO()
in.Goal.TrState.ResetDepth()
if !cancelled {
in.Goal.AddError(err)
} else {
in.Goal.Errors = nil
}
break
}
}
return v, err
}
// RunConfig runs the .goal startup config file in the user's
// home directory if it exists.
func (in *Interpreter) RunConfig() error {
err := in.Goal.TranspileConfig()
if err != nil {
errors.Log(err)
}
_, err = in.RunCode()
return err
}
// MonitorSignals monitors the operating system signals to appropriately
// stop the interpreter and prevent the shell from closing on Control+C.
// It is called automatically in another goroutine in [NewInterpreter].
func (in *Interpreter) MonitorSignals() {
c := make(chan os.Signal, 1)
// todo: syscall.SIGSEGV not defined on web
signal.Notify(c, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
// signal.Notify(c, os.Interrupt, syscall.SIGTERM, syscall.SIGINT, syscall.SIGSEGV)
for {
<-c
in.Goal.CancelExecution()
}
}
// Config performs final configuration after all the imports have been Use'd
func (in *Interpreter) Config() {
in.Interp.ImportUsed()
in.RunConfig()
}
// OpenHistory opens history from the current HistFile
// and loads it into the readline history for given rl instance
func (in *Interpreter) OpenHistory(rl *readline.Instance) error {
err := in.Goal.OpenHistory(in.HistFile)
if err == nil {
for _, h := range in.Goal.Hist {
rl.SaveToHistory(h)
}
}
return err
}
// SaveHistory saves last 500 (or HISTFILESIZE env value) lines of history,
// to the current HistFile.
func (in *Interpreter) SaveHistory() error {
n := 500
if hfs := os.Getenv("HISTFILESIZE"); hfs != "" {
en, err := strconv.Atoi(hfs)
if err != nil {
in.Goal.Config.StdIO.ErrPrintf("SaveHistory: environment variable HISTFILESIZE: %q not a number: %s", hfs, err.Error())
} else {
n = en
}
}
return in.Goal.SaveHistory(n, in.HistFile)
}
// Interactive runs an interactive shell that allows the user to input goal.
// Must have done in.Config() prior to calling.
func (in *Interpreter) Interactive() error {
in.Goal.TrState.MathRecord = true
rl, err := readline.NewFromConfig(&readline.Config{
AutoComplete: &goal.ReadlineCompleter{Goal: in.Goal},
Undo: true,
})
if err != nil {
return err
}
in.OpenHistory(rl)
defer rl.Close()
log.SetOutput(rl.Stderr()) // redraw the prompt correctly after log output
for {
rl.SetPrompt(in.Prompt())
line, err := rl.ReadLine()
if errors.Is(err, readline.ErrInterrupt) {
continue
}
if errors.Is(err, io.EOF) {
in.SaveHistory()
os.Exit(0)
}
if err != nil {
in.SaveHistory()
return err
}
if len(line) > 0 && line[0] == '!' { // history command
hl, err := strconv.Atoi(line[1:])
nh := len(in.Goal.Hist)
if err != nil {
in.Goal.Config.StdIO.ErrPrintf("history number: %q not a number: %s", line[1:], err.Error())
line = ""
} else if hl >= nh {
in.Goal.Config.StdIO.ErrPrintf("history number: %d not in range: [0:%d]", hl, nh)
line = ""
} else {
line = in.Goal.Hist[hl]
fmt.Printf("h:%d\t%s\n", hl, line)
}
} else if line != "" && !strings.HasPrefix(line, "history") && line != "h" {
in.Goal.AddHistory(line)
}
in.Goal.Errors = nil
v, hasPrint, err := in.Eval(line)
if err == nil && !hasPrint && v.IsValid() && !v.IsZero() && v.Kind() != reflect.Func {
fmt.Println(v.Interface())
}
}
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package goal
// Run executes the given command string, waiting for the command to finish,
// handling the given arguments appropriately.
// If there is any error, it adds it to the goal, and triggers CancelExecution.
// It forwards output to [exec.Config.Stdout] and [exec.Config.Stderr] appropriately.
func (gl *Goal) Run(cmd any, args ...any) {
gl.Exec(false, false, false, cmd, args...)
}
// RunErrOK executes the given command string, waiting for the command to finish,
// handling the given arguments appropriately.
// It does not stop execution if there is an error.
// If there is any error, it adds it to the goal. It forwards output to
// [exec.Config.Stdout] and [exec.Config.Stderr] appropriately.
func (gl *Goal) RunErrOK(cmd any, args ...any) {
gl.Exec(true, false, false, cmd, args...)
}
// Start starts the given command string for running in the background,
// handling the given arguments appropriately.
// If there is any error, it adds it to the goal. It forwards output to
// [exec.Config.Stdout] and [exec.Config.Stderr] appropriately.
func (gl *Goal) Start(cmd any, args ...any) {
gl.Exec(false, true, false, cmd, args...)
}
// Output executes the given command string, handling the given arguments
// appropriately. If there is any error, it adds it to the goal. It returns
// the stdout as a string and forwards stderr to [exec.Config.Stderr] appropriately.
func (gl *Goal) Output(cmd any, args ...any) string {
return gl.Exec(false, false, true, cmd, args...)
}
// OutputErrOK executes the given command string, handling the given arguments
// appropriately. If there is any error, it adds it to the goal. It returns
// the stdout as a string and forwards stderr to [exec.Config.Stderr] appropriately.
func (gl *Goal) OutputErrOK(cmd any, args ...any) string {
return gl.Exec(true, false, true, cmd, args...)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package transpile
import (
"path"
"reflect"
"strings"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/yaegilab/nogui"
)
func init() {
AddYaegiTensorFuncs()
}
var yaegiTensorPackages = []string{"/lab/tensor", "/lab/stats", "/lab/vector", "/lab/matrix"}
// AddYaegiTensorFuncs grabs all tensor* package functions registered
// in yaegicore and adds them to the `tensor.Funcs` map so we can
// properly convert symbols to either tensors or basic literals,
// depending on the arg types for the current function.
func AddYaegiTensorFuncs() {
for pth, symap := range nogui.Symbols {
has := false
for _, p := range yaegiTensorPackages {
if strings.Contains(pth, p) {
has = true
break
}
}
if !has {
continue
}
_, pkg := path.Split(pth)
for name, val := range symap {
if val.Kind() != reflect.Func {
continue
}
pnm := pkg + "." + name
tensor.AddFunc(pnm, val.Interface())
}
}
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package transpile
import (
"errors"
"go/token"
)
var tensorfsCommands = map[string]func(mp *mathParse) error{
"cd": cd,
"mkdir": mkdir,
"ls": ls,
}
func cd(mp *mathParse) error {
var dir string
if len(mp.ewords) > 1 {
dir = mp.ewords[1]
}
mp.out.Add(token.IDENT, "tensorfs.Chdir")
mp.out.Add(token.LPAREN)
mp.out.Add(token.STRING, `"`+dir+`"`)
mp.out.Add(token.RPAREN)
return nil
}
func mkdir(mp *mathParse) error {
if len(mp.ewords) == 1 {
return errors.New("tensorfs mkdir requires a directory name")
}
dir := mp.ewords[1]
mp.out.Add(token.IDENT, "tensorfs.Mkdir")
mp.out.Add(token.LPAREN)
mp.out.Add(token.STRING, `"`+dir+`"`)
mp.out.Add(token.RPAREN)
return nil
}
func ls(mp *mathParse) error {
mp.out.Add(token.IDENT, "tensorfs.List")
mp.out.Add(token.LPAREN)
for i := 1; i < len(mp.ewords); i++ {
mp.out.Add(token.STRING, `"`+mp.ewords[i]+`"`)
}
mp.out.Add(token.RPAREN)
return nil
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package transpile
import (
"fmt"
"strings"
"unicode"
)
func ExecWords(ln string) ([]string, error) {
ln = strings.TrimSpace(ln)
n := len(ln)
if n == 0 {
return nil, nil
}
if ln[0] == '$' {
ln = strings.TrimSpace(ln[1:])
n = len(ln)
if n == 0 {
return nil, nil
}
if ln[n-1] == '$' {
ln = strings.TrimSpace(ln[:n-1])
n = len(ln)
if n == 0 {
return nil, nil
}
}
}
word := ""
esc := false
dQuote := false
bQuote := false
brace := 0
brack := 0
redir := false
var words []string
addWord := func() {
if brace > 0 { // always accum into one token inside brace
return
}
if len(word) > 0 {
words = append(words, word)
word = ""
}
}
atStart := true
sbrack := (ln[0] == '[')
if sbrack {
word = "["
addWord()
brack++
ln = ln[1:]
atStart = false
}
for _, r := range ln {
quote := dQuote || bQuote
if redir {
redir = false
if r == '&' {
word += string(r)
addWord()
continue
}
if r == '>' {
word += string(r)
redir = true
continue
}
addWord()
}
switch {
case esc:
if brace == 0 && unicode.IsSpace(r) { // we will be quoted later anyway
word = word[:len(word)-1]
}
word += string(r)
esc = false
case r == '\\':
esc = true
word += string(r)
case r == '"':
if !bQuote {
dQuote = !dQuote
}
word += string(r)
case r == '`':
if !dQuote {
bQuote = !bQuote
}
word += string(r)
case quote: // absorbs quote -- no need to check below
word += string(r)
case unicode.IsSpace(r):
addWord()
continue // don't reset at start
case r == '{':
if brace == 0 {
addWord()
word = "{"
addWord()
}
brace++
case r == '}':
brace--
if brace == 0 {
addWord()
word = "}"
addWord()
}
case r == '[':
word += string(r)
if atStart && brack == 0 {
sbrack = true
addWord()
}
brack++
case r == ']':
brack--
if brack == 0 && sbrack { // only point of tracking brack is to get this end guy
addWord()
word = "]"
addWord()
} else {
word += string(r)
}
case r == '<' || r == '>' || r == '|':
addWord()
word += string(r)
redir = true
case r == '&': // known to not be redir
addWord()
word += string(r)
case r == ';':
addWord()
word += string(r)
addWord()
atStart = true
continue // avoid reset
default:
word += string(r)
}
atStart = false
}
addWord()
if dQuote || bQuote || brack > 0 {
return words, fmt.Errorf("goal: exec command has unterminated quotes (\": %v, `: %v) or brackets [ %v ]", dQuote, bQuote, brack > 0)
}
return words, nil
}
// ExecWordIsCommand returns true if given exec word is a command-like string
// (excluding any paths)
func ExecWordIsCommand(f string) bool {
if strings.Contains(f, "(") || strings.Contains(f, "[") || strings.Contains(f, "=") {
return false
}
return true
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package transpile
import (
"fmt"
"go/ast"
"go/token"
"strings"
"cogentcore.org/core/base/stack"
"cogentcore.org/lab/tensor"
)
// TranspileMath does math mode transpiling. fullLine indicates code should be
// full statement(s).
func (st *State) TranspileMath(toks Tokens, code string, fullLine bool) Tokens {
nt := len(toks)
if nt == 0 {
return nil
}
// fmt.Println(nt, toks)
str := code[toks[0].Pos-1 : toks[nt-1].Pos]
if toks[nt-1].Str != "" {
str += toks[nt-1].Str[1:]
}
// fmt.Println(str)
mp := mathParse{state: st, toks: toks, code: code}
// mp.trace = true
mods := AllErrors // | Trace
if fullLine {
ewords, err := ExecWords(str)
if len(ewords) > 0 {
if cmd, ok := tensorfsCommands[ewords[0]]; ok {
mp.ewords = ewords
err := cmd(&mp)
if err != nil {
fmt.Println(ewords[0]+":", err.Error())
return nil
} else {
return mp.out
}
}
}
stmts, err := ParseLine(str, mods)
if err != nil {
fmt.Println("line code:", str)
fmt.Println("parse err:", err)
}
if len(stmts) == 0 {
return toks
}
mp.stmtList(stmts)
} else {
ex, err := ParseExpr(str, mods)
if err != nil {
fmt.Println("expr:", str)
fmt.Println("parse err:", err)
}
mp.expr(ex)
}
if mp.idx != len(toks) {
fmt.Println(code)
fmt.Println(mp.out.Code())
fmt.Printf("parsing error: index: %d != len(toks): %d\n", mp.idx, len(toks))
}
return mp.out
}
// funcInfo is info about the function being processed
type funcInfo struct {
tensor.Func
// current arg index we are processing
curArg int
}
// mathParse has the parsing state, active only during a parsing pass
// on one specific chunk of code and tokens.
type mathParse struct {
state *State
code string // code string
toks Tokens // source tokens we are parsing
ewords []string // exec words
idx int // current index in source tokens -- critical to sync as we "use" source
out Tokens // output tokens we generate
trace bool // trace of parsing -- turn on to see alignment
// stack of function info -- top of stack reflects the current function
funcs stack.Stack[*funcInfo]
}
// returns the current argument for current function
func (mp *mathParse) curArg() *tensor.Arg {
cfun := mp.funcs.Peek()
if cfun == nil {
return nil
}
if cfun.curArg < len(cfun.Args) {
return cfun.Args[cfun.curArg]
}
return nil
}
func (mp *mathParse) nextArg() {
cfun := mp.funcs.Peek()
if cfun == nil || len(cfun.Args) == 0 {
// fmt.Println("next arg no fun or no args")
return
}
n := len(cfun.Args)
if cfun.curArg == n-1 {
carg := cfun.Args[n-1]
if !carg.IsVariadic {
fmt.Println("math transpile: args exceed registered function number:", cfun)
}
return
}
cfun.curArg++
}
func (mp *mathParse) curArgIsTensor() bool {
carg := mp.curArg()
if carg == nil {
return false
}
return carg.IsTensor
}
func (mp *mathParse) curArgIsInts() bool {
carg := mp.curArg()
if carg == nil {
return false
}
return carg.IsInt && carg.IsVariadic
}
// startFunc is called when starting a new function.
// empty is "dummy" assign case using Inc.
// optional noLookup indicates to not lookup type and just
// push the name -- for internal cases to prevent arg conversions.
func (mp *mathParse) startFunc(name string, noLookup ...bool) *funcInfo {
fi := &funcInfo{}
sname := name
if name == "" || name == "tensor.Tensor" {
sname = "tmath.Inc" // one arg tensor fun
}
if len(noLookup) == 1 && noLookup[0] {
fi.Name = name
} else {
if tf, err := tensor.FuncByName(sname); err == nil {
fi.Func = *tf
} else {
fi.Name = name
}
}
mp.funcs.Push(fi)
if name != "" {
mp.out.Add(token.IDENT, name)
}
return fi
}
func (mp *mathParse) endFunc() {
mp.funcs.Pop()
}
// addToken adds output token and increments idx
func (mp *mathParse) addToken(tok token.Token) {
mp.out.Add(tok)
if mp.trace {
ctok := &Token{}
if mp.idx < len(mp.toks) {
ctok = mp.toks[mp.idx]
}
fmt.Printf("%d\ttok: %s \t replaces: %s\n", mp.idx, tok, ctok)
}
mp.idx++
}
func (mp *mathParse) addCur() {
if len(mp.toks) > mp.idx {
mp.out.AddTokens(mp.toks[mp.idx])
mp.idx++
return
}
fmt.Println("out of tokens!", mp.idx, mp.toks)
}
func (mp *mathParse) stmtList(sts []ast.Stmt) {
for _, st := range sts {
mp.stmt(st)
}
}
func (mp *mathParse) stmt(st ast.Stmt) {
if st == nil {
return
}
switch x := st.(type) {
case *ast.BadStmt:
fmt.Println("bad stmt!")
case *ast.DeclStmt:
case *ast.ExprStmt:
mp.expr(x.X)
case *ast.SendStmt:
mp.expr(x.Chan)
mp.addToken(token.ARROW)
mp.expr(x.Value)
case *ast.IncDecStmt:
fn := "Inc"
if x.Tok == token.DEC {
fn = "Dec"
}
mp.startFunc("tmath." + fn)
mp.out.Add(token.LPAREN)
mp.expr(x.X)
mp.addToken(token.RPAREN)
case *ast.AssignStmt:
switch x.Tok {
case token.DEFINE:
mp.defineStmt(x)
default:
mp.assignStmt(x)
}
case *ast.GoStmt:
mp.addToken(token.GO)
mp.callExpr(x.Call)
case *ast.DeferStmt:
mp.addToken(token.DEFER)
mp.callExpr(x.Call)
case *ast.ReturnStmt:
mp.addToken(token.RETURN)
mp.exprList(x.Results)
case *ast.BranchStmt:
mp.addToken(x.Tok)
mp.ident(x.Label)
case *ast.BlockStmt:
mp.addToken(token.LBRACE)
mp.stmtList(x.List)
mp.addToken(token.RBRACE)
case *ast.IfStmt:
mp.addToken(token.IF)
mp.stmt(x.Init)
if x.Init != nil {
mp.addToken(token.SEMICOLON)
}
mp.expr(x.Cond)
mp.out.Add(token.IDENT, ".Bool1D(0)") // turn bool expr into actual bool
if x.Body != nil && len(x.Body.List) > 0 {
mp.addToken(token.LBRACE)
mp.stmtList(x.Body.List)
mp.addToken(token.RBRACE)
} else {
mp.addToken(token.LBRACE)
}
if x.Else != nil {
mp.addToken(token.ELSE)
mp.stmt(x.Else)
}
case *ast.ForStmt:
mp.addToken(token.FOR)
mp.stmt(x.Init)
if x.Init != nil {
mp.addToken(token.SEMICOLON)
}
mp.expr(x.Cond)
if x.Cond != nil {
mp.out.Add(token.IDENT, ".Bool1D(0)") // turn bool expr into actual bool
mp.addToken(token.SEMICOLON)
}
mp.stmt(x.Post)
if x.Body != nil && len(x.Body.List) > 0 {
mp.addToken(token.LBRACE)
mp.stmtList(x.Body.List)
mp.addToken(token.RBRACE)
} else {
mp.addToken(token.LBRACE)
}
case *ast.RangeStmt:
if x.Key == nil || x.Value == nil {
fmt.Println("for range statement requires both index and value variables")
return
}
ki, _ := x.Key.(*ast.Ident)
vi, _ := x.Value.(*ast.Ident)
ei, _ := x.X.(*ast.Ident)
if ki == nil || vi == nil || ei == nil {
fmt.Println("for range statement requires all variables (index, value, range) to be variable names, not other expressions")
return
}
knm := ki.Name
vnm := vi.Name
enm := ei.Name
mp.addToken(token.FOR)
mp.expr(x.Key)
mp.idx += 2
mp.addToken(token.DEFINE)
mp.out.Add(token.IDENT, "0")
mp.out.Add(token.SEMICOLON)
mp.out.Add(token.IDENT, knm)
mp.out.Add(token.IDENT, "<")
mp.out.Add(token.IDENT, enm)
mp.out.Add(token.PERIOD)
mp.out.Add(token.IDENT, "Len")
mp.idx++
mp.out.AddMulti(token.LPAREN, token.RPAREN)
mp.idx++
mp.out.Add(token.SEMICOLON)
mp.idx++
mp.out.Add(token.IDENT, knm)
mp.out.AddMulti(token.INC, token.LBRACE)
mp.out.Add(token.IDENT, vnm)
mp.out.Add(token.DEFINE)
mp.out.Add(token.IDENT, enm)
mp.out.Add(token.IDENT, ".Float1D")
mp.out.Add(token.LPAREN)
mp.out.Add(token.IDENT, knm)
mp.out.Add(token.RPAREN)
if x.Body != nil && len(x.Body.List) > 0 {
mp.stmtList(x.Body.List)
mp.addToken(token.RBRACE)
}
// TODO
// CaseClause: SwitchStmt:, TypeSwitchStmt:, CommClause:, SelectStmt:
}
}
func (mp *mathParse) expr(ex ast.Expr) {
if ex == nil {
return
}
switch x := ex.(type) {
case *ast.BadExpr:
fmt.Println("bad expr!")
case *ast.Ident:
mp.ident(x)
case *ast.UnaryExpr:
mp.unaryExpr(x)
case *ast.Ellipsis:
cfun := mp.funcs.Peek()
if cfun != nil && cfun.Name == "tensor.Reslice" {
mp.out.Add(token.IDENT, "tensor.Ellipsis")
mp.idx++
} else {
mp.addToken(token.ELLIPSIS)
}
case *ast.StarExpr:
mp.addToken(token.MUL)
mp.expr(x.X)
case *ast.BinaryExpr:
mp.binaryExpr(x)
case *ast.BasicLit:
mp.basicLit(x)
case *ast.FuncLit:
case *ast.ParenExpr:
mp.addToken(token.LPAREN)
mp.expr(x.X)
mp.addToken(token.RPAREN)
case *ast.SelectorExpr:
mp.selectorExpr(x)
case *ast.TypeAssertExpr:
case *ast.IndexExpr:
mp.indexExpr(x)
case *ast.IndexListExpr:
if x.X == nil { // array literal
mp.arrayLiteral(x)
} else {
mp.indexListExpr(x)
}
case *ast.SliceExpr:
mp.sliceExpr(x)
case *ast.CallExpr:
mp.callExpr(x)
case *ast.ArrayType:
// note: shouldn't happen normally:
fmt.Println("array type:", x, x.Len)
fmt.Printf("%#v\n", x.Len)
}
}
func (mp *mathParse) exprList(ex []ast.Expr) {
n := len(ex)
if n == 0 {
return
}
if n == 1 {
mp.expr(ex[0])
return
}
for i := range n {
mp.expr(ex[i])
if i < n-1 {
mp.addToken(token.COMMA)
}
}
}
func (mp *mathParse) argsList(ex []ast.Expr) {
n := len(ex)
if n == 0 {
return
}
if n == 1 {
mp.expr(ex[0])
return
}
for i := range n {
// cfun := mp.funcs.Peek()
// if i != cfun.curArg {
// fmt.Println(cfun, "arg should be:", i, "is:", cfun.curArg)
// }
mp.expr(ex[i])
if i < n-1 {
mp.nextArg()
mp.addToken(token.COMMA)
}
}
}
func (mp *mathParse) exprIsBool(ex ast.Expr) bool {
switch x := ex.(type) {
case *ast.BinaryExpr:
if (x.Op >= token.EQL && x.Op <= token.GTR) || (x.Op >= token.NEQ && x.Op <= token.GEQ) {
return true
}
case *ast.ParenExpr:
return mp.exprIsBool(x.X)
}
return false
}
func (mp *mathParse) exprsAreBool(ex []ast.Expr) bool {
for _, x := range ex {
if mp.exprIsBool(x) {
return true
}
}
return false
}
func (mp *mathParse) binaryExpr(ex *ast.BinaryExpr) {
if ex.Op == token.ILLEGAL { // @ = matmul
mp.startFunc("matrix.Mul")
mp.out.Add(token.LPAREN)
mp.expr(ex.X)
mp.out.Add(token.COMMA)
mp.idx++
mp.expr(ex.Y)
mp.out.Add(token.RPAREN)
mp.endFunc()
return
}
fn := ""
switch ex.Op {
case token.ADD:
fn = "Add"
case token.SUB:
fn = "Sub"
case token.MUL:
fn = "Mul"
if un, ok := ex.Y.(*ast.StarExpr); ok { // ** power operator
ex.Y = un.X
fn = "Pow"
}
case token.QUO:
fn = "Div"
case token.EQL:
fn = "Equal"
case token.LSS:
fn = "Less"
case token.GTR:
fn = "Greater"
case token.NEQ:
fn = "NotEqual"
case token.LEQ:
fn = "LessEqual"
case token.GEQ:
fn = "GreaterEqual"
case token.LOR:
fn = "Or"
case token.LAND:
fn = "And"
default:
fmt.Println("binary token:", ex.Op)
}
mp.startFunc("tmath." + fn)
mp.out.Add(token.LPAREN)
mp.expr(ex.X)
mp.out.Add(token.COMMA)
mp.idx++
if fn == "Pow" {
mp.idx++
}
mp.expr(ex.Y)
mp.out.Add(token.RPAREN)
mp.endFunc()
}
func (mp *mathParse) unaryExpr(ex *ast.UnaryExpr) {
if _, isbl := ex.X.(*ast.BasicLit); isbl {
mp.addToken(ex.Op)
mp.expr(ex.X)
return
}
fn := ""
switch ex.Op {
case token.NOT:
fn = "Not"
case token.SUB:
fn = "Negate"
case token.ADD:
mp.expr(ex.X)
return
default: // * goes to StarExpr -- not sure what else could happen here?
mp.addToken(ex.Op)
mp.expr(ex.X)
return
}
mp.startFunc("tmath." + fn)
mp.addToken(token.LPAREN)
mp.expr(ex.X)
mp.out.Add(token.RPAREN)
mp.endFunc()
}
func (mp *mathParse) defineStmt(as *ast.AssignStmt) {
firstStmt := mp.idx == 0
mp.exprList(as.Lhs)
mp.addToken(as.Tok)
mp.startFunc("tensor.Tensor")
mp.out.Add(token.LPAREN)
mp.exprList(as.Rhs)
mp.out.Add(token.RPAREN)
mp.endFunc()
if firstStmt && mp.state.MathRecord {
nvar, ok := as.Lhs[0].(*ast.Ident)
if ok {
mp.out.Add(token.SEMICOLON)
mp.out.Add(token.IDENT, "tensorfs.Record("+nvar.Name+",`"+nvar.Name+"`)")
}
}
}
func (mp *mathParse) assignStmt(as *ast.AssignStmt) {
if as.Tok == token.ASSIGN {
if _, ok := as.Lhs[0].(*ast.Ident); ok {
mp.exprList(as.Lhs)
mp.addToken(as.Tok)
mp.startFunc("")
mp.exprList(as.Rhs)
mp.endFunc()
return
}
}
fn := ""
switch as.Tok {
case token.ASSIGN:
fn = "Assign"
case token.ADD_ASSIGN:
fn = "AddAssign"
case token.SUB_ASSIGN:
fn = "SubAssign"
case token.MUL_ASSIGN:
fn = "MulAssign"
case token.QUO_ASSIGN:
fn = "DivAssign"
}
mp.startFunc("tmath." + fn)
mp.out.Add(token.LPAREN)
mp.exprList(as.Lhs)
mp.out.Add(token.COMMA)
mp.idx++
mp.exprList(as.Rhs)
mp.out.Add(token.RPAREN)
mp.endFunc()
}
func (mp *mathParse) basicLit(lit *ast.BasicLit) {
if mp.curArgIsTensor() {
mp.tensorLit(lit)
return
}
mp.out.Add(lit.Kind, lit.Value)
if mp.trace {
fmt.Printf("%d\ttok: %s literal\n", mp.idx, lit.Value)
}
mp.idx++
return
}
func (mp *mathParse) tensorLit(lit *ast.BasicLit) {
switch lit.Kind {
case token.INT:
mp.out.Add(token.IDENT, "tensor.NewIntScalar("+lit.Value+")")
mp.idx++
case token.FLOAT:
mp.out.Add(token.IDENT, "tensor.NewFloat64Scalar("+lit.Value+")")
mp.idx++
case token.STRING:
mp.out.Add(token.IDENT, "tensor.NewStringScalar("+lit.Value+")")
mp.idx++
}
}
// funWrap is a function wrapper for simple numpy property / functions
type funWrap struct {
fun string // function to call on tensor
wrap string // code for wrapping function for results of call
}
// nis: NewIntScalar, niv: NewIntFromValues, etc
var numpyProps = map[string]funWrap{
"ndim": {"NumDims()", "nis"},
"len": {"Len()", "nis"},
"size": {"Len()", "nis"},
"shape": {"Shape().Sizes", "niv"},
"T": {"", "tensor.Transpose"},
}
// tensorFunc outputs the wrapping function and whether it needs ellipsis
func (fw *funWrap) wrapFunc(mp *mathParse) bool {
ellip := false
wrapFun := fw.wrap
switch fw.wrap {
case "nis":
wrapFun = "tensor.NewIntScalar"
case "nfs":
wrapFun = "tensor.NewFloat64Scalar"
case "nss":
wrapFun = "tensor.NewStringScalar"
case "niv":
wrapFun = "tensor.NewIntFromValues"
ellip = true
case "nfv":
wrapFun = "tensor.NewFloat64FromValues"
ellip = true
case "nsv":
wrapFun = "tensor.NewStringFromValues"
ellip = true
default:
wrapFun = fw.wrap
}
mp.startFunc(wrapFun, true) // don't lookup -- don't auto-convert args
mp.out.Add(token.LPAREN)
return ellip
}
func (mp *mathParse) selectorExpr(ex *ast.SelectorExpr) {
fw, ok := numpyProps[ex.Sel.Name]
if !ok {
mp.expr(ex.X)
mp.addToken(token.PERIOD)
mp.out.Add(token.IDENT, ex.Sel.Name)
mp.idx++
return
}
ellip := fw.wrapFunc(mp)
mp.expr(ex.X)
if fw.fun != "" {
mp.addToken(token.PERIOD)
mp.out.Add(token.IDENT, fw.fun)
mp.idx++
} else {
mp.idx += 2
}
if ellip {
mp.out.Add(token.ELLIPSIS)
}
mp.out.Add(token.RPAREN)
mp.endFunc()
}
func (mp *mathParse) indexListExpr(il *ast.IndexListExpr) {
// fmt.Println("slice expr", se)
}
func (mp *mathParse) indexExpr(il *ast.IndexExpr) {
if _, ok := il.Index.(*ast.IndexListExpr); ok {
mp.basicSlicingExpr(il)
}
}
func (mp *mathParse) basicSlicingExpr(il *ast.IndexExpr) {
iil := il.Index.(*ast.IndexListExpr)
fun := "tensor.Reslice"
if mp.exprsAreBool(iil.Indices) {
fun = "tensor.Mask"
}
mp.startFunc(fun)
mp.out.Add(token.LPAREN)
mp.expr(il.X)
mp.nextArg()
mp.addToken(token.COMMA) // use the [ -- can't use ( to preserve X
mp.exprList(iil.Indices)
mp.addToken(token.RPAREN) // replaces ]
mp.endFunc()
}
func (mp *mathParse) sliceExpr(se *ast.SliceExpr) {
if se.Low == nil && se.High == nil && se.Max == nil {
mp.out.Add(token.IDENT, "tensor.FullAxis")
mp.idx++
return
}
mp.out.Add(token.IDENT, "tensor.Slice")
mp.out.Add(token.LBRACE)
prev := false
if se.Low != nil {
mp.out.Add(token.IDENT, "Start:")
mp.expr(se.Low)
prev = true
if se.High == nil && se.Max == nil {
mp.idx++
}
}
if se.High != nil {
if prev {
mp.out.Add(token.COMMA)
}
mp.out.Add(token.IDENT, "Stop:")
mp.idx++
mp.expr(se.High)
prev = true
}
if se.Max != nil {
if prev {
mp.out.Add(token.COMMA)
}
mp.idx++
if se.Low == nil && se.High == nil {
mp.idx++
}
mp.out.Add(token.IDENT, "Step:")
mp.expr(se.Max)
}
mp.out.Add(token.RBRACE)
}
func (mp *mathParse) arrayLiteral(il *ast.IndexListExpr) {
kind := inferKindExprList(il.Indices)
if kind == token.ILLEGAL {
kind = token.FLOAT // default
}
// todo: look for sub-arrays etc.
typ := "float64"
fun := "Float64"
switch kind {
case token.FLOAT:
case token.INT:
typ = "int"
fun = "Int"
case token.STRING:
typ = "string"
fun = "String"
}
if mp.curArgIsInts() {
mp.idx++ // opening brace we're not using
mp.exprList(il.Indices)
mp.idx++ // closing brace we're not using
return
}
var sh []int
mp.arrayShape(il.Indices, &sh)
if len(sh) > 1 {
mp.startFunc("tensor.Reshape")
mp.out.Add(token.LPAREN)
}
mp.startFunc("tensor.New" + fun + "FromValues")
mp.out.Add(token.LPAREN)
mp.out.Add(token.IDENT, "[]"+typ)
mp.addToken(token.LBRACE)
mp.exprList(il.Indices)
mp.addToken(token.RBRACE)
mp.out.AddMulti(token.ELLIPSIS, token.RPAREN)
mp.endFunc()
if len(sh) > 1 {
mp.out.Add(token.COMMA)
nsh := len(sh)
for i, s := range sh {
mp.out.Add(token.INT, fmt.Sprintf("%d", s))
if i < nsh-1 {
mp.out.Add(token.COMMA)
}
}
mp.out.Add(token.RPAREN)
mp.endFunc()
}
}
func (mp *mathParse) arrayShape(ex []ast.Expr, sh *[]int) {
n := len(ex)
if n == 0 {
return
}
*sh = append(*sh, n)
for i := range n {
if il, ok := ex[i].(*ast.IndexListExpr); ok {
mp.arrayShape(il.Indices, sh)
return
}
}
}
// nofun = do not accept a function version, just a method
var numpyFuncs = map[string]funWrap{
// "array": {"tensor.NewFloatFromValues", ""}, // todo: probably not right, maybe don't have?
"zeros": {"tensor.NewFloat64", ""},
"full": {"tensor.NewFloat64Full", ""},
"ones": {"tensor.NewFloat64Ones", ""},
"rand": {"tensor.NewFloat64Rand", ""},
"arange": {"tensor.NewIntRange", ""},
"linspace": {"tensor.NewFloat64SpacedLinear", ""},
"reshape": {"tensor.Reshape", ""},
"copy": {"tensor.Clone", ""},
"get": {"tensorfs.Get", ""},
"set": {"tensorfs.Set", ""},
"flatten": {"tensor.Flatten", "nofun"},
"squeeze": {"tensor.Squeeze", "nofun"},
}
func (mp *mathParse) callExpr(ex *ast.CallExpr) {
switch x := ex.Fun.(type) {
case *ast.Ident:
if fw, ok := numpyProps[x.Name]; ok && fw.wrap != "nofun" {
mp.callPropFun(ex, fw)
return
}
mp.callName(ex, x.Name, "")
case *ast.SelectorExpr:
fun := x.Sel.Name
if pkg, ok := x.X.(*ast.Ident); ok {
if fw, ok := numpyFuncs[fun]; ok {
mp.callPropSelFun(ex, x.X, fw)
return
} else {
// fmt.Println("call name:", fun, pkg.Name)
mp.callName(ex, fun, pkg.Name)
}
} else {
if fw, ok := numpyFuncs[fun]; ok {
mp.callPropSelFun(ex, x.X, fw)
return
}
// todo: dot fun?
mp.expr(ex)
}
default:
mp.expr(ex.Fun)
}
mp.argsList(ex.Args)
// todo: ellipsis
mp.addToken(token.RPAREN)
mp.endFunc()
}
// this calls a "prop" function like ndim(a) on the object.
func (mp *mathParse) callPropFun(cf *ast.CallExpr, fw funWrap) {
ellip := fw.wrapFunc(mp)
mp.idx += 2
mp.exprList(cf.Args) // this is the tensor
mp.addToken(token.PERIOD)
mp.out.Add(token.IDENT, fw.fun)
if ellip {
mp.out.Add(token.ELLIPSIS)
}
mp.out.Add(token.RPAREN)
mp.endFunc()
}
// this calls global function through selector like: a.reshape()
func (mp *mathParse) callPropSelFun(cf *ast.CallExpr, ex ast.Expr, fw funWrap) {
mp.startFunc(fw.fun)
mp.out.Add(token.LPAREN) // use the (
mp.expr(ex)
mp.idx += 2
if len(cf.Args) > 0 {
mp.nextArg() // did first
mp.addToken(token.COMMA)
mp.argsList(cf.Args)
} else {
mp.idx++
}
mp.addToken(token.RPAREN)
mp.endFunc()
}
func (mp *mathParse) callName(cf *ast.CallExpr, funName, pkgName string) {
if fw, ok := numpyFuncs[funName]; ok {
mp.startFunc(fw.fun)
mp.addToken(token.LPAREN) // use the (
mp.idx++ // paren too
return
}
var err error // validate name
if pkgName != "" {
funName = pkgName + "." + funName
_, err = tensor.FuncByName(funName)
} else { // non-package qualified names are _only_ in tmath! can be lowercase
_, err = tensor.FuncByName("tmath." + funName)
if err != nil {
funName = strings.ToUpper(funName[:1]) + funName[1:] // first letter uppercased
_, err = tensor.FuncByName("tmath." + funName)
}
if err == nil { // registered, must be in tmath
funName = "tmath." + funName
}
}
if err != nil { // not a registered tensor function
// fmt.Println("regular fun", funName)
mp.startFunc(funName)
mp.addToken(token.LPAREN) // use the (
mp.idx += 3
return
}
mp.startFunc(funName)
mp.idx += 1
if pkgName != "" {
mp.idx += 2 // . and selector
}
mp.addToken(token.LPAREN)
}
// basic ident replacements
var consts = map[string]string{
"newaxis": "tensor.NewAxis",
"pi": "tensor.NewFloat64Scalar(math.Pi)",
}
func (mp *mathParse) ident(id *ast.Ident) {
if id == nil {
return
}
if cn, ok := consts[id.Name]; ok {
mp.out.Add(token.IDENT, cn)
mp.idx++
return
}
if mp.curArgIsInts() {
mp.out.Add(token.IDENT, "tensor.AsIntSlice")
mp.out.Add(token.LPAREN)
mp.addCur()
mp.out.AddMulti(token.RPAREN, token.ELLIPSIS)
} else {
mp.addCur()
}
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// mparse is a hacked version of go/parser:
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package transpile
import (
"fmt"
"go/ast"
"go/build/constraint"
"go/scanner"
"go/token"
"strings"
)
// ParseLine parses a line of code that could contain one or more statements
func ParseLine(code string, mode Mode) (stmts []ast.Stmt, err error) {
fset := token.NewFileSet()
var p parser
defer func() {
if e := recover(); e != nil {
// resume same panic if it's not a bailout
bail, ok := e.(bailout)
if !ok {
panic(e)
} else if bail.msg != "" {
p.errors.Add(p.file.Position(bail.pos), bail.msg)
}
}
p.errors.Sort()
err = p.errors.Err()
}()
p.init(fset, "", []byte(code), mode)
stmts = p.parseStmtList()
// If a semicolon was inserted, consume it;
// report an error if there's more tokens.
if p.tok == token.SEMICOLON && p.lit == "\n" {
p.next()
}
if p.tok == token.RBRACE {
return
}
p.expect(token.EOF)
return
}
// ParseExpr parses an expression
func ParseExpr(code string, mode Mode) (expr ast.Expr, err error) {
fset := token.NewFileSet()
var p parser
defer func() {
if e := recover(); e != nil {
// resume same panic if it's not a bailout
bail, ok := e.(bailout)
if !ok {
panic(e)
} else if bail.msg != "" {
p.errors.Add(p.file.Position(bail.pos), bail.msg)
}
}
p.errors.Sort()
err = p.errors.Err()
}()
p.init(fset, "", []byte(code), mode)
expr = p.parseRhs()
// If a semicolon was inserted, consume it;
// report an error if there's more tokens.
if p.tok == token.SEMICOLON && p.lit == "\n" {
p.next()
}
p.expect(token.EOF)
return
}
// A Mode value is a set of flags (or 0).
// They control the amount of source code parsed and other optional
// parser functionality.
type Mode uint
const (
ParseComments Mode = 1 << iota // parse comments and add them to AST
Trace // print a trace of parsed productions
DeclarationErrors // report declaration errors
SpuriousErrors // same as AllErrors, for backward-compatibility
AllErrors = SpuriousErrors // report all errors (not just the first 10 on different lines)
)
// The parser structure holds the parser's internal state.
type parser struct {
file *token.File
errors scanner.ErrorList
scanner scanner.Scanner
// Tracing/debugging
mode Mode // parsing mode
trace bool // == (mode&Trace != 0)
indent int // indentation used for tracing output
// Comments
comments []*ast.CommentGroup
leadComment *ast.CommentGroup // last lead comment
lineComment *ast.CommentGroup // last line comment
top bool // in top of file (before package clause)
goVersion string // minimum Go version found in //go:build comment
// Next token
pos token.Pos // token position
tok token.Token // one token look-ahead
lit string // token literal
// Error recovery
// (used to limit the number of calls to parser.advance
// w/o making scanning progress - avoids potential endless
// loops across multiple parser functions during error recovery)
syncPos token.Pos // last synchronization position
syncCnt int // number of parser.advance calls without progress
// Non-syntactic parser control
exprLev int // < 0: in control clause, >= 0: in expression
inRhs bool // if set, the parser is parsing a rhs expression
imports []*ast.ImportSpec // list of imports
// nestLev is used to track and limit the recursion depth
// during parsing.
nestLev int
}
func (p *parser) init(fset *token.FileSet, filename string, src []byte, mode Mode) {
p.file = fset.AddFile(filename, -1, len(src))
eh := func(pos token.Position, msg string) {
if !strings.Contains(msg, "@") {
p.errors.Add(pos, msg)
}
}
p.scanner.Init(p.file, src, eh, scanner.ScanComments)
p.top = true
p.mode = mode
p.trace = mode&Trace != 0 // for convenience (p.trace is used frequently)
p.next()
}
// ----------------------------------------------------------------------------
// Parsing support
func (p *parser) printTrace(a ...any) {
const dots = ". . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . "
const n = len(dots)
pos := p.file.Position(p.pos)
fmt.Printf("%5d:%3d: ", pos.Line, pos.Column)
i := 2 * p.indent
for i > n {
fmt.Print(dots)
i -= n
}
// i <= n
fmt.Print(dots[0:i])
fmt.Println(a...)
}
func trace(p *parser, msg string) *parser {
p.printTrace(msg, "(")
p.indent++
return p
}
// Usage pattern: defer un(trace(p, "..."))
func un(p *parser) {
p.indent--
p.printTrace(")")
}
// maxNestLev is the deepest we're willing to recurse during parsing
const maxNestLev int = 1e5
func incNestLev(p *parser) *parser {
p.nestLev++
if p.nestLev > maxNestLev {
p.error(p.pos, "exceeded max nesting depth")
panic(bailout{})
}
return p
}
// decNestLev is used to track nesting depth during parsing to prevent stack exhaustion.
// It is used along with incNestLev in a similar fashion to how un and trace are used.
func decNestLev(p *parser) {
p.nestLev--
}
// Advance to the next token.
func (p *parser) next0() {
// Because of one-token look-ahead, print the previous token
// when tracing as it provides a more readable output. The
// very first token (!p.pos.IsValid()) is not initialized
// (it is token.ILLEGAL), so don't print it.
if p.trace && p.pos.IsValid() {
s := p.tok.String()
switch {
case p.tok.IsLiteral():
p.printTrace(s, p.lit)
case p.tok.IsOperator(), p.tok.IsKeyword():
p.printTrace("\"" + s + "\"")
default:
p.printTrace(s)
}
}
for {
p.pos, p.tok, p.lit = p.scanner.Scan()
if p.tok == token.COMMENT {
if p.top && strings.HasPrefix(p.lit, "//go:build") {
if x, err := constraint.Parse(p.lit); err == nil {
p.goVersion = constraint.GoVersion(x)
}
}
if p.mode&ParseComments == 0 {
continue
}
} else {
// Found a non-comment; top of file is over.
p.top = false
}
break
}
}
// Consume a comment and return it and the line on which it ends.
func (p *parser) consumeComment() (comment *ast.Comment, endline int) {
// /*-style comments may end on a different line than where they start.
// Scan the comment for '\n' chars and adjust endline accordingly.
endline = p.file.Line(p.pos)
if p.lit[1] == '*' {
// don't use range here - no need to decode Unicode code points
for i := 0; i < len(p.lit); i++ {
if p.lit[i] == '\n' {
endline++
}
}
}
comment = &ast.Comment{Slash: p.pos, Text: p.lit}
p.next0()
return
}
// Consume a group of adjacent comments, add it to the parser's
// comments list, and return it together with the line at which
// the last comment in the group ends. A non-comment token or n
// empty lines terminate a comment group.
func (p *parser) consumeCommentGroup(n int) (comments *ast.CommentGroup, endline int) {
var list []*ast.Comment
endline = p.file.Line(p.pos)
for p.tok == token.COMMENT && p.file.Line(p.pos) <= endline+n {
var comment *ast.Comment
comment, endline = p.consumeComment()
list = append(list, comment)
}
// add comment group to the comments list
comments = &ast.CommentGroup{List: list}
p.comments = append(p.comments, comments)
return
}
// Advance to the next non-comment token. In the process, collect
// any comment groups encountered, and remember the last lead and
// line comments.
//
// A lead comment is a comment group that starts and ends in a
// line without any other tokens and that is followed by a non-comment
// token on the line immediately after the comment group.
//
// A line comment is a comment group that follows a non-comment
// token on the same line, and that has no tokens after it on the line
// where it ends.
//
// Lead and line comments may be considered documentation that is
// stored in the AST.
func (p *parser) next() {
p.leadComment = nil
p.lineComment = nil
prev := p.pos
p.next0()
if p.tok == token.COMMENT {
var comment *ast.CommentGroup
var endline int
if p.file.Line(p.pos) == p.file.Line(prev) {
// The comment is on same line as the previous token; it
// cannot be a lead comment but may be a line comment.
comment, endline = p.consumeCommentGroup(0)
if p.file.Line(p.pos) != endline || p.tok == token.SEMICOLON || p.tok == token.EOF {
// The next token is on a different line, thus
// the last comment group is a line comment.
p.lineComment = comment
}
}
// consume successor comments, if any
endline = -1
for p.tok == token.COMMENT {
comment, endline = p.consumeCommentGroup(1)
}
if endline+1 == p.file.Line(p.pos) {
// The next token is following on the line immediately after the
// comment group, thus the last comment group is a lead comment.
p.leadComment = comment
}
}
}
// A bailout panic is raised to indicate early termination. pos and msg are
// only populated when bailing out of object resolution.
type bailout struct {
pos token.Pos
msg string
}
func (p *parser) error(pos token.Pos, msg string) {
if p.trace {
defer un(trace(p, "error: "+msg))
}
epos := p.file.Position(pos)
// If AllErrors is not set, discard errors reported on the same line
// as the last recorded error and stop parsing if there are more than
// 10 errors.
if p.mode&AllErrors == 0 {
n := len(p.errors)
if n > 0 && p.errors[n-1].Pos.Line == epos.Line {
return // discard - likely a spurious error
}
if n > 10 {
panic(bailout{})
}
}
p.errors.Add(epos, msg)
}
func (p *parser) errorExpected(pos token.Pos, msg string) {
msg = "expected " + msg
if pos == p.pos {
// the error happened at the current position;
// make the error message more specific
switch {
case p.tok == token.SEMICOLON && p.lit == "\n":
msg += ", found newline"
case p.tok.IsLiteral():
// print 123 rather than 'INT', etc.
msg += ", found " + p.lit
default:
msg += ", found '" + p.tok.String() + "'"
}
}
p.error(pos, msg)
}
func (p *parser) expect(tok token.Token) token.Pos {
pos := p.pos
if p.tok != tok {
p.errorExpected(pos, "'"+tok.String()+"'")
}
p.next() // make progress
return pos
}
// expect2 is like expect, but it returns an invalid position
// if the expected token is not found.
func (p *parser) expect2(tok token.Token) (pos token.Pos) {
if p.tok == tok {
pos = p.pos
} else {
p.errorExpected(p.pos, "'"+tok.String()+"'")
}
p.next() // make progress
return
}
// expectClosing is like expect but provides a better error message
// for the common case of a missing comma before a newline.
func (p *parser) expectClosing(tok token.Token, context string) token.Pos {
if p.tok != tok && p.tok == token.SEMICOLON && p.lit == "\n" {
p.error(p.pos, "missing ',' before newline in "+context)
p.next()
}
return p.expect(tok)
}
// expectSemi consumes a semicolon and returns the applicable line comment.
func (p *parser) expectSemi() (comment *ast.CommentGroup) {
// semicolon is optional before a closing ')' or '}'
if p.tok != token.RPAREN && p.tok != token.RBRACE {
switch p.tok {
case token.COMMA:
// permit a ',' instead of a ';' but complain
p.errorExpected(p.pos, "';'")
fallthrough
case token.SEMICOLON:
if p.lit == ";" {
// explicit semicolon
p.next()
comment = p.lineComment // use following comments
} else {
// artificial semicolon
comment = p.lineComment // use preceding comments
p.next()
}
return comment
default:
// math: allow unexpected endings..
// p.errorExpected(p.pos, "';'")
// p.advance(stmtStart)
}
}
return nil
}
func (p *parser) atComma(context string, follow token.Token) bool {
if p.tok == token.COMMA {
return true
}
if p.tok != follow {
msg := "missing ','"
if p.tok == token.SEMICOLON && p.lit == "\n" {
msg += " before newline"
}
p.error(p.pos, msg+" in "+context)
return true // "insert" comma and continue
}
return false
}
func passert(cond bool, msg string) {
if !cond {
panic("go/parser internal error: " + msg)
}
}
// advance consumes tokens until the current token p.tok
// is in the 'to' set, or token.EOF. For error recovery.
func (p *parser) advance(to map[token.Token]bool) {
for ; p.tok != token.EOF; p.next() {
if to[p.tok] {
// Return only if parser made some progress since last
// sync or if it has not reached 10 advance calls without
// progress. Otherwise consume at least one token to
// avoid an endless parser loop (it is possible that
// both parseOperand and parseStmt call advance and
// correctly do not advance, thus the need for the
// invocation limit p.syncCnt).
if p.pos == p.syncPos && p.syncCnt < 10 {
p.syncCnt++
return
}
if p.pos > p.syncPos {
p.syncPos = p.pos
p.syncCnt = 0
return
}
// Reaching here indicates a parser bug, likely an
// incorrect token list in this function, but it only
// leads to skipping of possibly correct code if a
// previous error is present, and thus is preferred
// over a non-terminating parse.
}
}
}
var stmtStart = map[token.Token]bool{
token.BREAK: true,
token.CONST: true,
token.CONTINUE: true,
token.DEFER: true,
token.FALLTHROUGH: true,
token.FOR: true,
token.GO: true,
token.GOTO: true,
token.IF: true,
token.RETURN: true,
token.SELECT: true,
token.SWITCH: true,
token.TYPE: true,
token.VAR: true,
}
var declStart = map[token.Token]bool{
token.IMPORT: true,
token.CONST: true,
token.TYPE: true,
token.VAR: true,
}
var exprEnd = map[token.Token]bool{
token.COMMA: true,
token.COLON: true,
token.SEMICOLON: true,
token.RPAREN: true,
token.RBRACK: true,
token.RBRACE: true,
}
// safePos returns a valid file position for a given position: If pos
// is valid to begin with, safePos returns pos. If pos is out-of-range,
// safePos returns the EOF position.
//
// This is hack to work around "artificial" end positions in the AST which
// are computed by adding 1 to (presumably valid) token positions. If the
// token positions are invalid due to parse errors, the resulting end position
// may be past the file's EOF position, which would lead to panics if used
// later on.
func (p *parser) safePos(pos token.Pos) (res token.Pos) {
defer func() {
if recover() != nil {
res = token.Pos(p.file.Base() + p.file.Size()) // EOF position
}
}()
_ = p.file.Offset(pos) // trigger a panic if position is out-of-range
return pos
}
// ----------------------------------------------------------------------------
// Identifiers
func (p *parser) parseIdent() *ast.Ident {
pos := p.pos
name := "_"
if p.tok == token.IDENT {
name = p.lit
p.next()
} else {
p.expect(token.IDENT) // use expect() error handling
}
return &ast.Ident{NamePos: pos, Name: name}
}
func (p *parser) parseIdentList() (list []*ast.Ident) {
if p.trace {
defer un(trace(p, "IdentList"))
}
list = append(list, p.parseIdent())
for p.tok == token.COMMA {
p.next()
list = append(list, p.parseIdent())
}
return
}
// ----------------------------------------------------------------------------
// Common productions
// If lhs is set, result list elements which are identifiers are not resolved.
func (p *parser) parseExprList() (list []ast.Expr) {
if p.trace {
defer un(trace(p, "ExpressionList"))
}
list = append(list, p.parseExpr())
for p.tok == token.COMMA {
p.next()
list = append(list, p.parseExpr())
}
return
}
func (p *parser) parseList(inRhs bool) []ast.Expr {
old := p.inRhs
p.inRhs = inRhs
list := p.parseExprList()
p.inRhs = old
return list
}
// math: allow full array list expressions
func (p *parser) parseArrayList(lbrack token.Pos) *ast.IndexListExpr {
if p.trace {
defer un(trace(p, "ArrayList"))
}
p.exprLev++
// x := p.parseRhs()
x := p.parseExprList()
p.exprLev--
rbrack := p.expect(token.RBRACK)
return &ast.IndexListExpr{Lbrack: lbrack, Indices: x, Rbrack: rbrack}
}
// ----------------------------------------------------------------------------
// Types
func (p *parser) parseType() ast.Expr {
if p.trace {
defer un(trace(p, "Type"))
}
typ := p.tryIdentOrType()
if typ == nil {
pos := p.pos
p.errorExpected(pos, "type")
p.advance(exprEnd)
return &ast.BadExpr{From: pos, To: p.pos}
}
return typ
}
func (p *parser) parseQualifiedIdent(ident *ast.Ident) ast.Expr {
if p.trace {
defer un(trace(p, "QualifiedIdent"))
}
typ := p.parseTypeName(ident)
if p.tok == token.LBRACK {
typ = p.parseTypeInstance(typ)
}
return typ
}
// If the result is an identifier, it is not resolved.
func (p *parser) parseTypeName(ident *ast.Ident) ast.Expr {
if p.trace {
defer un(trace(p, "TypeName"))
}
if ident == nil {
ident = p.parseIdent()
}
if p.tok == token.PERIOD {
// ident is a package name
p.next()
sel := p.parseIdent()
return &ast.SelectorExpr{X: ident, Sel: sel}
}
return ident
}
// "[" has already been consumed, and lbrack is its position.
// If len != nil it is the already consumed array length.
func (p *parser) parseArrayType(lbrack token.Pos, len ast.Expr) *ast.ArrayType {
if p.trace {
defer un(trace(p, "ArrayType"))
}
if len == nil {
p.exprLev++
// always permit ellipsis for more fault-tolerant parsing
if p.tok == token.ELLIPSIS {
len = &ast.Ellipsis{Ellipsis: p.pos}
p.next()
} else if p.tok != token.RBRACK {
len = p.parseRhs()
}
p.exprLev--
}
if p.tok == token.COMMA {
// Trailing commas are accepted in type parameter
// lists but not in array type declarations.
// Accept for better error handling but complain.
p.error(p.pos, "unexpected comma; expecting ]")
p.next()
}
p.expect(token.RBRACK)
elt := p.parseType()
return &ast.ArrayType{Lbrack: lbrack, Len: len, Elt: elt}
}
func (p *parser) parseArrayFieldOrTypeInstance(x *ast.Ident) (*ast.Ident, ast.Expr) {
if p.trace {
defer un(trace(p, "ArrayFieldOrTypeInstance"))
}
lbrack := p.expect(token.LBRACK)
trailingComma := token.NoPos // if valid, the position of a trailing comma preceding the ']'
var args []ast.Expr
if p.tok != token.RBRACK {
p.exprLev++
args = append(args, p.parseRhs())
for p.tok == token.COMMA {
comma := p.pos
p.next()
if p.tok == token.RBRACK {
trailingComma = comma
break
}
args = append(args, p.parseRhs())
}
p.exprLev--
}
rbrack := p.expect(token.RBRACK)
_ = rbrack
if len(args) == 0 {
// x []E
elt := p.parseType()
return x, &ast.ArrayType{Lbrack: lbrack, Elt: elt}
}
// x [P]E or x[P]
if len(args) == 1 {
elt := p.tryIdentOrType()
if elt != nil {
// x [P]E
if trailingComma.IsValid() {
// Trailing commas are invalid in array type fields.
p.error(trailingComma, "unexpected comma; expecting ]")
}
return x, &ast.ArrayType{Lbrack: lbrack, Len: args[0], Elt: elt}
}
}
// x[P], x[P1, P2], ...
return nil, nil // typeparams.PackIndexExpr(x, lbrack, args, rbrack)
}
func (p *parser) parseFieldDecl() *ast.Field {
if p.trace {
defer un(trace(p, "FieldDecl"))
}
doc := p.leadComment
var names []*ast.Ident
var typ ast.Expr
switch p.tok {
case token.IDENT:
name := p.parseIdent()
if p.tok == token.PERIOD || p.tok == token.STRING || p.tok == token.SEMICOLON || p.tok == token.RBRACE {
// embedded type
typ = name
if p.tok == token.PERIOD {
typ = p.parseQualifiedIdent(name)
}
} else {
// name1, name2, ... T
names = []*ast.Ident{name}
for p.tok == token.COMMA {
p.next()
names = append(names, p.parseIdent())
}
// Careful dance: We don't know if we have an embedded instantiated
// type T[P1, P2, ...] or a field T of array type []E or [P]E.
if len(names) == 1 && p.tok == token.LBRACK {
name, typ = p.parseArrayFieldOrTypeInstance(name)
if name == nil {
names = nil
}
} else {
// T P
typ = p.parseType()
}
}
case token.MUL:
star := p.pos
p.next()
if p.tok == token.LPAREN {
// *(T)
p.error(p.pos, "cannot parenthesize embedded type")
p.next()
typ = p.parseQualifiedIdent(nil)
// expect closing ')' but no need to complain if missing
if p.tok == token.RPAREN {
p.next()
}
} else {
// *T
typ = p.parseQualifiedIdent(nil)
}
typ = &ast.StarExpr{Star: star, X: typ}
case token.LPAREN:
p.error(p.pos, "cannot parenthesize embedded type")
p.next()
if p.tok == token.MUL {
// (*T)
star := p.pos
p.next()
typ = &ast.StarExpr{Star: star, X: p.parseQualifiedIdent(nil)}
} else {
// (T)
typ = p.parseQualifiedIdent(nil)
}
// expect closing ')' but no need to complain if missing
if p.tok == token.RPAREN {
p.next()
}
default:
pos := p.pos
p.errorExpected(pos, "field name or embedded type")
p.advance(exprEnd)
typ = &ast.BadExpr{From: pos, To: p.pos}
}
var tag *ast.BasicLit
if p.tok == token.STRING {
tag = &ast.BasicLit{ValuePos: p.pos, Kind: p.tok, Value: p.lit}
p.next()
}
comment := p.expectSemi()
field := &ast.Field{Doc: doc, Names: names, Type: typ, Tag: tag, Comment: comment}
return field
}
func (p *parser) parseStructType() *ast.StructType {
if p.trace {
defer un(trace(p, "StructType"))
}
pos := p.expect(token.STRUCT)
lbrace := p.expect(token.LBRACE)
var list []*ast.Field
for p.tok == token.IDENT || p.tok == token.MUL || p.tok == token.LPAREN {
// a field declaration cannot start with a '(' but we accept
// it here for more robust parsing and better error messages
// (parseFieldDecl will check and complain if necessary)
list = append(list, p.parseFieldDecl())
}
rbrace := p.expect(token.RBRACE)
return &ast.StructType{
Struct: pos,
Fields: &ast.FieldList{
Opening: lbrace,
List: list,
Closing: rbrace,
},
}
}
func (p *parser) parsePointerType() *ast.StarExpr {
if p.trace {
defer un(trace(p, "PointerType"))
}
star := p.expect(token.MUL)
base := p.parseType()
return &ast.StarExpr{Star: star, X: base}
}
func (p *parser) parseDotsType() *ast.Ellipsis {
if p.trace {
defer un(trace(p, "DotsType"))
}
pos := p.expect(token.ELLIPSIS)
elt := p.parseType()
return &ast.Ellipsis{Ellipsis: pos, Elt: elt}
}
type field struct {
name *ast.Ident
typ ast.Expr
}
func (p *parser) parseParamDecl(name *ast.Ident, typeSetsOK bool) (f field) {
// TODO(rFindley) refactor to be more similar to paramDeclOrNil in the syntax
// package
if p.trace {
defer un(trace(p, "ParamDeclOrNil"))
}
ptok := p.tok
if name != nil {
p.tok = token.IDENT // force token.IDENT case in switch below
} else if typeSetsOK && p.tok == token.TILDE {
// "~" ...
return field{nil, p.embeddedElem(nil)}
}
switch p.tok {
case token.IDENT:
// name
if name != nil {
f.name = name
p.tok = ptok
} else {
f.name = p.parseIdent()
}
switch p.tok {
case token.IDENT, token.MUL, token.ARROW, token.FUNC, token.CHAN, token.MAP, token.STRUCT, token.INTERFACE, token.LPAREN:
// name type
f.typ = p.parseType()
case token.LBRACK:
// name "[" type1, ..., typeN "]" or name "[" n "]" type
f.name, f.typ = p.parseArrayFieldOrTypeInstance(f.name)
case token.ELLIPSIS:
// name "..." type
f.typ = p.parseDotsType()
return // don't allow ...type "|" ...
case token.PERIOD:
// name "." ...
f.typ = p.parseQualifiedIdent(f.name)
f.name = nil
case token.TILDE:
if typeSetsOK {
f.typ = p.embeddedElem(nil)
return
}
case token.OR:
if typeSetsOK {
// name "|" typeset
f.typ = p.embeddedElem(f.name)
f.name = nil
return
}
}
case token.MUL, token.ARROW, token.FUNC, token.LBRACK, token.CHAN, token.MAP, token.STRUCT, token.INTERFACE, token.LPAREN:
// type
f.typ = p.parseType()
case token.ELLIPSIS:
// "..." type
// (always accepted)
f.typ = p.parseDotsType()
return // don't allow ...type "|" ...
default:
// TODO(rfindley): this is incorrect in the case of type parameter lists
// (should be "']'" in that case)
p.errorExpected(p.pos, "')'")
p.advance(exprEnd)
}
// [name] type "|"
if typeSetsOK && p.tok == token.OR && f.typ != nil {
f.typ = p.embeddedElem(f.typ)
}
return
}
func (p *parser) parseParameterList(name0 *ast.Ident, typ0 ast.Expr, closing token.Token) (params []*ast.Field) {
if p.trace {
defer un(trace(p, "ParameterList"))
}
// Type parameters are the only parameter list closed by ']'.
tparams := closing == token.RBRACK
pos0 := p.pos
if name0 != nil {
pos0 = name0.Pos()
} else if typ0 != nil {
pos0 = typ0.Pos()
}
// Note: The code below matches the corresponding code in the syntax
// parser closely. Changes must be reflected in either parser.
// For the code to match, we use the local []field list that
// corresponds to []syntax.Field. At the end, the list must be
// converted into an []*ast.Field.
var list []field
var named int // number of parameters that have an explicit name and type
var typed int // number of parameters that have an explicit type
for name0 != nil || p.tok != closing && p.tok != token.EOF {
var par field
if typ0 != nil {
if tparams {
typ0 = p.embeddedElem(typ0)
}
par = field{name0, typ0}
} else {
par = p.parseParamDecl(name0, tparams)
}
name0 = nil // 1st name was consumed if present
typ0 = nil // 1st typ was consumed if present
if par.name != nil || par.typ != nil {
list = append(list, par)
if par.name != nil && par.typ != nil {
named++
}
if par.typ != nil {
typed++
}
}
if !p.atComma("parameter list", closing) {
break
}
p.next()
}
if len(list) == 0 {
return // not uncommon
}
// distribute parameter types (len(list) > 0)
if named == 0 {
// all unnamed => found names are type names
for i := 0; i < len(list); i++ {
par := &list[i]
if typ := par.name; typ != nil {
par.typ = typ
par.name = nil
}
}
if tparams {
// This is the same error handling as below, adjusted for type parameters only.
// See comment below for details. (go.dev/issue/64534)
var errPos token.Pos
var msg string
if named == typed /* same as typed == 0 */ {
errPos = p.pos // position error at closing ]
msg = "missing type constraint"
} else {
errPos = pos0 // position at opening [ or first name
msg = "missing type parameter name"
if len(list) == 1 {
msg += " or invalid array length"
}
}
p.error(errPos, msg)
}
} else if named != len(list) {
// some named or we're in a type parameter list => all must be named
var errPos token.Pos // left-most error position (or invalid)
var typ ast.Expr // current type (from right to left)
for i := len(list) - 1; i >= 0; i-- {
if par := &list[i]; par.typ != nil {
typ = par.typ
if par.name == nil {
errPos = typ.Pos()
n := ast.NewIdent("_")
n.NamePos = errPos // correct position
par.name = n
}
} else if typ != nil {
par.typ = typ
} else {
// par.typ == nil && typ == nil => we only have a par.name
errPos = par.name.Pos()
par.typ = &ast.BadExpr{From: errPos, To: p.pos}
}
}
if errPos.IsValid() {
var msg string
if tparams {
// Not all parameters are named because named != len(list).
// If named == typed we must have parameters that have no types,
// and they must be at the end of the parameter list, otherwise
// the types would have been filled in by the right-to-left sweep
// above and we wouldn't have an error. Since we are in a type
// parameter list, the missing types are constraints.
if named == typed {
errPos = p.pos // position error at closing ]
msg = "missing type constraint"
} else {
msg = "missing type parameter name"
// go.dev/issue/60812
if len(list) == 1 {
msg += " or invalid array length"
}
}
} else {
msg = "mixed named and unnamed parameters"
}
p.error(errPos, msg)
}
}
// Convert list to []*ast.Field.
// If list contains types only, each type gets its own ast.Field.
if named == 0 {
// parameter list consists of types only
for _, par := range list {
passert(par.typ != nil, "nil type in unnamed parameter list")
params = append(params, &ast.Field{Type: par.typ})
}
return
}
// If the parameter list consists of named parameters with types,
// collect all names with the same types into a single ast.Field.
var names []*ast.Ident
var typ ast.Expr
addParams := func() {
passert(typ != nil, "nil type in named parameter list")
field := &ast.Field{Names: names, Type: typ}
params = append(params, field)
names = nil
}
for _, par := range list {
if par.typ != typ {
if len(names) > 0 {
addParams()
}
typ = par.typ
}
names = append(names, par.name)
}
if len(names) > 0 {
addParams()
}
return
}
func (p *parser) parseParameters(acceptTParams bool) (tparams, params *ast.FieldList) {
if p.trace {
defer un(trace(p, "Parameters"))
}
if acceptTParams && p.tok == token.LBRACK {
opening := p.pos
p.next()
// [T any](params) syntax
list := p.parseParameterList(nil, nil, token.RBRACK)
rbrack := p.expect(token.RBRACK)
tparams = &ast.FieldList{Opening: opening, List: list, Closing: rbrack}
// Type parameter lists must not be empty.
if tparams.NumFields() == 0 {
p.error(tparams.Closing, "empty type parameter list")
tparams = nil // avoid follow-on errors
}
}
opening := p.expect(token.LPAREN)
var fields []*ast.Field
if p.tok != token.RPAREN {
fields = p.parseParameterList(nil, nil, token.RPAREN)
}
rparen := p.expect(token.RPAREN)
params = &ast.FieldList{Opening: opening, List: fields, Closing: rparen}
return
}
func (p *parser) parseResult() *ast.FieldList {
if p.trace {
defer un(trace(p, "Result"))
}
if p.tok == token.LPAREN {
_, results := p.parseParameters(false)
return results
}
typ := p.tryIdentOrType()
if typ != nil {
list := make([]*ast.Field, 1)
list[0] = &ast.Field{Type: typ}
return &ast.FieldList{List: list}
}
return nil
}
func (p *parser) parseFuncType() *ast.FuncType {
if p.trace {
defer un(trace(p, "FuncType"))
}
pos := p.expect(token.FUNC)
tparams, params := p.parseParameters(true)
if tparams != nil {
p.error(tparams.Pos(), "function type must have no type parameters")
}
results := p.parseResult()
return &ast.FuncType{Func: pos, Params: params, Results: results}
}
func (p *parser) parseMethodSpec() *ast.Field {
if p.trace {
defer un(trace(p, "MethodSpec"))
}
doc := p.leadComment
var idents []*ast.Ident
var typ ast.Expr
x := p.parseTypeName(nil)
if ident, _ := x.(*ast.Ident); ident != nil {
switch {
case p.tok == token.LBRACK:
// generic method or embedded instantiated type
lbrack := p.pos
p.next()
p.exprLev++
x := p.parseExpr()
p.exprLev--
if name0, _ := x.(*ast.Ident); name0 != nil && p.tok != token.COMMA && p.tok != token.RBRACK {
// generic method m[T any]
//
// Interface methods do not have type parameters. We parse them for a
// better error message and improved error recovery.
_ = p.parseParameterList(name0, nil, token.RBRACK)
_ = p.expect(token.RBRACK)
p.error(lbrack, "interface method must have no type parameters")
// TODO(rfindley) refactor to share code with parseFuncType.
_, params := p.parseParameters(false)
results := p.parseResult()
idents = []*ast.Ident{ident}
typ = &ast.FuncType{
Func: token.NoPos,
Params: params,
Results: results,
}
} else {
// embedded instantiated type
// TODO(rfindley) should resolve all identifiers in x.
list := []ast.Expr{x}
if p.atComma("type argument list", token.RBRACK) {
p.exprLev++
p.next()
for p.tok != token.RBRACK && p.tok != token.EOF {
list = append(list, p.parseType())
if !p.atComma("type argument list", token.RBRACK) {
break
}
p.next()
}
p.exprLev--
}
// rbrack := p.expectClosing(token.RBRACK, "type argument list")
// typ = typeparams.PackIndexExpr(ident, lbrack, list, rbrack)
}
case p.tok == token.LPAREN:
// ordinary method
// TODO(rfindley) refactor to share code with parseFuncType.
_, params := p.parseParameters(false)
results := p.parseResult()
idents = []*ast.Ident{ident}
typ = &ast.FuncType{Func: token.NoPos, Params: params, Results: results}
default:
// embedded type
typ = x
}
} else {
// embedded, possibly instantiated type
typ = x
if p.tok == token.LBRACK {
// embedded instantiated interface
typ = p.parseTypeInstance(typ)
}
}
// Comment is added at the callsite: the field below may joined with
// additional type specs using '|'.
// TODO(rfindley) this should be refactored.
// TODO(rfindley) add more tests for comment handling.
return &ast.Field{Doc: doc, Names: idents, Type: typ}
}
func (p *parser) embeddedElem(x ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "EmbeddedElem"))
}
if x == nil {
x = p.embeddedTerm()
}
for p.tok == token.OR {
t := new(ast.BinaryExpr)
t.OpPos = p.pos
t.Op = token.OR
p.next()
t.X = x
t.Y = p.embeddedTerm()
x = t
}
return x
}
func (p *parser) embeddedTerm() ast.Expr {
if p.trace {
defer un(trace(p, "EmbeddedTerm"))
}
if p.tok == token.TILDE {
t := new(ast.UnaryExpr)
t.OpPos = p.pos
t.Op = token.TILDE
p.next()
t.X = p.parseType()
return t
}
t := p.tryIdentOrType()
if t == nil {
pos := p.pos
p.errorExpected(pos, "~ term or type")
p.advance(exprEnd)
return &ast.BadExpr{From: pos, To: p.pos}
}
return t
}
func (p *parser) parseInterfaceType() *ast.InterfaceType {
if p.trace {
defer un(trace(p, "InterfaceType"))
}
pos := p.expect(token.INTERFACE)
lbrace := p.expect(token.LBRACE)
var list []*ast.Field
parseElements:
for {
switch {
case p.tok == token.IDENT:
f := p.parseMethodSpec()
if f.Names == nil {
f.Type = p.embeddedElem(f.Type)
}
f.Comment = p.expectSemi()
list = append(list, f)
case p.tok == token.TILDE:
typ := p.embeddedElem(nil)
comment := p.expectSemi()
list = append(list, &ast.Field{Type: typ, Comment: comment})
default:
if t := p.tryIdentOrType(); t != nil {
typ := p.embeddedElem(t)
comment := p.expectSemi()
list = append(list, &ast.Field{Type: typ, Comment: comment})
} else {
break parseElements
}
}
}
// TODO(rfindley): the error produced here could be improved, since we could
// accept an identifier, 'type', or a '}' at this point.
rbrace := p.expect(token.RBRACE)
return &ast.InterfaceType{
Interface: pos,
Methods: &ast.FieldList{
Opening: lbrace,
List: list,
Closing: rbrace,
},
}
}
func (p *parser) parseMapType() *ast.MapType {
if p.trace {
defer un(trace(p, "MapType"))
}
pos := p.expect(token.MAP)
p.expect(token.LBRACK)
key := p.parseType()
p.expect(token.RBRACK)
value := p.parseType()
return &ast.MapType{Map: pos, Key: key, Value: value}
}
func (p *parser) parseChanType() *ast.ChanType {
if p.trace {
defer un(trace(p, "ChanType"))
}
pos := p.pos
dir := ast.SEND | ast.RECV
var arrow token.Pos
if p.tok == token.CHAN {
p.next()
if p.tok == token.ARROW {
arrow = p.pos
p.next()
dir = ast.SEND
}
} else {
arrow = p.expect(token.ARROW)
p.expect(token.CHAN)
dir = ast.RECV
}
value := p.parseType()
return &ast.ChanType{Begin: pos, Arrow: arrow, Dir: dir, Value: value}
}
func (p *parser) parseTypeInstance(typ ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "TypeInstance"))
}
opening := p.expect(token.LBRACK)
p.exprLev++
var list []ast.Expr
for p.tok != token.RBRACK && p.tok != token.EOF {
list = append(list, p.parseType())
if !p.atComma("type argument list", token.RBRACK) {
break
}
p.next()
}
p.exprLev--
closing := p.expectClosing(token.RBRACK, "type argument list")
if len(list) == 0 {
p.errorExpected(closing, "type argument list")
return &ast.IndexExpr{
X: typ,
Lbrack: opening,
Index: &ast.BadExpr{From: opening + 1, To: closing},
Rbrack: closing,
}
}
return nil // typeparams.PackIndexExpr(typ, opening, list, closing)
}
func (p *parser) tryIdentOrType() ast.Expr {
defer decNestLev(incNestLev(p))
switch p.tok {
case token.IDENT:
typ := p.parseTypeName(nil)
if p.tok == token.LBRACK {
typ = p.parseTypeInstance(typ)
}
return typ
case token.LBRACK:
lbrack := p.expect(token.LBRACK)
return p.parseArrayList(lbrack) // math: full array exprs
// return p.parseArrayType(lbrack, nil)
case token.STRUCT:
return p.parseStructType()
case token.MUL:
return p.parsePointerType()
case token.FUNC:
return p.parseFuncType()
case token.INTERFACE:
return p.parseInterfaceType()
case token.MAP:
return p.parseMapType()
case token.CHAN, token.ARROW:
return p.parseChanType()
case token.LPAREN:
lparen := p.pos
p.next()
typ := p.parseType()
rparen := p.expect(token.RPAREN)
return &ast.ParenExpr{Lparen: lparen, X: typ, Rparen: rparen}
}
// no type found
return nil
}
// ----------------------------------------------------------------------------
// Blocks
func (p *parser) parseStmtList() (list []ast.Stmt) {
if p.trace {
defer un(trace(p, "StatementList"))
}
for p.tok != token.CASE && p.tok != token.DEFAULT && p.tok != token.RBRACE && p.tok != token.EOF {
list = append(list, p.parseStmt())
}
return
}
func (p *parser) parseBody() *ast.BlockStmt {
if p.trace {
defer un(trace(p, "Body"))
}
lbrace := p.expect(token.LBRACE)
list := p.parseStmtList()
rbrace := p.expect2(token.RBRACE)
return &ast.BlockStmt{Lbrace: lbrace, List: list, Rbrace: rbrace}
}
func (p *parser) parseBlockStmt() *ast.BlockStmt {
if p.trace {
defer un(trace(p, "BlockStmt"))
}
lbrace := p.expect(token.LBRACE)
if p.tok == token.EOF { // math: allow start only
return &ast.BlockStmt{Lbrace: lbrace}
}
list := p.parseStmtList()
rbrace := p.expect2(token.RBRACE)
return &ast.BlockStmt{Lbrace: lbrace, List: list, Rbrace: rbrace}
}
// ----------------------------------------------------------------------------
// Expressions
func (p *parser) parseFuncTypeOrLit() ast.Expr {
if p.trace {
defer un(trace(p, "FuncTypeOrLit"))
}
typ := p.parseFuncType()
if p.tok != token.LBRACE {
// function type only
return typ
}
p.exprLev++
body := p.parseBody()
p.exprLev--
return &ast.FuncLit{Type: typ, Body: body}
}
// parseOperand may return an expression or a raw type (incl. array
// types of the form [...]T). Callers must verify the result.
func (p *parser) parseOperand() ast.Expr {
if p.trace {
defer un(trace(p, "Operand"))
}
switch p.tok {
case token.IDENT:
x := p.parseIdent()
return x
case token.INT, token.FLOAT, token.IMAG, token.CHAR, token.STRING:
x := &ast.BasicLit{ValuePos: p.pos, Kind: p.tok, Value: p.lit}
// fmt.Println("operand lit:", p.lit)
p.next()
if p.tok == token.COLON {
return p.parseSliceExpr(x)
} else {
return x
}
case token.LPAREN:
lparen := p.pos
p.next()
p.exprLev++
x := p.parseRhs() // types may be parenthesized: (some type)
p.exprLev--
rparen := p.expect(token.RPAREN)
return &ast.ParenExpr{Lparen: lparen, X: x, Rparen: rparen}
case token.FUNC:
return p.parseFuncTypeOrLit()
case token.COLON:
p.expect(token.COLON)
return p.parseSliceExpr(nil)
}
if typ := p.tryIdentOrType(); typ != nil { // do not consume trailing type parameters
// could be type for composite literal or conversion
_, isIdent := typ.(*ast.Ident)
passert(!isIdent, "type cannot be identifier")
return typ
}
// we have an error
pos := p.pos
p.errorExpected(pos, "operand")
p.advance(stmtStart)
return &ast.BadExpr{From: pos, To: p.pos}
}
func (p *parser) parseSelector(x ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "Selector"))
}
sel := p.parseIdent()
return &ast.SelectorExpr{X: x, Sel: sel}
}
func (p *parser) parseTypeAssertion(x ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "TypeAssertion"))
}
lparen := p.expect(token.LPAREN)
var typ ast.Expr
if p.tok == token.TYPE {
// type switch: typ == nil
p.next()
} else {
typ = p.parseType()
}
rparen := p.expect(token.RPAREN)
return &ast.TypeAssertExpr{X: x, Type: typ, Lparen: lparen, Rparen: rparen}
}
func (p *parser) parseIndexOrSliceOrInstance(x ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "parseIndexOrSliceOrInstance"))
}
lbrack := p.expect(token.LBRACK)
if p.tok == token.RBRACK {
// empty index, slice or index expressions are not permitted;
// accept them for parsing tolerance, but complain
p.errorExpected(p.pos, "operand")
rbrack := p.pos
p.next()
return &ast.IndexExpr{
X: x,
Lbrack: lbrack,
Index: &ast.BadExpr{From: rbrack, To: rbrack},
Rbrack: rbrack,
}
}
ix := p.parseArrayList(lbrack)
return &ast.IndexExpr{
X: x,
Lbrack: lbrack,
Index: ix,
Rbrack: ix.Rbrack,
}
}
func (p *parser) parseSliceExpr(ex ast.Expr) *ast.SliceExpr {
if p.trace {
defer un(trace(p, "parseSliceExpr"))
}
lbrack := p.pos
p.exprLev++
const N = 3 // change the 3 to 2 to disable 3-index slices
var index [N]ast.Expr
index[0] = ex
var colons [N - 1]token.Pos
ncolons := 0
if ex == nil {
ncolons++
}
var rpos token.Pos
// fmt.Println(ncolons, p.tok)
switch p.tok {
case token.COLON:
// slice expression
for p.tok == token.COLON && ncolons < len(colons) {
colons[ncolons] = p.pos
ncolons++
p.next()
if p.tok != token.COMMA && p.tok != token.COLON && p.tok != token.RBRACK && p.tok != token.EOF {
ix := p.parseRhs()
if se, ok := ix.(*ast.SliceExpr); ok {
index[ncolons] = se.Low
if ncolons == 1 && se.High != nil {
ncolons++
index[ncolons] = se.High
}
// fmt.Printf("nc: %d low: %#v hi: %#v max: %#v\n", ncolons, se.Low, se.High, se.Max)
} else {
// fmt.Printf("nc: %d low: %#v\n", ncolons, ix)
if _, ok := ix.(*ast.BadExpr); !ok {
index[ncolons] = ix
}
}
// } else {
// fmt.Println(ncolons, "else")
}
}
case token.COMMA:
rpos = p.pos // expect(token.COMMA)
case token.RBRACK:
rpos = p.pos // expect(token.RBRACK)
// instance expression
// args = append(args, index[0])
// for p.tok == token.COMMA {
// p.next()
// if p.tok != token.RBRACK && p.tok != token.EOF {
// args = append(args, p.parseType())
// }
// }
default:
ix := p.parseRhs()
// fmt.Printf("nc: %d ix: %#v\n", ncolons, ix)
index[ncolons] = ix
}
p.exprLev--
// rbrack := p.expect(token.RBRACK)
// slice expression
slice3 := false
if ncolons == 2 {
slice3 = true
// Check presence of middle and final index here rather than during type-checking
// to prevent erroneous programs from passing through gofmt (was go.dev/issue/7305).
// if index[1] == nil {
// p.error(colons[0], "middle index required in 3-index slice")
// index[1] = &ast.BadExpr{From: colons[0] + 1, To: colons[1]}
// }
// if index[2] == nil {
// p.error(colons[1], "final index required in 3-index slice")
// index[2] = &ast.BadExpr{From: colons[1] + 1} // , To: rbrack
// }
}
se := &ast.SliceExpr{Lbrack: lbrack, Low: index[0], High: index[1], Max: index[2], Slice3: slice3, Rbrack: rpos}
// fmt.Printf("final: %#v\n", se)
return se
//
// if len(args) == 0 {
// // index expression
// return &ast.IndexExpr{X: x, Lbrack: lbrack, Index: index[0], Rbrack: rbrack}
// }
// instance expression
return nil // typeparams.PackIndexExpr(x, lbrack, args, rbrack)
}
func (p *parser) parseCallOrConversion(fun ast.Expr) *ast.CallExpr {
if p.trace {
defer un(trace(p, "CallOrConversion"))
}
lparen := p.expect(token.LPAREN)
p.exprLev++
var list []ast.Expr
var ellipsis token.Pos
for p.tok != token.RPAREN && p.tok != token.EOF && !ellipsis.IsValid() {
list = append(list, p.parseRhs()) // builtins may expect a type: make(some type, ...)
if p.tok == token.ELLIPSIS {
ellipsis = p.pos
p.next()
}
if !p.atComma("argument list", token.RPAREN) {
break
}
p.next()
}
p.exprLev--
rparen := p.expectClosing(token.RPAREN, "argument list")
return &ast.CallExpr{Fun: fun, Lparen: lparen, Args: list, Ellipsis: ellipsis, Rparen: rparen}
}
func (p *parser) parseValue() ast.Expr {
if p.trace {
defer un(trace(p, "Element"))
}
if p.tok == token.LBRACE {
return p.parseLiteralValue(nil)
}
x := p.parseExpr()
return x
}
func (p *parser) parseElement() ast.Expr {
if p.trace {
defer un(trace(p, "Element"))
}
x := p.parseValue()
if p.tok == token.COLON {
colon := p.pos
p.next()
x = &ast.KeyValueExpr{Key: x, Colon: colon, Value: p.parseValue()}
}
return x
}
func (p *parser) parseElementList() (list []ast.Expr) {
if p.trace {
defer un(trace(p, "ElementList"))
}
for p.tok != token.RBRACE && p.tok != token.EOF {
list = append(list, p.parseElement())
if !p.atComma("composite literal", token.RBRACE) {
break
}
p.next()
}
return
}
func (p *parser) parseLiteralValue(typ ast.Expr) ast.Expr {
defer decNestLev(incNestLev(p))
if p.trace {
defer un(trace(p, "LiteralValue"))
}
lbrace := p.expect(token.LBRACE)
var elts []ast.Expr
p.exprLev++
if p.tok != token.RBRACE {
elts = p.parseElementList()
}
p.exprLev--
rbrace := p.expectClosing(token.RBRACE, "composite literal")
return &ast.CompositeLit{Type: typ, Lbrace: lbrace, Elts: elts, Rbrace: rbrace}
}
func (p *parser) parsePrimaryExpr(x ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "PrimaryExpr"))
}
// math: ellipses can show up in index expression.
if p.tok == token.ELLIPSIS {
p.next()
return &ast.Ellipsis{Ellipsis: p.pos}
}
if x == nil {
x = p.parseOperand()
}
// We track the nesting here rather than at the entry for the function,
// since it can iteratively produce a nested output, and we want to
// limit how deep a structure we generate.
var n int
defer func() { p.nestLev -= n }()
for n = 1; ; n++ {
incNestLev(p)
switch p.tok {
case token.PERIOD:
p.next()
switch p.tok {
case token.IDENT:
x = p.parseSelector(x)
case token.LPAREN:
x = p.parseTypeAssertion(x)
default:
pos := p.pos
p.errorExpected(pos, "selector or type assertion")
// TODO(rFindley) The check for token.RBRACE below is a targeted fix
// to error recovery sufficient to make the x/tools tests to
// pass with the new parsing logic introduced for type
// parameters. Remove this once error recovery has been
// more generally reconsidered.
if p.tok != token.RBRACE {
p.next() // make progress
}
sel := &ast.Ident{NamePos: pos, Name: "_"}
x = &ast.SelectorExpr{X: x, Sel: sel}
}
case token.LBRACK:
x = p.parseIndexOrSliceOrInstance(x)
case token.LPAREN:
x = p.parseCallOrConversion(x)
case token.LBRACE:
// operand may have returned a parenthesized complit
// type; accept it but complain if we have a complit
t := ast.Unparen(x)
// determine if '{' belongs to a composite literal or a block statement
switch t.(type) {
case *ast.BadExpr, *ast.Ident, *ast.SelectorExpr:
if p.exprLev < 0 {
return x
}
// x is possibly a composite literal type
case *ast.IndexExpr, *ast.IndexListExpr:
if p.exprLev < 0 {
return x
}
// x is possibly a composite literal type
case *ast.ArrayType, *ast.StructType, *ast.MapType:
// x is a composite literal type
default:
return x
}
if t != x {
p.error(t.Pos(), "cannot parenthesize type in composite literal")
// already progressed, no need to advance
}
x = p.parseLiteralValue(x)
default:
return x
}
}
}
func (p *parser) parseUnaryExpr() ast.Expr {
defer decNestLev(incNestLev(p))
if p.trace {
defer un(trace(p, "UnaryExpr"))
}
switch p.tok {
case token.ADD, token.SUB, token.NOT, token.XOR, token.AND, token.TILDE:
pos, op := p.pos, p.tok
p.next()
x := p.parseUnaryExpr()
return &ast.UnaryExpr{OpPos: pos, Op: op, X: x}
case token.ARROW:
// channel type or receive expression
arrow := p.pos
p.next()
// If the next token is token.CHAN we still don't know if it
// is a channel type or a receive operation - we only know
// once we have found the end of the unary expression. There
// are two cases:
//
// <- type => (<-type) must be channel type
// <- expr => <-(expr) is a receive from an expression
//
// In the first case, the arrow must be re-associated with
// the channel type parsed already:
//
// <- (chan type) => (<-chan type)
// <- (chan<- type) => (<-chan (<-type))
x := p.parseUnaryExpr()
// determine which case we have
if typ, ok := x.(*ast.ChanType); ok {
// (<-type)
// re-associate position info and <-
dir := ast.SEND
for ok && dir == ast.SEND {
if typ.Dir == ast.RECV {
// error: (<-type) is (<-(<-chan T))
p.errorExpected(typ.Arrow, "'chan'")
}
arrow, typ.Begin, typ.Arrow = typ.Arrow, arrow, arrow
dir, typ.Dir = typ.Dir, ast.RECV
typ, ok = typ.Value.(*ast.ChanType)
}
if dir == ast.SEND {
p.errorExpected(arrow, "channel type")
}
return x
}
// <-(expr)
return &ast.UnaryExpr{OpPos: arrow, Op: token.ARROW, X: x}
case token.MUL:
// pointer type or unary "*" expression
pos := p.pos
p.next()
x := p.parseUnaryExpr()
return &ast.StarExpr{Star: pos, X: x}
}
return p.parsePrimaryExpr(nil)
}
func (p *parser) tokPrec() (token.Token, int) {
tok := p.tok
if p.inRhs && tok == token.ASSIGN {
tok = token.EQL
}
if p.tok == token.ILLEGAL && p.lit == "@" {
// fmt.Println("@ token")
return token.ILLEGAL, 5
}
return tok, tok.Precedence()
}
// parseBinaryExpr parses a (possibly) binary expression.
// If x is non-nil, it is used as the left operand.
//
// TODO(rfindley): parseBinaryExpr has become overloaded. Consider refactoring.
func (p *parser) parseBinaryExpr(x ast.Expr, prec1 int) ast.Expr {
if p.trace {
defer un(trace(p, "BinaryExpr"))
}
if x == nil {
x = p.parseUnaryExpr()
}
// We track the nesting here rather than at the entry for the function,
// since it can iteratively produce a nested output, and we want to
// limit how deep a structure we generate.
var n int
defer func() { p.nestLev -= n }()
for n = 1; ; n++ {
incNestLev(p)
op, oprec := p.tokPrec()
if oprec < prec1 {
return x
}
pos := p.pos
if op == token.ILLEGAL {
p.next()
} else {
pos = p.expect(op)
}
y := p.parseBinaryExpr(nil, oprec+1)
x = &ast.BinaryExpr{X: x, OpPos: pos, Op: op, Y: y}
}
}
// The result may be a type or even a raw type ([...]int).
func (p *parser) parseExpr() ast.Expr {
if p.trace {
defer un(trace(p, "Expression"))
}
return p.parseBinaryExpr(nil, token.LowestPrec+1)
}
func (p *parser) parseRhs() ast.Expr {
old := p.inRhs
p.inRhs = true
x := p.parseExpr()
p.inRhs = old
return x
}
// ----------------------------------------------------------------------------
// Statements
// Parsing modes for parseSimpleStmt.
const (
basic = iota
labelOk
rangeOk
)
// parseSimpleStmt returns true as 2nd result if it parsed the assignment
// of a range clause (with mode == rangeOk). The returned statement is an
// assignment with a right-hand side that is a single unary expression of
// the form "range x". No guarantees are given for the left-hand side.
func (p *parser) parseSimpleStmt(mode int) (ast.Stmt, bool) {
if p.trace {
defer un(trace(p, "SimpleStmt"))
}
x := p.parseList(false)
switch p.tok {
case
token.DEFINE, token.ASSIGN, token.ADD_ASSIGN,
token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN,
token.REM_ASSIGN, token.AND_ASSIGN, token.OR_ASSIGN,
token.XOR_ASSIGN, token.SHL_ASSIGN, token.SHR_ASSIGN, token.AND_NOT_ASSIGN:
// assignment statement, possibly part of a range clause
pos, tok := p.pos, p.tok
p.next()
var y []ast.Expr
isRange := false
if mode == rangeOk && p.tok == token.RANGE && (tok == token.DEFINE || tok == token.ASSIGN) {
pos := p.pos
p.next()
y = []ast.Expr{&ast.UnaryExpr{OpPos: pos, Op: token.RANGE, X: p.parseRhs()}}
isRange = true
} else {
y = p.parseList(true)
}
return &ast.AssignStmt{Lhs: x, TokPos: pos, Tok: tok, Rhs: y}, isRange
}
if len(x) > 1 {
p.errorExpected(x[0].Pos(), "1 expression")
// continue with first expression
}
switch p.tok {
case token.COLON:
// labeled statement
colon := p.pos
p.next()
if label, isIdent := x[0].(*ast.Ident); mode == labelOk && isIdent {
// Go spec: The scope of a label is the body of the function
// in which it is declared and excludes the body of any nested
// function.
stmt := &ast.LabeledStmt{Label: label, Colon: colon, Stmt: p.parseStmt()}
return stmt, false
}
// The label declaration typically starts at x[0].Pos(), but the label
// declaration may be erroneous due to a token after that position (and
// before the ':'). If SpuriousErrors is not set, the (only) error
// reported for the line is the illegal label error instead of the token
// before the ':' that caused the problem. Thus, use the (latest) colon
// position for error reporting.
p.error(colon, "illegal label declaration")
return &ast.BadStmt{From: x[0].Pos(), To: colon + 1}, false
case token.ARROW:
// send statement
arrow := p.pos
p.next()
y := p.parseRhs()
return &ast.SendStmt{Chan: x[0], Arrow: arrow, Value: y}, false
case token.INC, token.DEC:
// increment or decrement
s := &ast.IncDecStmt{X: x[0], TokPos: p.pos, Tok: p.tok}
p.next()
return s, false
}
// expression
return &ast.ExprStmt{X: x[0]}, false
}
func (p *parser) parseCallExpr(callType string) *ast.CallExpr {
x := p.parseRhs() // could be a conversion: (some type)(x)
if t := ast.Unparen(x); t != x {
p.error(x.Pos(), fmt.Sprintf("expression in %s must not be parenthesized", callType))
x = t
}
if call, isCall := x.(*ast.CallExpr); isCall {
return call
}
if _, isBad := x.(*ast.BadExpr); !isBad {
// only report error if it's a new one
p.error(p.safePos(x.End()), fmt.Sprintf("expression in %s must be function call", callType))
}
return nil
}
func (p *parser) parseGoStmt() ast.Stmt {
if p.trace {
defer un(trace(p, "GoStmt"))
}
pos := p.expect(token.GO)
call := p.parseCallExpr("go")
p.expectSemi()
if call == nil {
return &ast.BadStmt{From: pos, To: pos + 2} // len("go")
}
return &ast.GoStmt{Go: pos, Call: call}
}
func (p *parser) parseDeferStmt() ast.Stmt {
if p.trace {
defer un(trace(p, "DeferStmt"))
}
pos := p.expect(token.DEFER)
call := p.parseCallExpr("defer")
p.expectSemi()
if call == nil {
return &ast.BadStmt{From: pos, To: pos + 5} // len("defer")
}
return &ast.DeferStmt{Defer: pos, Call: call}
}
func (p *parser) parseReturnStmt() *ast.ReturnStmt {
if p.trace {
defer un(trace(p, "ReturnStmt"))
}
pos := p.pos
p.expect(token.RETURN)
var x []ast.Expr
if p.tok != token.SEMICOLON && p.tok != token.RBRACE {
x = p.parseList(true)
}
p.expectSemi()
return &ast.ReturnStmt{Return: pos, Results: x}
}
func (p *parser) parseBranchStmt(tok token.Token) *ast.BranchStmt {
if p.trace {
defer un(trace(p, "BranchStmt"))
}
pos := p.expect(tok)
var label *ast.Ident
if tok != token.FALLTHROUGH && p.tok == token.IDENT {
label = p.parseIdent()
}
p.expectSemi()
return &ast.BranchStmt{TokPos: pos, Tok: tok, Label: label}
}
func (p *parser) makeExpr(s ast.Stmt, want string) ast.Expr {
if s == nil {
return nil
}
if es, isExpr := s.(*ast.ExprStmt); isExpr {
return es.X
}
found := "simple statement"
if _, isAss := s.(*ast.AssignStmt); isAss {
found = "assignment"
}
p.error(s.Pos(), fmt.Sprintf("expected %s, found %s (missing parentheses around composite literal?)", want, found))
return &ast.BadExpr{From: s.Pos(), To: p.safePos(s.End())}
}
// parseIfHeader is an adjusted version of parser.header
// in cmd/compile/internal/syntax/parser.go, which has
// been tuned for better error handling.
func (p *parser) parseIfHeader() (init ast.Stmt, cond ast.Expr) {
if p.tok == token.LBRACE {
p.error(p.pos, "missing condition in if statement")
cond = &ast.BadExpr{From: p.pos, To: p.pos}
return
}
// p.tok != token.LBRACE
prevLev := p.exprLev
p.exprLev = -1
if p.tok != token.SEMICOLON {
// accept potential variable declaration but complain
if p.tok == token.VAR {
p.next()
p.error(p.pos, "var declaration not allowed in if initializer")
}
init, _ = p.parseSimpleStmt(basic)
}
var condStmt ast.Stmt
var semi struct {
pos token.Pos
lit string // ";" or "\n"; valid if pos.IsValid()
}
if p.tok != token.LBRACE {
if p.tok == token.SEMICOLON {
semi.pos = p.pos
semi.lit = p.lit
p.next()
} else {
p.expect(token.SEMICOLON)
}
if p.tok != token.LBRACE {
condStmt, _ = p.parseSimpleStmt(basic)
}
} else {
condStmt = init
init = nil
}
if condStmt != nil {
cond = p.makeExpr(condStmt, "boolean expression")
} else if semi.pos.IsValid() {
if semi.lit == "\n" {
p.error(semi.pos, "unexpected newline, expecting { after if clause")
} else {
p.error(semi.pos, "missing condition in if statement")
}
}
// make sure we have a valid AST
if cond == nil {
cond = &ast.BadExpr{From: p.pos, To: p.pos}
}
p.exprLev = prevLev
return
}
func (p *parser) parseIfStmt() *ast.IfStmt {
defer decNestLev(incNestLev(p))
if p.trace {
defer un(trace(p, "IfStmt"))
}
pos := p.expect(token.IF)
init, cond := p.parseIfHeader()
body := p.parseBlockStmt()
var else_ ast.Stmt
if p.tok == token.ELSE {
p.next()
switch p.tok {
case token.IF:
else_ = p.parseIfStmt()
case token.LBRACE:
else_ = p.parseBlockStmt()
p.expectSemi()
default:
p.errorExpected(p.pos, "if statement or block")
else_ = &ast.BadStmt{From: p.pos, To: p.pos}
}
} else {
p.expectSemi()
}
return &ast.IfStmt{If: pos, Init: init, Cond: cond, Body: body, Else: else_}
}
func (p *parser) parseCaseClause() *ast.CaseClause {
if p.trace {
defer un(trace(p, "CaseClause"))
}
pos := p.pos
var list []ast.Expr
if p.tok == token.CASE {
p.next()
list = p.parseList(true)
} else {
p.expect(token.DEFAULT)
}
colon := p.expect(token.COLON)
body := p.parseStmtList()
return &ast.CaseClause{Case: pos, List: list, Colon: colon, Body: body}
}
func isTypeSwitchAssert(x ast.Expr) bool {
a, ok := x.(*ast.TypeAssertExpr)
return ok && a.Type == nil
}
func (p *parser) isTypeSwitchGuard(s ast.Stmt) bool {
switch t := s.(type) {
case *ast.ExprStmt:
// x.(type)
return isTypeSwitchAssert(t.X)
case *ast.AssignStmt:
// v := x.(type)
if len(t.Lhs) == 1 && len(t.Rhs) == 1 && isTypeSwitchAssert(t.Rhs[0]) {
switch t.Tok {
case token.ASSIGN:
// permit v = x.(type) but complain
p.error(t.TokPos, "expected ':=', found '='")
fallthrough
case token.DEFINE:
return true
}
}
}
return false
}
func (p *parser) parseSwitchStmt() ast.Stmt {
if p.trace {
defer un(trace(p, "SwitchStmt"))
}
pos := p.expect(token.SWITCH)
var s1, s2 ast.Stmt
if p.tok != token.LBRACE {
prevLev := p.exprLev
p.exprLev = -1
if p.tok != token.SEMICOLON {
s2, _ = p.parseSimpleStmt(basic)
}
if p.tok == token.SEMICOLON {
p.next()
s1 = s2
s2 = nil
if p.tok != token.LBRACE {
// A TypeSwitchGuard may declare a variable in addition
// to the variable declared in the initial SimpleStmt.
// Introduce extra scope to avoid redeclaration errors:
//
// switch t := 0; t := x.(T) { ... }
//
// (this code is not valid Go because the first t
// cannot be accessed and thus is never used, the extra
// scope is needed for the correct error message).
//
// If we don't have a type switch, s2 must be an expression.
// Having the extra nested but empty scope won't affect it.
s2, _ = p.parseSimpleStmt(basic)
}
}
p.exprLev = prevLev
}
typeSwitch := p.isTypeSwitchGuard(s2)
lbrace := p.expect(token.LBRACE)
var list []ast.Stmt
for p.tok == token.CASE || p.tok == token.DEFAULT {
list = append(list, p.parseCaseClause())
}
rbrace := p.expect(token.RBRACE)
p.expectSemi()
body := &ast.BlockStmt{Lbrace: lbrace, List: list, Rbrace: rbrace}
if typeSwitch {
return &ast.TypeSwitchStmt{Switch: pos, Init: s1, Assign: s2, Body: body}
}
return &ast.SwitchStmt{Switch: pos, Init: s1, Tag: p.makeExpr(s2, "switch expression"), Body: body}
}
func (p *parser) parseCommClause() *ast.CommClause {
if p.trace {
defer un(trace(p, "CommClause"))
}
pos := p.pos
var comm ast.Stmt
if p.tok == token.CASE {
p.next()
lhs := p.parseList(false)
if p.tok == token.ARROW {
// SendStmt
if len(lhs) > 1 {
p.errorExpected(lhs[0].Pos(), "1 expression")
// continue with first expression
}
arrow := p.pos
p.next()
rhs := p.parseRhs()
comm = &ast.SendStmt{Chan: lhs[0], Arrow: arrow, Value: rhs}
} else {
// RecvStmt
if tok := p.tok; tok == token.ASSIGN || tok == token.DEFINE {
// RecvStmt with assignment
if len(lhs) > 2 {
p.errorExpected(lhs[0].Pos(), "1 or 2 expressions")
// continue with first two expressions
lhs = lhs[0:2]
}
pos := p.pos
p.next()
rhs := p.parseRhs()
comm = &ast.AssignStmt{Lhs: lhs, TokPos: pos, Tok: tok, Rhs: []ast.Expr{rhs}}
} else {
// lhs must be single receive operation
if len(lhs) > 1 {
p.errorExpected(lhs[0].Pos(), "1 expression")
// continue with first expression
}
comm = &ast.ExprStmt{X: lhs[0]}
}
}
} else {
p.expect(token.DEFAULT)
}
colon := p.expect(token.COLON)
body := p.parseStmtList()
return &ast.CommClause{Case: pos, Comm: comm, Colon: colon, Body: body}
}
func (p *parser) parseSelectStmt() *ast.SelectStmt {
if p.trace {
defer un(trace(p, "SelectStmt"))
}
pos := p.expect(token.SELECT)
lbrace := p.expect(token.LBRACE)
var list []ast.Stmt
for p.tok == token.CASE || p.tok == token.DEFAULT {
list = append(list, p.parseCommClause())
}
rbrace := p.expect(token.RBRACE)
p.expectSemi()
body := &ast.BlockStmt{Lbrace: lbrace, List: list, Rbrace: rbrace}
return &ast.SelectStmt{Select: pos, Body: body}
}
func (p *parser) parseForStmt() ast.Stmt {
if p.trace {
defer un(trace(p, "ForStmt"))
}
pos := p.expect(token.FOR)
var s1, s2, s3 ast.Stmt
var isRange bool
if p.tok != token.LBRACE {
prevLev := p.exprLev
p.exprLev = -1
if p.tok != token.SEMICOLON {
if p.tok == token.RANGE {
// "for range x" (nil lhs in assignment)
pos := p.pos
p.next()
y := []ast.Expr{&ast.UnaryExpr{OpPos: pos, Op: token.RANGE, X: p.parseRhs()}}
s2 = &ast.AssignStmt{Rhs: y}
isRange = true
} else {
s2, isRange = p.parseSimpleStmt(rangeOk)
}
}
if !isRange && p.tok == token.SEMICOLON {
p.next()
s1 = s2
s2 = nil
if p.tok != token.SEMICOLON {
s2, _ = p.parseSimpleStmt(basic)
}
p.expectSemi()
if p.tok != token.LBRACE {
s3, _ = p.parseSimpleStmt(basic)
}
}
p.exprLev = prevLev
}
body := p.parseBlockStmt()
p.expectSemi()
if isRange {
as := s2.(*ast.AssignStmt)
// check lhs
var key, value ast.Expr
switch len(as.Lhs) {
case 0:
// nothing to do
case 1:
key = as.Lhs[0]
case 2:
key, value = as.Lhs[0], as.Lhs[1]
default:
p.errorExpected(as.Lhs[len(as.Lhs)-1].Pos(), "at most 2 expressions")
return &ast.BadStmt{From: pos, To: p.safePos(body.End())}
}
// parseSimpleStmt returned a right-hand side that
// is a single unary expression of the form "range x"
x := as.Rhs[0].(*ast.UnaryExpr).X
return &ast.RangeStmt{
For: pos,
Key: key,
Value: value,
TokPos: as.TokPos,
Tok: as.Tok,
Range: as.Rhs[0].Pos(),
X: x,
Body: body,
}
}
// regular for statement
return &ast.ForStmt{
For: pos,
Init: s1,
Cond: p.makeExpr(s2, "boolean or range expression"),
Post: s3,
Body: body,
}
}
func (p *parser) parseStmt() (s ast.Stmt) {
defer decNestLev(incNestLev(p))
if p.trace {
defer un(trace(p, "Statement"))
}
switch p.tok {
case token.CONST, token.TYPE, token.VAR:
s = &ast.DeclStmt{Decl: p.parseDecl(stmtStart)}
case
// tokens that may start an expression
token.IDENT, token.INT, token.FLOAT, token.IMAG, token.CHAR, token.STRING, token.FUNC, token.LPAREN, // operands
token.LBRACK, token.STRUCT, token.MAP, token.CHAN, token.INTERFACE, // composite types
token.ADD, token.SUB, token.MUL, token.AND, token.XOR, token.ARROW, token.NOT: // unary operators
s, _ = p.parseSimpleStmt(labelOk)
// because of the required look-ahead, labeled statements are
// parsed by parseSimpleStmt - don't expect a semicolon after
// them
if _, isLabeledStmt := s.(*ast.LabeledStmt); !isLabeledStmt {
p.expectSemi()
}
case token.GO:
s = p.parseGoStmt()
case token.DEFER:
s = p.parseDeferStmt()
case token.RETURN:
s = p.parseReturnStmt()
case token.BREAK, token.CONTINUE, token.GOTO, token.FALLTHROUGH:
s = p.parseBranchStmt(p.tok)
case token.LBRACE:
s = p.parseBlockStmt()
p.expectSemi()
case token.IF:
s = p.parseIfStmt()
case token.SWITCH:
s = p.parseSwitchStmt()
case token.SELECT:
s = p.parseSelectStmt()
case token.FOR:
s = p.parseForStmt()
case token.SEMICOLON:
// Is it ever possible to have an implicit semicolon
// producing an empty statement in a valid program?
// (handle correctly anyway)
s = &ast.EmptyStmt{Semicolon: p.pos, Implicit: p.lit == "\n"}
p.next()
case token.RBRACE:
// a semicolon may be omitted before a closing "}"
s = &ast.EmptyStmt{Semicolon: p.pos, Implicit: true}
default:
// no statement found
pos := p.pos
p.errorExpected(pos, "statement")
p.advance(stmtStart)
s = &ast.BadStmt{From: pos, To: p.pos}
}
return
}
// ----------------------------------------------------------------------------
// Declarations
type parseSpecFunction func(doc *ast.CommentGroup, keyword token.Token, iota int) ast.Spec
func (p *parser) parseImportSpec(doc *ast.CommentGroup, _ token.Token, _ int) ast.Spec {
if p.trace {
defer un(trace(p, "ImportSpec"))
}
var ident *ast.Ident
switch p.tok {
case token.IDENT:
ident = p.parseIdent()
case token.PERIOD:
ident = &ast.Ident{NamePos: p.pos, Name: "."}
p.next()
}
pos := p.pos
var path string
if p.tok == token.STRING {
path = p.lit
p.next()
} else if p.tok.IsLiteral() {
p.error(pos, "import path must be a string")
p.next()
} else {
p.error(pos, "missing import path")
p.advance(exprEnd)
}
comment := p.expectSemi()
// collect imports
spec := &ast.ImportSpec{
Doc: doc,
Name: ident,
Path: &ast.BasicLit{ValuePos: pos, Kind: token.STRING, Value: path},
Comment: comment,
}
p.imports = append(p.imports, spec)
return spec
}
func (p *parser) parseValueSpec(doc *ast.CommentGroup, keyword token.Token, iota int) ast.Spec {
if p.trace {
defer un(trace(p, keyword.String()+"Spec"))
}
idents := p.parseIdentList()
var typ ast.Expr
var values []ast.Expr
switch keyword {
case token.CONST:
// always permit optional type and initialization for more tolerant parsing
if p.tok != token.EOF && p.tok != token.SEMICOLON && p.tok != token.RPAREN {
typ = p.tryIdentOrType()
if p.tok == token.ASSIGN {
p.next()
values = p.parseList(true)
}
}
case token.VAR:
if p.tok != token.ASSIGN {
typ = p.parseType()
}
if p.tok == token.ASSIGN {
p.next()
values = p.parseList(true)
}
default:
panic("unreachable")
}
comment := p.expectSemi()
spec := &ast.ValueSpec{
Doc: doc,
Names: idents,
Type: typ,
Values: values,
Comment: comment,
}
return spec
}
func (p *parser) parseGenericType(spec *ast.TypeSpec, openPos token.Pos, name0 *ast.Ident, typ0 ast.Expr) {
if p.trace {
defer un(trace(p, "parseGenericType"))
}
list := p.parseParameterList(name0, typ0, token.RBRACK)
closePos := p.expect(token.RBRACK)
spec.TypeParams = &ast.FieldList{Opening: openPos, List: list, Closing: closePos}
// Let the type checker decide whether to accept type parameters on aliases:
// see go.dev/issue/46477.
if p.tok == token.ASSIGN {
// type alias
spec.Assign = p.pos
p.next()
}
spec.Type = p.parseType()
}
func (p *parser) parseTypeSpec(doc *ast.CommentGroup, _ token.Token, _ int) ast.Spec {
if p.trace {
defer un(trace(p, "TypeSpec"))
}
name := p.parseIdent()
spec := &ast.TypeSpec{Doc: doc, Name: name}
if p.tok == token.LBRACK {
// spec.Name "[" ...
// array/slice type or type parameter list
lbrack := p.pos
p.next()
if p.tok == token.IDENT {
// We may have an array type or a type parameter list.
// In either case we expect an expression x (which may
// just be a name, or a more complex expression) which
// we can analyze further.
//
// A type parameter list may have a type bound starting
// with a "[" as in: P []E. In that case, simply parsing
// an expression would lead to an error: P[] is invalid.
// But since index or slice expressions are never constant
// and thus invalid array length expressions, if the name
// is followed by "[" it must be the start of an array or
// slice constraint. Only if we don't see a "[" do we
// need to parse a full expression. Notably, name <- x
// is not a concern because name <- x is a statement and
// not an expression.
var x ast.Expr = p.parseIdent()
if p.tok != token.LBRACK {
// To parse the expression starting with name, expand
// the call sequence we would get by passing in name
// to parser.expr, and pass in name to parsePrimaryExpr.
p.exprLev++
lhs := p.parsePrimaryExpr(x)
x = p.parseBinaryExpr(lhs, token.LowestPrec+1)
p.exprLev--
}
// Analyze expression x. If we can split x into a type parameter
// name, possibly followed by a type parameter type, we consider
// this the start of a type parameter list, with some caveats:
// a single name followed by "]" tilts the decision towards an
// array declaration; a type parameter type that could also be
// an ordinary expression but which is followed by a comma tilts
// the decision towards a type parameter list.
if pname, ptype := extractName(x, p.tok == token.COMMA); pname != nil && (ptype != nil || p.tok != token.RBRACK) {
// spec.Name "[" pname ...
// spec.Name "[" pname ptype ...
// spec.Name "[" pname ptype "," ...
p.parseGenericType(spec, lbrack, pname, ptype) // ptype may be nil
} else {
// spec.Name "[" pname "]" ...
// spec.Name "[" x ...
spec.Type = p.parseArrayType(lbrack, x)
}
} else {
// array type
spec.Type = p.parseArrayType(lbrack, nil)
}
} else {
// no type parameters
if p.tok == token.ASSIGN {
// type alias
spec.Assign = p.pos
p.next()
}
spec.Type = p.parseType()
}
spec.Comment = p.expectSemi()
return spec
}
// extractName splits the expression x into (name, expr) if syntactically
// x can be written as name expr. The split only happens if expr is a type
// element (per the isTypeElem predicate) or if force is set.
// If x is just a name, the result is (name, nil). If the split succeeds,
// the result is (name, expr). Otherwise the result is (nil, x).
// Examples:
//
// x force name expr
// ------------------------------------
// P*[]int T/F P *[]int
// P*E T P *E
// P*E F nil P*E
// P([]int) T/F P []int
// P(E) T P E
// P(E) F nil P(E)
// P*E|F|~G T/F P *E|F|~G
// P*E|F|G T P *E|F|G
// P*E|F|G F nil P*E|F|G
func extractName(x ast.Expr, force bool) (*ast.Ident, ast.Expr) {
switch x := x.(type) {
case *ast.Ident:
return x, nil
case *ast.BinaryExpr:
switch x.Op {
case token.MUL:
if name, _ := x.X.(*ast.Ident); name != nil && (force || isTypeElem(x.Y)) {
// x = name *x.Y
return name, &ast.StarExpr{Star: x.OpPos, X: x.Y}
}
case token.OR:
if name, lhs := extractName(x.X, force || isTypeElem(x.Y)); name != nil && lhs != nil {
// x = name lhs|x.Y
op := *x
op.X = lhs
return name, &op
}
}
case *ast.CallExpr:
if name, _ := x.Fun.(*ast.Ident); name != nil {
if len(x.Args) == 1 && x.Ellipsis == token.NoPos && (force || isTypeElem(x.Args[0])) {
// x = name "(" x.ArgList[0] ")"
return name, x.Args[0]
}
}
}
return nil, x
}
// isTypeElem reports whether x is a (possibly parenthesized) type element expression.
// The result is false if x could be a type element OR an ordinary (value) expression.
func isTypeElem(x ast.Expr) bool {
switch x := x.(type) {
case *ast.ArrayType, *ast.StructType, *ast.FuncType, *ast.InterfaceType, *ast.MapType, *ast.ChanType:
return true
case *ast.BinaryExpr:
return isTypeElem(x.X) || isTypeElem(x.Y)
case *ast.UnaryExpr:
return x.Op == token.TILDE
case *ast.ParenExpr:
return isTypeElem(x.X)
}
return false
}
func (p *parser) parseGenDecl(keyword token.Token, f parseSpecFunction) *ast.GenDecl {
if p.trace {
defer un(trace(p, "GenDecl("+keyword.String()+")"))
}
doc := p.leadComment
pos := p.expect(keyword)
var lparen, rparen token.Pos
var list []ast.Spec
if p.tok == token.LPAREN {
lparen = p.pos
p.next()
for iota := 0; p.tok != token.RPAREN && p.tok != token.EOF; iota++ {
list = append(list, f(p.leadComment, keyword, iota))
}
rparen = p.expect(token.RPAREN)
p.expectSemi()
} else {
list = append(list, f(nil, keyword, 0))
}
return &ast.GenDecl{
Doc: doc,
TokPos: pos,
Tok: keyword,
Lparen: lparen,
Specs: list,
Rparen: rparen,
}
}
func (p *parser) parseFuncDecl() *ast.FuncDecl {
if p.trace {
defer un(trace(p, "FunctionDecl"))
}
doc := p.leadComment
pos := p.expect(token.FUNC)
var recv *ast.FieldList
if p.tok == token.LPAREN {
_, recv = p.parseParameters(false)
}
ident := p.parseIdent()
tparams, params := p.parseParameters(true)
if recv != nil && tparams != nil {
// Method declarations do not have type parameters. We parse them for a
// better error message and improved error recovery.
p.error(tparams.Opening, "method must have no type parameters")
tparams = nil
}
results := p.parseResult()
var body *ast.BlockStmt
switch p.tok {
case token.LBRACE:
body = p.parseBody()
p.expectSemi()
case token.SEMICOLON:
p.next()
if p.tok == token.LBRACE {
// opening { of function declaration on next line
p.error(p.pos, "unexpected semicolon or newline before {")
body = p.parseBody()
p.expectSemi()
}
default:
p.expectSemi()
}
decl := &ast.FuncDecl{
Doc: doc,
Recv: recv,
Name: ident,
Type: &ast.FuncType{
Func: pos,
TypeParams: tparams,
Params: params,
Results: results,
},
Body: body,
}
return decl
}
func (p *parser) parseDecl(sync map[token.Token]bool) ast.Decl {
if p.trace {
defer un(trace(p, "Declaration"))
}
var f parseSpecFunction
switch p.tok {
case token.IMPORT:
f = p.parseImportSpec
case token.CONST, token.VAR:
f = p.parseValueSpec
case token.TYPE:
f = p.parseTypeSpec
case token.FUNC:
return p.parseFuncDecl()
default:
pos := p.pos
p.errorExpected(pos, "declaration")
p.advance(sync)
return &ast.BadDecl{From: pos, To: p.pos}
}
return p.parseGenDecl(p.tok, f)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package transpile
import (
"go/token"
)
// ReplaceIdentAt replaces an identifier spanning n tokens
// starting at given index, with a single identifier with given string.
// This is used in Exec mode for dealing with identifiers and paths that are
// separately-parsed by Go.
func (tk Tokens) ReplaceIdentAt(at int, str string, n int) Tokens {
ntk := append(tk[:at], &Token{Tok: token.IDENT, Str: str})
ntk = append(ntk, tk[at+n:]...)
return ntk
}
// Path extracts a standard path or URL expression from the current
// list of tokens (starting at index 0), returning the path string
// and the number of tokens included in the path.
// Restricts processing to contiguous elements with no spaces!
// If it is not a path, returns nil string, 0
func (tk Tokens) Path(idx0 bool) (string, int) {
n := len(tk)
if n == 0 {
return "", 0
}
t0 := tk[0]
ispath := (t0.IsPathDelim() || t0.Tok == token.TILDE)
if n == 1 {
if ispath {
return t0.String(), 1
}
return "", 0
}
str := tk[0].String()
lastEnd := int(tk[0].Pos) + len(str)
ci := 1
if !ispath {
lastEnd = int(tk[0].Pos)
ci = 0
if t0.Tok != token.IDENT {
return "", 0
}
tin := 1
tid := t0.Str
tindelim := tk[tin].IsPathDelim()
if idx0 {
tindelim = tk[tin].Tok == token.QUO
}
if (int(tk[tin].Pos) > lastEnd+len(tid)) || !(tk[tin].Tok == token.COLON || tindelim) {
return "", 0
}
ci += tin + 1
str = tid + tk[tin].String()
lastEnd += len(str)
if ci >= n || int(tk[ci].Pos) > lastEnd { // just 2 or 2 and a space
if tk[tin].Tok == token.COLON { // go Ident: static initializer
return "", 0
}
}
}
prevWasDelim := true
for {
if ci >= n || int(tk[ci].Pos) > lastEnd {
return str, ci
}
ct := tk[ci]
if ct.IsPathDelim() || ct.IsPathExtraDelim() {
prevWasDelim = true
str += ct.String()
lastEnd += len(ct.String())
ci++
continue
}
if ct.Tok == token.STRING {
prevWasDelim = true
str += EscapeQuotes(ct.String())
lastEnd += len(ct.String())
ci++
continue
}
if !prevWasDelim {
if ct.Tok == token.ILLEGAL && ct.Str == `\` && ci+1 < n && int(tk[ci+1].Pos) == lastEnd+2 {
prevWasDelim = true
str += " "
ci++
lastEnd += 2
continue
}
return str, ci
}
if ct.IsWord() {
prevWasDelim = false
str += ct.String()
lastEnd += len(ct.String())
ci++
continue
}
return str, ci
}
}
func (tk *Token) IsPathDelim() bool {
return tk.Tok == token.PERIOD || tk.Tok == token.QUO
}
func (tk *Token) IsPathExtraDelim() bool {
return tk.Tok == token.SUB || tk.Tok == token.ASSIGN || tk.Tok == token.REM || (tk.Tok == token.ILLEGAL && (tk.Str == "?" || tk.Str == "#"))
}
// IsWord returns true if the token is some kind of word-like entity,
// including IDENT, STRING, CHAR, or one of the Go keywords.
// This is for exec filtering.
func (tk *Token) IsWord() bool {
return tk.Tok == token.IDENT || tk.IsGo() || tk.Tok == token.STRING || tk.Tok == token.CHAR
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package transpile
import (
"errors"
"fmt"
"go/format"
"log/slog"
"os"
"strings"
"cogentcore.org/core/base/logx"
"cogentcore.org/core/base/num"
"cogentcore.org/core/base/stringsx"
"golang.org/x/tools/imports"
)
// State holds the transpiling state
type State struct {
// FuncToVar translates function definitions into variable definitions,
// which is the default for interactive use of random code fragments
// without the complete go formatting.
// For pure transpiling of a complete codebase with full proper Go formatting
// this should be turned off.
FuncToVar bool
// MathMode is on when math mode is turned on.
MathMode bool
// MathRecord is state of auto-recording of data into current directory
// in math mode.
MathRecord bool
// depth of delim at the end of the current line. if 0, was complete.
ParenDepth, BraceDepth, BrackDepth, TypeDepth, DeclDepth int
// Chunks of code lines that are accumulated during Transpile,
// each of which should be evaluated separately, to avoid
// issues with contextual effects from import, package etc.
Chunks []string
// current stack of transpiled lines, that are accumulated into
// code Chunks.
Lines []string
// stack of runtime errors.
Errors []error
// if this is non-empty, it is the name of the last command defined.
// triggers insertion of the AddCommand call to add to list of defined commands.
lastCommand string
}
// NewState returns a new transpiling state; mostly for testing
func NewState() *State {
st := &State{FuncToVar: true}
return st
}
// TranspileCode processes each line of given code,
// adding the results to the LineStack
func (st *State) TranspileCode(code string) {
lns := strings.Split(code, "\n")
n := len(lns)
if n == 0 {
return
}
for _, ln := range lns {
hasDecl := st.DeclDepth > 0
tl := st.TranspileLine(ln)
st.AddLine(tl)
if st.BraceDepth == 0 && st.BrackDepth == 0 && st.ParenDepth == 1 && st.lastCommand != "" {
st.lastCommand = ""
nl := len(st.Lines)
st.Lines[nl-1] = st.Lines[nl-1] + ")"
st.ParenDepth--
}
if hasDecl && st.DeclDepth == 0 { // break at decl
st.AddChunk()
}
}
}
// TranspileFile transpiles the given input goal file to the
// given output Go file. If no existing package declaration
// is found, then package main and func main declarations are
// added. This also affects how functions are interpreted.
func (st *State) TranspileFile(in string, out string) error {
b, err := os.ReadFile(in)
if err != nil {
return err
}
code := string(b)
lns := stringsx.SplitLines(code)
hasPackage := false
for _, ln := range lns {
if strings.HasPrefix(ln, "package ") {
hasPackage = true
break
}
}
if hasPackage {
st.FuncToVar = false // use raw functions
}
st.TranspileCode(code)
st.FuncToVar = true
if err != nil {
return err
}
hdr := `package main
import (
"cogentcore.org/lab/goal"
"cogentcore.org/lab/goal/goalib"
"cogentcore.org/lab/tensor"
_ "cogentcore.org/lab/tensor/tmath"
_ "cogentcore.org/lab/stats/stats"
_ "cogentcore.org/lab/stats/metric"
)
func main() {
goal := goal.NewGoal()
_ = goal
`
src := st.Code()
res := []byte(src)
bsrc := res
gen := fmt.Sprintf("// Code generated by \"goal build\"; DO NOT EDIT.\n//line %s:1\n", in)
if hasPackage {
bsrc = []byte(gen + src)
res, err = format.Source(bsrc)
} else {
bsrc = []byte(gen + hdr + src + "\n}")
res, err = imports.Process(out, bsrc, nil)
}
if err != nil {
res = bsrc
fmt.Println(err.Error())
} else {
err = st.DepthError()
}
werr := os.WriteFile(out, res, 0666)
return errors.Join(err, werr)
}
// TotalDepth returns the sum of any unresolved paren, brace, or bracket depths.
func (st *State) TotalDepth() int {
return num.Abs(st.ParenDepth) + num.Abs(st.BraceDepth) + num.Abs(st.BrackDepth)
}
// ResetCode resets the stack of transpiled code
func (st *State) ResetCode() {
st.Chunks = nil
st.Lines = nil
}
// ResetDepth resets the current depths to 0
func (st *State) ResetDepth() {
st.ParenDepth, st.BraceDepth, st.BrackDepth, st.TypeDepth, st.DeclDepth = 0, 0, 0, 0, 0
}
// DepthError reports an error if any of the parsing depths are not zero,
// to be called at the end of transpiling a complete block of code.
func (st *State) DepthError() error {
if st.TotalDepth() == 0 {
return nil
}
str := ""
if st.ParenDepth != 0 {
str += fmt.Sprintf("Incomplete parentheses (), remaining depth: %d\n", st.ParenDepth)
}
if st.BraceDepth != 0 {
str += fmt.Sprintf("Incomplete braces [], remaining depth: %d\n", st.BraceDepth)
}
if st.BrackDepth != 0 {
str += fmt.Sprintf("Incomplete brackets {}, remaining depth: %d\n", st.BrackDepth)
}
if str != "" {
slog.Error(str)
return errors.New(str)
}
return nil
}
// AddLine adds line on the stack
func (st *State) AddLine(ln string) {
st.Lines = append(st.Lines, ln)
}
// Code returns the current transpiled lines,
// split into chunks that should be compiled separately.
func (st *State) Code() string {
st.AddChunk()
if len(st.Chunks) == 0 {
return ""
}
return strings.Join(st.Chunks, "\n")
}
// AddChunk adds current lines into a chunk of code
// that should be compiled separately.
func (st *State) AddChunk() {
if len(st.Lines) == 0 {
return
}
st.Chunks = append(st.Chunks, strings.Join(st.Lines, "\n"))
st.Lines = nil
}
// AddError adds the given error to the error stack if it is non-nil,
// and calls the Cancel function if set, to stop execution.
// This is the main way that goal errors are handled.
// It also prints the error.
func (st *State) AddError(err error) error {
if err == nil {
return nil
}
st.Errors = append(st.Errors, err)
logx.PrintlnError(err)
return err
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package transpile
import (
"go/scanner"
"go/token"
"log/slog"
"slices"
"strings"
"cogentcore.org/core/base/logx"
)
// Token provides full data for one token
type Token struct {
// Go token classification
Tok token.Token
// Literal string
Str string
// position in the original string.
// this is only set for the original parse,
// not for transpiled additions.
Pos token.Pos
}
// Tokens converts the string into tokens
func TokensFromString(ln string) Tokens {
fset := token.NewFileSet()
f := fset.AddFile("", fset.Base(), len(ln))
var sc scanner.Scanner
sc.Init(f, []byte(ln), errHandler, scanner.ScanComments|2) // 2 is non-exported dontInsertSemis
// note to Go team: just export this stuff. seriously.
var toks Tokens
for {
pos, tok, lit := sc.Scan()
if tok == token.EOF {
break
}
// logx.PrintfDebug(" token: %s\t%s\t%q\n", fset.Position(pos), tok, lit)
toks = append(toks, &Token{Tok: tok, Pos: pos, Str: lit})
}
return toks
}
func errHandler(pos token.Position, msg string) {
logx.PrintlnDebug("Scan Error:", pos, msg)
}
// Tokens is a slice of Token
type Tokens []*Token
// NewToken returns a new token, for generated tokens without Pos
func NewToken(tok token.Token, str ...string) *Token {
tk := &Token{Tok: tok}
if len(str) > 0 {
tk.Str = str[0]
}
return tk
}
// Add adds a new token, for generated tokens without Pos
func (tk *Tokens) Add(tok token.Token, str ...string) *Token {
nt := NewToken(tok, str...)
*tk = append(*tk, nt)
return nt
}
// AddMulti adds new basic tokens (not IDENT).
func (tk *Tokens) AddMulti(tok ...token.Token) {
for _, t := range tok {
tk.Add(t)
}
}
// AddTokens adds given tokens to our list
func (tk *Tokens) AddTokens(toks ...*Token) *Tokens {
*tk = append(*tk, toks...)
return tk
}
// Insert inserts a new token at given position
func (tk *Tokens) Insert(i int, tok token.Token, str ...string) *Token {
nt := NewToken(tok, str...)
*tk = slices.Insert(*tk, i, nt)
return nt
}
// Last returns the final token in the list
func (tk Tokens) Last() *Token {
n := len(tk)
if n == 0 {
return nil
}
return tk[n-1]
}
// DeleteLastComma removes any final Comma.
// easier to generate and delete at the end
func (tk *Tokens) DeleteLastComma() {
lt := tk.Last()
if lt == nil {
return
}
if lt.Tok == token.COMMA {
*tk = (*tk)[:len(*tk)-1]
}
}
// String returns the string for the token
func (tk *Token) String() string {
if tk.Str != "" {
return tk.Str
}
return tk.Tok.String()
}
// IsBacktickString returns true if the given STRING uses backticks
func (tk *Token) IsBacktickString() bool {
if tk.Tok != token.STRING {
return false
}
return (tk.Str[0] == '`')
}
// IsGo returns true if the given token is a Go Keyword or Comment
func (tk *Token) IsGo() bool {
if tk.Tok >= token.BREAK && tk.Tok <= token.VAR {
return true
}
if tk.Tok == token.COMMENT {
return true
}
return false
}
// IsValidExecIdent returns true if the given token is a valid component
// of an Exec mode identifier
func (tk *Token) IsValidExecIdent() bool {
return (tk.IsGo() || tk.Tok == token.IDENT || tk.Tok == token.SUB || tk.Tok == token.DEC || tk.Tok == token.INT || tk.Tok == token.FLOAT || tk.Tok == token.ASSIGN)
}
// String is the stringer version which includes the token ID
// in addition to the string literal
func (tk Tokens) String() string {
str := ""
for _, tok := range tk {
str += "[" + tok.Tok.String() + "] "
if tok.Str != "" {
str += tok.Str + " "
}
}
if len(str) == 0 {
return str
}
return str[:len(str)-1] // remove trailing space
}
// Code returns concatenated Str values of the tokens,
// to generate a surface-valid code string.
func (tk Tokens) Code() string {
n := len(tk)
if n == 0 {
return ""
}
str := ""
prvIdent := false
for _, tok := range tk {
switch {
case tok.IsOp():
if tok.Tok == token.INC || tok.Tok == token.DEC {
str += tok.String() + " "
} else if tok.Tok == token.MUL {
str += " " + tok.String()
} else {
str += " " + tok.String() + " "
}
prvIdent = false
case tok.Tok == token.ELLIPSIS:
str += " " + tok.String()
prvIdent = false
case tok.IsBracket() || tok.Tok == token.PERIOD:
if tok.Tok == token.RBRACE || tok.Tok == token.LBRACE {
if len(str) > 0 && str[len(str)-1] != ' ' {
str += " "
}
str += tok.String() + " "
} else {
str += tok.String()
}
prvIdent = false
case tok.Tok == token.COMMA || tok.Tok == token.COLON || tok.Tok == token.SEMICOLON:
str += tok.String() + " "
prvIdent = false
case tok.Tok == token.STRUCT:
str += " " + tok.String() + " "
case tok.Tok == token.FUNC:
if prvIdent {
str += " "
}
str += tok.String()
prvIdent = true
case tok.Tok == token.COMMENT:
if str != "" {
str += " "
}
str += tok.String()
case tok.IsGo():
if prvIdent {
str += " "
}
str += tok.String()
if tok.Tok != token.MAP {
str += " "
}
prvIdent = false
case tok.Tok == token.IDENT || tok.Tok == token.STRING:
if prvIdent {
str += " "
}
str += tok.String()
prvIdent = true
default:
str += tok.String()
prvIdent = false
}
}
if len(str) == 0 {
return str
}
if str[len(str)-1] == ' ' {
return str[:len(str)-1]
}
return str
}
// IsOp returns true if the given token is an operator
func (tk *Token) IsOp() bool {
if tk.Tok >= token.ADD && tk.Tok <= token.DEFINE {
return true
}
return false
}
// Contains returns true if the token string contains any of the given token(s)
func (tk Tokens) Contains(toks ...token.Token) bool {
if len(toks) == 0 {
slog.Error("programmer error: tokens.Contains with no args")
return false
}
for _, t := range tk {
for _, st := range toks {
if t.Tok == st {
return true
}
}
}
return false
}
// EscapeQuotes replaces any " with \"
func EscapeQuotes(str string) string {
return strings.ReplaceAll(str, `"`, `\"`)
}
// AddQuotes surrounds given string with quotes,
// also escaping any contained quotes
func AddQuotes(str string) string {
return `"` + EscapeQuotes(str) + `"`
}
// IsBracket returns true if the given token is a bracket delimiter:
// paren, brace, bracket
func (tk *Token) IsBracket() bool {
if (tk.Tok >= token.LPAREN && tk.Tok <= token.LBRACE) || (tk.Tok >= token.RPAREN && tk.Tok <= token.RBRACE) {
return true
}
return false
}
// RightMatching returns the position (or -1 if not found) for the
// right matching [paren, bracket, brace] given the left one that
// is at the 0 position of the current set of tokens.
func (tk Tokens) RightMatching() int {
n := len(tk)
if n == 0 {
return -1
}
rb := token.RPAREN
lb := tk[0].Tok
switch lb {
case token.LPAREN:
rb = token.RPAREN
case token.LBRACK:
rb = token.RBRACK
case token.LBRACE:
rb = token.RBRACE
}
depth := 0
for i := 1; i < n; i++ {
tok := tk[i].Tok
switch tok {
case rb:
if depth <= 0 {
return i
}
depth--
case lb:
depth++
}
}
return -1
}
// BracketDepths returns the depths for the three bracket delimiters
// [paren, bracket, brace], based on unmatched right versions.
func (tk Tokens) BracketDepths() (paren, brace, brack int) {
n := len(tk)
if n == 0 {
return
}
for i := 0; i < n; i++ {
tok := tk[i].Tok
switch tok {
case token.LPAREN:
paren++
case token.LBRACE:
brace++
case token.LBRACK:
brack++
case token.RPAREN:
paren--
case token.RBRACE:
brace--
case token.RBRACK:
brack--
}
}
return
}
// ModeEnd returns the position (or -1 if not found) for the
// next ILLEGAL mode token ($ or #) given the starting one that
// is at the 0 position of the current set of tokens.
func (tk Tokens) ModeEnd() int {
n := len(tk)
if n == 0 {
return -1
}
st := tk[0].Str
for i := 1; i < n; i++ {
if tk[i].Tok != token.ILLEGAL {
continue
}
if tk[i].Str == st {
return i
}
}
return -1
}
// IsAssignExpr checks if there are any Go assignment or define tokens
// outside of { } Go code.
func (tk Tokens) IsAssignExpr() bool {
n := len(tk)
if n == 0 {
return false
}
for i := 1; i < n; i++ {
tok := tk[i].Tok
if tok == token.ASSIGN || tok == token.DEFINE || (tok >= token.ADD_ASSIGN && tok <= token.AND_NOT_ASSIGN) {
return true
}
if tok == token.LBRACE { // skip Go mode
rp := tk[i:n].RightMatching()
if rp > 0 {
i += rp
}
}
}
return false
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package transpile
import (
"fmt"
"go/token"
"slices"
"strings"
"cogentcore.org/core/base/logx"
)
// TranspileLine is the main function for parsing a single line of goal input,
// returning a new transpiled line of code that converts Exec code into corresponding
// Go function calls.
func (st *State) TranspileLine(code string) string {
if len(code) == 0 {
return code
}
if strings.HasPrefix(code, "#!") {
return ""
}
toks := st.TranspileLineTokens(code)
paren, brace, brack := toks.BracketDepths()
st.ParenDepth += paren
st.BraceDepth += brace
st.BrackDepth += brack
// logx.PrintlnDebug("depths: ", st.ParenDepth, st.BraceDepth, st.BrackDepth, st.TypeDepth, st.DeclDepth)
if st.TypeDepth > 0 && st.BraceDepth == 0 {
st.TypeDepth = 0
}
if st.DeclDepth > 0 && (st.ParenDepth == 0 && st.BraceDepth == 0) {
st.DeclDepth = 0
}
// logx.PrintlnDebug("depths: ", st.ParenDepth, st.BraceDepth, st.BrackDepth, st.TypeDepth, st.DeclDepth)
return toks.Code()
}
// TranspileLineTokens returns the tokens for the full line
func (st *State) TranspileLineTokens(code string) Tokens {
if code == "" {
return nil
}
toks := TokensFromString(code)
n := len(toks)
if n == 0 {
return toks
}
if st.MathMode {
if len(toks) >= 2 {
if toks[0].Tok == token.ILLEGAL && toks[0].Str == "#" && toks[1].Tok == token.ILLEGAL && toks[1].Str == "#" {
st.MathMode = false
return nil
}
}
return st.TranspileMath(toks, code, true)
}
ewords, err := ExecWords(code)
if err != nil {
st.AddError(err)
return nil
}
logx.PrintlnDebug("\n########## line:\n", code, "\nTokens:", len(toks), "\n", toks.String(), "\nWords:", len(ewords), "\n", ewords)
if toks[0].Tok == token.TYPE {
st.TypeDepth++
}
if toks[0].Tok == token.IMPORT || toks[0].Tok == token.VAR || toks[0].Tok == token.CONST {
st.DeclDepth++
}
// logx.PrintlnDebug("depths: ", st.ParenDepth, st.BraceDepth, st.BrackDepth, st.TypeDepth, st.DeclDepth)
if st.TypeDepth > 0 || st.DeclDepth > 0 {
logx.PrintlnDebug("go: type / decl defn")
return st.TranspileGo(toks, code)
}
t0 := toks[0]
_, t0pn := toks.Path(true) // true = first position
en := len(ewords)
f0exec := (t0.Tok == token.IDENT && ExecWordIsCommand(ewords[0]))
if f0exec && n > 1 && toks[1].Tok == token.COLON { // go Ident: static initializer
f0exec = false
}
switch {
case t0.Tok == token.ILLEGAL:
if t0.Str == "#" {
logx.PrintlnDebug("math #")
if toks[1].Tok == token.ILLEGAL && toks[1].Str == "#" {
st.MathMode = true
return nil
}
return st.TranspileMath(toks[1:], code, true)
}
return st.TranspileExec(ewords, false)
case t0.Tok == token.LBRACE:
logx.PrintlnDebug("go: { } line")
return st.TranspileGoRange(toks, code, 1, n-1)
case t0.Tok == token.LBRACK:
logx.PrintlnDebug("exec: [ ] line")
return st.TranspileExec(ewords, false) // it processes the [ ]
case t0.Tok == token.IDENT && t0.Str == "command":
st.lastCommand = toks[1].Str // 1 is the name -- triggers AddCommand
toks = toks[2:] // get rid of first
toks.Insert(0, token.IDENT, "goal.AddCommand")
toks.Insert(1, token.LPAREN)
toks.Insert(2, token.STRING, `"`+st.lastCommand+`"`)
toks.Insert(3, token.COMMA)
toks.Insert(4, token.FUNC)
toks.Insert(5, token.LPAREN)
toks.Insert(6, token.IDENT, "args")
toks.Insert(7, token.ELLIPSIS)
toks.Insert(8, token.IDENT, "string")
toks.Insert(9, token.RPAREN)
toks.AddTokens(st.TranspileGo(toks[11:], code)...)
case t0.IsGo():
if t0.Tok == token.GO {
if !toks.Contains(token.LPAREN) {
logx.PrintlnDebug("exec: go command")
return st.TranspileExec(ewords, false)
}
}
logx.PrintlnDebug("go keyword")
return st.TranspileGo(toks, code)
case toks[n-1].Tok == token.INC || toks[n-1].Tok == token.DEC:
logx.PrintlnDebug("go ++ / --")
return st.TranspileGo(toks, code)
case t0pn > 0: // path expr
logx.PrintlnDebug("exec: path...")
return st.TranspileExec(ewords, false)
case t0.Tok == token.STRING:
logx.PrintlnDebug("exec: string...")
return st.TranspileExec(ewords, false)
case f0exec && en == 1:
logx.PrintlnDebug("exec: 1 word")
return st.TranspileExec(ewords, false)
case !f0exec: // exec must be IDENT
logx.PrintlnDebug("go: not ident")
return st.TranspileGo(toks, code)
case f0exec && en > 1 && ewords[0] != "set" && toks.IsAssignExpr():
logx.PrintlnDebug("go: assignment or defn")
return st.TranspileGo(toks, code)
case f0exec && en > 1 && ewords[0] != "set" && toks.IsAssignExpr():
logx.PrintlnDebug("go: assignment or defn")
return st.TranspileGo(toks, code)
case f0exec: // now any ident
logx.PrintlnDebug("exec: ident..")
return st.TranspileExec(ewords, false)
default:
logx.PrintlnDebug("go: default")
return st.TranspileGo(toks, code)
}
return toks
}
// TranspileGoRange returns transpiled tokens assuming Go code,
// for given start, end (exclusive) range of given tokens and code.
// In general the positions in the tokens applies to the _original_ code
// so you should just keep the original code string. However, this is
// needed for a specific case.
func (st *State) TranspileGoRange(toks Tokens, code string, start, end int) Tokens {
codeSt := toks[start].Pos - 1
codeEd := token.Pos(len(code))
if end <= len(toks)-1 {
codeEd = toks[end].Pos - 1
}
return st.TranspileGo(toks[start:end], code[codeSt:codeEd])
}
// TranspileGo returns transpiled tokens assuming Go code.
// Unpacks any encapsulated shell or math expressions.
func (st *State) TranspileGo(toks Tokens, code string) Tokens {
n := len(toks)
if n == 0 {
return toks
}
if st.FuncToVar && toks[0].Tok == token.FUNC { // reorder as an assignment
if len(toks) > 1 && toks[1].Tok == token.IDENT {
toks[0] = toks[1]
toks.Insert(1, token.DEFINE)
toks[2] = &Token{Tok: token.FUNC}
n = len(toks)
}
}
gtoks := make(Tokens, 0, len(toks)) // return tokens
for i := 0; i < n; i++ {
tok := toks[i]
switch {
case tok.Tok == token.ILLEGAL:
et := toks[i:].ModeEnd()
if et > 0 {
if tok.Str == "#" {
gtoks.AddTokens(st.TranspileMath(toks[i+1:i+et], code, false)...)
} else {
gtoks.AddTokens(st.TranspileExecTokens(toks[i+1:i+et+1], code, true)...)
}
i += et
continue
} else {
gtoks = append(gtoks, tok)
}
case tok.Tok == token.LBRACK && i > 0 && toks[i-1].Tok == token.IDENT: // index expr
ixtoks := toks[i:]
rm := ixtoks.RightMatching()
if rm < 3 {
gtoks = append(gtoks, tok)
continue
}
idx := st.TranspileGoNDimIndex(toks, code, >oks, i-1, rm+i)
if idx > 0 {
i = idx
} else {
gtoks = append(gtoks, tok)
}
default:
gtoks = append(gtoks, tok)
}
}
return gtoks
}
// TranspileExecString returns transpiled tokens assuming Exec code,
// from a string, with the given bool indicating whether [Output] is needed.
func (st *State) TranspileExecString(str string, output bool) Tokens {
if len(str) <= 1 {
return nil
}
ewords, err := ExecWords(str)
if err != nil {
st.AddError(err)
}
return st.TranspileExec(ewords, output)
}
// TranspileExecTokens returns transpiled tokens assuming Exec code,
// from given tokens, with the given bool indicating
// whether [Output] is needed.
func (st *State) TranspileExecTokens(toks Tokens, code string, output bool) Tokens {
nt := len(toks)
if nt == 0 {
return nil
}
str := code[toks[0].Pos-1 : toks[nt-1].Pos-1]
return st.TranspileExecString(str, output)
}
// TranspileExec returns transpiled tokens assuming Exec code,
// with the given bools indicating the type of run to execute.
func (st *State) TranspileExec(ewords []string, output bool) Tokens {
n := len(ewords)
if n == 0 {
return nil
}
etoks := make(Tokens, 0, n+5) // return tokens
var execTok *Token
bgJob := false
noStop := false
if ewords[0] == "[" {
ewords = ewords[1:]
n--
noStop = true
}
startExec := func() {
bgJob = false
etoks.Add(token.IDENT, "goal")
etoks.Add(token.PERIOD)
switch {
case output && noStop:
execTok = etoks.Add(token.IDENT, "OutputErrOK")
case output && !noStop:
execTok = etoks.Add(token.IDENT, "Output")
case !output && noStop:
execTok = etoks.Add(token.IDENT, "RunErrOK")
case !output && !noStop:
execTok = etoks.Add(token.IDENT, "Run")
}
etoks.Add(token.LPAREN)
}
endExec := func() {
if bgJob {
execTok.Str = "Start"
}
etoks.DeleteLastComma()
etoks.Add(token.RPAREN)
}
startExec()
for i := 0; i < n; i++ {
f := ewords[i]
switch {
case f == "{": // embedded go
if n < i+3 {
st.AddError(fmt.Errorf("goal: no matching right brace } found in exec command line"))
} else {
gstr := ewords[i+1]
etoks.AddTokens(st.TranspileGo(TokensFromString(gstr), gstr)...)
etoks.Add(token.COMMA)
i += 2
}
case f == "[":
noStop = true
case f == "]": // solo is def end
// just skip
noStop = false
case f == "&":
bgJob = true
case f[0] == '|':
execTok.Str = "Start"
etoks.Add(token.IDENT, AddQuotes(f))
etoks.Add(token.COMMA)
endExec()
etoks.Add(token.SEMICOLON)
etoks.AddTokens(st.TranspileExec(ewords[i+1:], output)...)
return etoks
case f == ";":
endExec()
etoks.Add(token.SEMICOLON)
etoks.AddTokens(st.TranspileExec(ewords[i+1:], output)...)
return etoks
default:
if f[0] == '"' || f[0] == '`' {
etoks.Add(token.STRING, f)
} else {
etoks.Add(token.IDENT, AddQuotes(f)) // mark as an IDENT but add quotes!
}
etoks.Add(token.COMMA)
}
}
endExec()
return etoks
}
// TranspileGoNDimIndex processes an ident[*] sequence of tokens,
// translating it into a corresponding tensor Value or Set expression,
// if it is a multi-dimensional indexing expression which is not valid in Go,
// to support simple n-dimensional tensor indexing in Go (not math) mode.
// Gets the current sequence of toks tokens, where the ident starts at idIdx
// and the ] is at rbIdx. It puts the results in gtoks generated tokens.
// Returns a positive index to resume processing at, if it is actually an
// n-dimensional expr, and -1 if not, in which case the normal process resumes.
func (st *State) TranspileGoNDimIndex(toks Tokens, code string, gtoks *Tokens, idIdx, rbIdx int) int {
var commas []int
for i := idIdx + 2; i < rbIdx; i++ {
tk := toks[i]
if tk.Tok == token.COMMA {
commas = append(commas, i)
}
if tk.Tok == token.LPAREN || tk.Tok == token.LBRACE || tk.Tok == token.LBRACK {
rp := toks[i:rbIdx].RightMatching()
if rp > 0 {
i += rp
}
}
}
if len(commas) == 0 { // not multidim
return -1
}
isPtr := false
if idIdx > 0 && toks[idIdx-1].Tok == token.AND {
isPtr = true
lgt := len(*gtoks)
*gtoks = slices.Delete(*gtoks, lgt-2, lgt-1) // get rid of &
}
// now we need to determine if it is a Set based on what happens after rb
isSet := false
stok := token.ILLEGAL
n := len(toks)
hasComment := false
if toks[n-1].Tok == token.COMMENT {
hasComment = true
n--
}
if n-rbIdx > 1 {
ntk := toks[rbIdx+1].Tok
if ntk == token.ASSIGN || (ntk >= token.ADD_ASSIGN && ntk <= token.QUO_ASSIGN) {
isSet = true
stok = ntk
}
}
fun := "Value"
if isPtr {
fun = "ValuePtr"
isSet = false
} else if isSet {
fun = "Set"
switch stok {
case token.ADD_ASSIGN:
fun += "Add"
case token.SUB_ASSIGN:
fun += "Sub"
case token.MUL_ASSIGN:
fun += "Mul"
case token.QUO_ASSIGN:
fun += "Div"
}
}
gtoks.Add(token.PERIOD)
gtoks.Add(token.IDENT, fun)
gtoks.Add(token.LPAREN)
if isSet {
gtoks.AddTokens(st.TranspileGo(toks[rbIdx+2:n], code)...)
gtoks.Add(token.COMMA)
}
sti := idIdx + 2
for _, cp := range commas {
gtoks.Add(token.IDENT, "int")
gtoks.Add(token.LPAREN)
gtoks.AddTokens(st.TranspileGo(toks[sti:cp], code)...)
gtoks.Add(token.RPAREN)
gtoks.Add(token.COMMA)
sti = cp + 1
}
gtoks.Add(token.IDENT, "int")
gtoks.Add(token.LPAREN)
gtoks.AddTokens(st.TranspileGo(toks[sti:rbIdx], code)...)
gtoks.Add(token.RPAREN)
gtoks.Add(token.RPAREN)
if isSet {
if hasComment {
gtoks.AddTokens(toks[len(toks)-1])
}
return len(toks)
} else {
return rbIdx
}
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package transpile
import (
"go/ast"
"go/token"
)
// inferKindExpr infers the basic Kind level type from given expression
func inferKindExpr(ex ast.Expr) token.Token {
if ex == nil {
return token.ILLEGAL
}
switch x := ex.(type) {
case *ast.BadExpr:
return token.ILLEGAL
case *ast.Ident:
// todo: get type of object is not possible!
case *ast.BinaryExpr:
ta := inferKindExpr(x.X)
tb := inferKindExpr(x.Y)
if ta == tb {
return ta
}
if ta != token.ILLEGAL {
return ta
} else {
return tb
}
case *ast.BasicLit:
return x.Kind // key grounding
case *ast.FuncLit:
case *ast.ParenExpr:
return inferKindExpr(x.X)
case *ast.SelectorExpr:
case *ast.TypeAssertExpr:
case *ast.IndexListExpr:
if x.X == nil { // array literal
return inferKindExprList(x.Indices)
} else {
return inferKindExpr(x.X)
}
case *ast.SliceExpr:
case *ast.CallExpr:
}
return token.ILLEGAL
}
func inferKindExprList(ex []ast.Expr) token.Token {
n := len(ex)
for i := range n {
t := inferKindExpr(ex[i])
if t != token.ILLEGAL {
return t
}
}
return token.ILLEGAL
}
// Copyright (c) 2022, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
package alignsl performs 16-byte alignment checking of struct fields
and total size modulus checking of struct types to ensure WGSL
(and GSL) compatibility.
Checks that struct sizes are an even multiple of 16 bytes
(4 float32's), fields are 32 bit types: [U]Int32, Float32,
and that fields that are other struct types are aligned
at even 16 byte multiples.
*/
package alignsl
import (
"errors"
"fmt"
"go/types"
"strings"
"golang.org/x/tools/go/packages"
)
// Context for given package run
type Context struct {
Sizes types.Sizes // from package
Structs map[*types.Struct]string // structs that have been processed already -- value is name
Stack map[*types.Struct]string // structs to process in a second pass -- structs encountered during processing of other structs
Errs []string // accumulating list of error strings -- empty if all good
}
func NewContext(sz types.Sizes) *Context {
cx := &Context{Sizes: sz}
cx.Structs = make(map[*types.Struct]string)
cx.Stack = make(map[*types.Struct]string)
return cx
}
func (cx *Context) IsNewStruct(st *types.Struct) bool {
if _, has := cx.Structs[st]; has {
return false
}
cx.Structs[st] = st.String()
return true
}
func (cx *Context) AddError(ers string, hasErr bool, stName string) bool {
if !hasErr {
cx.Errs = append(cx.Errs, stName)
}
cx.Errs = append(cx.Errs, ers)
return true
}
func TypeName(tp types.Type) string {
switch x := tp.(type) {
case *types.Named:
return x.Obj().Name()
}
return tp.String()
}
// CheckStruct is the primary checker -- returns hasErr = true if there
// are any mis-aligned fields or total size of struct is not an
// even multiple of 16 bytes -- adds details to Errs
func CheckStruct(cx *Context, st *types.Struct, stName string) bool {
if !cx.IsNewStruct(st) {
return false
}
var flds []*types.Var
nf := st.NumFields()
if nf == 0 {
return false
}
hasErr := false
for i := 0; i < nf; i++ {
fl := st.Field(i)
flds = append(flds, fl)
ft := fl.Type()
ut := ft.Underlying()
if bt, isBasic := ut.(*types.Basic); isBasic {
kind := bt.Kind()
if !(kind == types.Uint32 || kind == types.Int32 || kind == types.Float32 || kind == types.Uint64) {
hasErr = cx.AddError(fmt.Sprintf(" %s: basic type != [U]Int32 or Float32: %s", fl.Name(), bt.String()), hasErr, stName)
}
} else {
if sst, is := ut.(*types.Struct); is {
cx.Stack[sst] = TypeName(ft)
} else {
hasErr = cx.AddError(fmt.Sprintf(" %s: unsupported type: %s", fl.Name(), ft.String()), hasErr, stName)
}
}
}
offs := cx.Sizes.Offsetsof(flds)
last := cx.Sizes.Sizeof(flds[nf-1].Type())
totsz := int(offs[nf-1] + last)
mod := totsz % 16
vectyp := strings.Contains(strings.ToLower(stName), "vec") // vector types are ok
if !vectyp && mod != 0 {
needs := 4 - (mod / 4)
hasErr = cx.AddError(fmt.Sprintf(" total size: %d not even multiple of 16 -- needs %d extra 32bit padding fields", totsz, needs), hasErr, stName)
}
// check that struct starts at mod 16 byte offset
for i, fl := range flds {
ft := fl.Type()
ut := ft.Underlying()
if _, is := ut.(*types.Struct); is {
off := offs[i]
if off%16 != 0 {
hasErr = cx.AddError(fmt.Sprintf(" %s: struct type: %s is not at mod-16 byte offset: %d", fl.Name(), TypeName(ft), off), hasErr, stName)
}
}
}
return hasErr
}
// CheckPackage is main entry point for checking a package
// returns error string if any errors found.
func CheckPackage(pkg *packages.Package) error {
cx := NewContext(pkg.TypesSizes)
sc := pkg.Types.Scope()
hasErr := CheckScope(cx, sc, 0)
er := CheckStack(cx)
if hasErr || er {
str := `
WARNING: in struct type alignment checking:
Checks that struct sizes are an even multiple of 16 bytes (4 float32's),
and fields are 32 bit types: [U]Int32, Float32 or other struct,
and that fields that are other struct types are aligned at even 16 byte multiples.
List of errors found follow below, by struct type name:
` + strings.Join(cx.Errs, "\n")
return errors.New(str)
}
return nil
}
func CheckStack(cx *Context) bool {
hasErr := false
for {
if len(cx.Stack) == 0 {
break
}
st := cx.Stack
cx.Stack = make(map[*types.Struct]string) // new stack
for st, nm := range st {
er := CheckStruct(cx, st, nm)
if er {
hasErr = true
}
}
}
return hasErr
}
func CheckScope(cx *Context, sc *types.Scope, level int) bool {
nms := sc.Names()
ntyp := 0
hasErr := false
for _, nm := range nms {
ob := sc.Lookup(nm)
tp := ob.Type()
if tp == nil {
continue
}
if nt, is := tp.(*types.Named); is {
ut := nt.Underlying()
if ut == nil {
continue
}
if st, is := ut.(*types.Struct); is {
er := CheckStruct(cx, st, nt.Obj().Name())
if er {
hasErr = true
}
ntyp++
}
}
}
if ntyp == 0 {
for i := 0; i < sc.NumChildren(); i++ {
cs := sc.Child(i)
er := CheckScope(cx, cs, level+1)
if er {
hasErr = true
}
}
}
return hasErr
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import "sync/atomic"
//gosl:start
// Atomic does an atomic computation on the data.
func Atomic(i uint32) { //gosl:kernel
atomic.AddInt32(&IntData[i, Integ], 1)
}
//gosl:end
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"cogentcore.org/core/math32"
"cogentcore.org/lab/tensor"
)
//gosl:start
//gosl:import "cogentcore.org/core/math32"
//gosl:vars
var (
// Params are the parameters for the computation.
//gosl:group Params
//gosl:read-only
Params []ParamStruct
// Data is the data on which the computation operates.
// 2D: outer index is data, inner index is: Raw, Integ, Exp vars.
//gosl:group Data
//gosl:dims 2
Data *tensor.Float32
// IntData is the int data on which the computation operates.
// 2D: outer index is data, inner index is: Raw, Integ, Exp vars.
//gosl:dims 2
IntData *tensor.Int32
)
const (
Raw int = iota
Integ
Exp
NVars
)
// ParamStruct has the test params
type ParamStruct struct {
// rate constant in msec
Tau float32
// 1/Tau
Dt float32
pad float32
pad1 float32
}
// IntegFromRaw computes integrated value from current raw value
func (ps *ParamStruct) IntegFromRaw(idx int) {
integ := Data[idx, Integ]
integ += ps.Dt * (Data[idx, Raw] - integ)
Data[idx, Integ] = integ
Data[idx, Exp] = math32.FastExp(-integ)
}
// Compute does the main computation.
func Compute(i uint32) { //gosl:kernel
params := GetParams(0)
params.IntegFromRaw(int(i))
}
//gosl:end
// note: only core compute code needs to be in shader -- all init is done CPU-side
func (ps *ParamStruct) Defaults() {
ps.Tau = 5
ps.Update()
}
func (ps *ParamStruct) Update() {
ps.Dt = 1.0 / ps.Tau
}
// Code generated by "gosl"; DO NOT EDIT
package main
import (
"embed"
"unsafe"
"cogentcore.org/core/gpu"
"cogentcore.org/lab/tensor"
)
//go:embed shaders/*.wgsl
var shaders embed.FS
// ComputeGPU is the compute gpu device
var ComputeGPU *gpu.GPU
// UseGPU indicates whether to use GPU vs. CPU.
var UseGPU bool
// GPUSystem is a GPU compute System with kernels operating on the
// same set of data variables.
var GPUSystem *gpu.ComputeSystem
// GPUVars is an enum for GPU variables, for specifying what to sync.
type GPUVars int32 //enums:enum
const (
ParamsVar GPUVars = 0
DataVar GPUVars = 1
IntDataVar GPUVars = 2
)
// Tensor stride variables
var TensorStrides tensor.Uint32
// GPUInit initializes the GPU compute system,
// configuring system(s), variables and kernels.
// It is safe to call multiple times: detects if already run.
func GPUInit() {
if ComputeGPU != nil {
return
}
gp := gpu.NewComputeGPU()
ComputeGPU = gp
{
sy := gpu.NewComputeSystem(gp, "Default")
GPUSystem = sy
gpu.NewComputePipelineShaderFS(shaders, "shaders/Atomic.wgsl", sy)
gpu.NewComputePipelineShaderFS(shaders, "shaders/Compute.wgsl", sy)
vars := sy.Vars()
{
sgp := vars.AddGroup(gpu.Storage, "Params")
var vr *gpu.Var
_ = vr
vr = sgp.Add("TensorStrides", gpu.Uint32, 1, gpu.ComputeShader)
vr.ReadOnly = true
vr = sgp.AddStruct("Params", int(unsafe.Sizeof(ParamStruct{})), 1, gpu.ComputeShader)
vr.ReadOnly = true
sgp.SetNValues(1)
}
{
sgp := vars.AddGroup(gpu.Storage, "Data")
var vr *gpu.Var
_ = vr
vr = sgp.Add("Data", gpu.Float32, 1, gpu.ComputeShader)
vr = sgp.Add("IntData", gpu.Int32, 1, gpu.ComputeShader)
sgp.SetNValues(1)
}
sy.Config()
}
}
// GPURelease releases the GPU compute system resources.
// Call this at program exit.
func GPURelease() {
if GPUSystem != nil {
GPUSystem.Release()
GPUSystem = nil
}
if ComputeGPU != nil {
ComputeGPU.Release()
ComputeGPU = nil
}
}
// RunAtomic runs the Atomic kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// Can call multiple Run* kernels in a row, which are then all launched
// in the same command submission on the GPU, which is by far the most efficient.
// MUST call RunDone (with optional vars to sync) after all Run calls.
// Alternatively, a single-shot RunOneAtomic call does Run and Done for a
// single run-and-sync case.
func RunAtomic(n int) {
if UseGPU {
RunAtomicGPU(n)
} else {
RunAtomicCPU(n)
}
}
// RunAtomicGPU runs the Atomic kernel on the GPU. See [RunAtomic] for more info.
func RunAtomicGPU(n int) {
sy := GPUSystem
pl := sy.ComputePipelines["Atomic"]
ce, _ := sy.BeginComputePass()
pl.Dispatch1D(ce, n, 64)
}
// RunAtomicCPU runs the Atomic kernel on the CPU.
func RunAtomicCPU(n int) {
gpu.VectorizeFunc(0, n, Atomic)
}
// RunOneAtomic runs the Atomic kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// This version then calls RunDone with the given variables to sync
// after the Run, for a single-shot Run-and-Done call. If multiple kernels
// can be run in sequence, it is much more efficient to do multiple Run*
// calls followed by a RunDone call.
func RunOneAtomic(n int, syncVars ...GPUVars) {
if UseGPU {
RunAtomicGPU(n)
RunDone(syncVars...)
} else {
RunAtomicCPU(n)
}
}
// RunCompute runs the Compute kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// Can call multiple Run* kernels in a row, which are then all launched
// in the same command submission on the GPU, which is by far the most efficient.
// MUST call RunDone (with optional vars to sync) after all Run calls.
// Alternatively, a single-shot RunOneCompute call does Run and Done for a
// single run-and-sync case.
func RunCompute(n int) {
if UseGPU {
RunComputeGPU(n)
} else {
RunComputeCPU(n)
}
}
// RunComputeGPU runs the Compute kernel on the GPU. See [RunCompute] for more info.
func RunComputeGPU(n int) {
sy := GPUSystem
pl := sy.ComputePipelines["Compute"]
ce, _ := sy.BeginComputePass()
pl.Dispatch1D(ce, n, 64)
}
// RunComputeCPU runs the Compute kernel on the CPU.
func RunComputeCPU(n int) {
gpu.VectorizeFunc(0, n, Compute)
}
// RunOneCompute runs the Compute kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// This version then calls RunDone with the given variables to sync
// after the Run, for a single-shot Run-and-Done call. If multiple kernels
// can be run in sequence, it is much more efficient to do multiple Run*
// calls followed by a RunDone call.
func RunOneCompute(n int, syncVars ...GPUVars) {
if UseGPU {
RunComputeGPU(n)
RunDone(syncVars...)
} else {
RunComputeCPU(n)
}
}
// RunDone must be called after Run* calls to start compute kernels.
// This actually submits the kernel jobs to the GPU, and adds commands
// to synchronize the given variables back from the GPU to the CPU.
// After this function completes, the GPU results will be available in
// the specified variables.
func RunDone(syncVars ...GPUVars) {
if !UseGPU {
return
}
sy := GPUSystem
sy.ComputeEncoder.End()
ReadFromGPU(syncVars...)
sy.EndComputePass()
SyncFromGPU(syncVars...)
}
// ToGPU copies given variables to the GPU for the system.
func ToGPU(vars ...GPUVars) {
if !UseGPU {
return
}
sy := GPUSystem
syVars := sy.Vars()
for _, vr := range vars {
switch vr {
case ParamsVar:
v, _ := syVars.ValueByIndex(0, "Params", 0)
gpu.SetValueFrom(v, Params)
case DataVar:
v, _ := syVars.ValueByIndex(1, "Data", 0)
gpu.SetValueFrom(v, Data.Values)
case IntDataVar:
v, _ := syVars.ValueByIndex(1, "IntData", 0)
gpu.SetValueFrom(v, IntData.Values)
}
}
}
// RunGPUSync can be called to synchronize data between CPU and GPU.
// Any prior ToGPU* calls will execute to send data to the GPU,
// and any subsequent RunDone* calls will copy data back from the GPU.
func RunGPUSync() {
if !UseGPU {
return
}
sy := GPUSystem
sy.BeginComputePass()
}
// ToGPUTensorStrides gets tensor strides and starts copying to the GPU.
func ToGPUTensorStrides() {
if !UseGPU {
return
}
sy := GPUSystem
syVars := sy.Vars()
TensorStrides.SetShapeSizes(20)
TensorStrides.SetInt1D(Data.Shape().Strides[0], 0)
TensorStrides.SetInt1D(Data.Shape().Strides[1], 1)
TensorStrides.SetInt1D(IntData.Shape().Strides[0], 10)
TensorStrides.SetInt1D(IntData.Shape().Strides[1], 11)
v, _ := syVars.ValueByIndex(0, "TensorStrides", 0)
gpu.SetValueFrom(v, TensorStrides.Values)
}
// ReadFromGPU starts the process of copying vars to the GPU.
func ReadFromGPU(vars ...GPUVars) {
sy := GPUSystem
syVars := sy.Vars()
for _, vr := range vars {
switch vr {
case ParamsVar:
v, _ := syVars.ValueByIndex(0, "Params", 0)
v.GPUToRead(sy.CommandEncoder)
case DataVar:
v, _ := syVars.ValueByIndex(1, "Data", 0)
v.GPUToRead(sy.CommandEncoder)
case IntDataVar:
v, _ := syVars.ValueByIndex(1, "IntData", 0)
v.GPUToRead(sy.CommandEncoder)
}
}
}
// SyncFromGPU synchronizes vars from the GPU to the actual variable.
func SyncFromGPU(vars ...GPUVars) {
sy := GPUSystem
syVars := sy.Vars()
for _, vr := range vars {
switch vr {
case ParamsVar:
v, _ := syVars.ValueByIndex(0, "Params", 0)
v.ReadSync()
gpu.ReadToBytes(v, Params)
case DataVar:
v, _ := syVars.ValueByIndex(1, "Data", 0)
v.ReadSync()
gpu.ReadToBytes(v, Data.Values)
case IntDataVar:
v, _ := syVars.ValueByIndex(1, "IntData", 0)
v.ReadSync()
gpu.ReadToBytes(v, IntData.Values)
}
}
}
// GetParams returns a pointer to the given global variable:
// [Params] []ParamStruct at given index.
// To ensure that values are updated on the GPU, you must call [SetParams].
// after all changes have been made.
func GetParams(idx uint32) *ParamStruct {
return &Params[idx]
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This example just does some basic calculations on data structures and
// reports the time difference between the CPU and GPU.
package main
import (
"fmt"
"math/rand"
"runtime"
"cogentcore.org/core/base/timer"
"cogentcore.org/core/gpu"
"cogentcore.org/lab/tensor"
)
//go:generate gosl
func init() {
// must lock main thread for gpu!
runtime.LockOSThread()
}
func main() {
gpu.Debug = true
GPUInit()
rand.Seed(0)
// gpu.NumThreads = 1 // to restrict to sequential for loop
n := 16_000_000
// n := 2_000_000
Params = make([]ParamStruct, 1)
Params[0].Defaults()
Data = tensor.NewFloat32()
Data.SetShapeSizes(n, 3)
nt := Data.Len()
IntData = tensor.NewInt32()
IntData.SetShapeSizes(n, 3)
for i := range nt {
Data.Set1D(rand.Float32(), i)
}
sid := tensor.NewInt32()
sid.SetShapeSizes(n, 3)
sd := tensor.NewFloat32()
sd.SetShapeSizes(n, 3)
for i := range nt {
sd.Set1D(Data.Value1D(i), i)
}
cpuTmr := timer.Time{}
cpuTmr.Start()
RunOneAtomic(n)
RunOneCompute(n)
cpuTmr.Stop()
cd := Data
cid := IntData
Data = sd
IntData = sid
gpuFullTmr := timer.Time{}
gpuFullTmr.Start()
UseGPU = true
ToGPUTensorStrides()
ToGPU(ParamsVar, DataVar, IntDataVar)
gpuTmr := timer.Time{}
gpuTmr.Start()
RunAtomic(n)
RunCompute(n)
gpuTmr.Stop()
RunDone(DataVar, IntDataVar)
gpuFullTmr.Stop()
mx := min(n, 5)
for i := 0; i < mx; i++ {
fmt.Printf("%d\t CPU IntData: %d\t GPU: %d\n", i, cid.Value(1, Integ), sid.Value(i, Integ))
}
fmt.Println()
for i := 0; i < mx; i++ {
d := cd.Value(i, Exp) - sd.Value(i, Exp)
fmt.Printf("CPU:\t%d\t Raw: %6.4g\t Integ: %6.4g\t Exp: %6.4g\tGPU: %6.4g\tDiff: %g\n", i, cd.Value(i, Raw), cd.Value(i, Integ), cd.Value(i, Exp), sd.Value(i, Exp), d)
fmt.Printf("GPU:\t%d\t Raw: %6.4g\t Integ: %6.4g\t Exp: %6.4g\tCPU: %6.4g\tDiff: %g\n\n", i, sd.Value(i, Raw), sd.Value(i, Integ), sd.Value(i, Exp), cd.Value(i, Exp), d)
}
fmt.Printf("\n")
cpu := cpuTmr.Total
gpu := gpuTmr.Total
gpuFull := gpuFullTmr.Total
fmt.Printf("N: %d\t CPU: %v\t GPU: %v\t Full: %v\t CPU/GPU: %6.4g\n", n, cpu, gpu, gpuFull, float64(cpu)/float64(gpu))
GPURelease()
}
// Code generated by "gosl"; DO NOT EDIT
package main
import (
"embed"
"unsafe"
"cogentcore.org/core/gpu"
"cogentcore.org/lab/tensor"
)
//go:embed shaders/*.wgsl
var shaders embed.FS
// ComputeGPU is the compute gpu device
var ComputeGPU *gpu.GPU
// UseGPU indicates whether to use GPU vs. CPU.
var UseGPU bool
// GPUSystem is a GPU compute System with kernels operating on the
// same set of data variables.
var GPUSystem *gpu.ComputeSystem
// GPUVars is an enum for GPU variables, for specifying what to sync.
type GPUVars int32 //enums:enum
const (
SeedVar GPUVars = 0
DataVar GPUVars = 1
)
// Dummy tensor stride variable to avoid import error
var __TensorStrides tensor.Uint32
// GPUInit initializes the GPU compute system,
// configuring system(s), variables and kernels.
// It is safe to call multiple times: detects if already run.
func GPUInit() {
if ComputeGPU != nil {
return
}
gp := gpu.NewComputeGPU()
ComputeGPU = gp
{
sy := gpu.NewComputeSystem(gp, "Default")
GPUSystem = sy
gpu.NewComputePipelineShaderFS(shaders, "shaders/Compute.wgsl", sy)
vars := sy.Vars()
{
sgp := vars.AddGroup(gpu.Storage, "Group_0")
var vr *gpu.Var
_ = vr
vr = sgp.AddStruct("Seed", int(unsafe.Sizeof(Seeds{})), 1, gpu.ComputeShader)
vr.ReadOnly = true
vr = sgp.AddStruct("Data", int(unsafe.Sizeof(Rnds{})), 1, gpu.ComputeShader)
sgp.SetNValues(1)
}
sy.Config()
}
}
// GPURelease releases the GPU compute system resources.
// Call this at program exit.
func GPURelease() {
if GPUSystem != nil {
GPUSystem.Release()
GPUSystem = nil
}
if ComputeGPU != nil {
ComputeGPU.Release()
ComputeGPU = nil
}
}
// RunCompute runs the Compute kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// Can call multiple Run* kernels in a row, which are then all launched
// in the same command submission on the GPU, which is by far the most efficient.
// MUST call RunDone (with optional vars to sync) after all Run calls.
// Alternatively, a single-shot RunOneCompute call does Run and Done for a
// single run-and-sync case.
func RunCompute(n int) {
if UseGPU {
RunComputeGPU(n)
} else {
RunComputeCPU(n)
}
}
// RunComputeGPU runs the Compute kernel on the GPU. See [RunCompute] for more info.
func RunComputeGPU(n int) {
sy := GPUSystem
pl := sy.ComputePipelines["Compute"]
ce, _ := sy.BeginComputePass()
pl.Dispatch1D(ce, n, 64)
}
// RunComputeCPU runs the Compute kernel on the CPU.
func RunComputeCPU(n int) {
gpu.VectorizeFunc(0, n, Compute)
}
// RunOneCompute runs the Compute kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// This version then calls RunDone with the given variables to sync
// after the Run, for a single-shot Run-and-Done call. If multiple kernels
// can be run in sequence, it is much more efficient to do multiple Run*
// calls followed by a RunDone call.
func RunOneCompute(n int, syncVars ...GPUVars) {
if UseGPU {
RunComputeGPU(n)
RunDone(syncVars...)
} else {
RunComputeCPU(n)
}
}
// RunDone must be called after Run* calls to start compute kernels.
// This actually submits the kernel jobs to the GPU, and adds commands
// to synchronize the given variables back from the GPU to the CPU.
// After this function completes, the GPU results will be available in
// the specified variables.
func RunDone(syncVars ...GPUVars) {
if !UseGPU {
return
}
sy := GPUSystem
sy.ComputeEncoder.End()
ReadFromGPU(syncVars...)
sy.EndComputePass()
SyncFromGPU(syncVars...)
}
// ToGPU copies given variables to the GPU for the system.
func ToGPU(vars ...GPUVars) {
if !UseGPU {
return
}
sy := GPUSystem
syVars := sy.Vars()
for _, vr := range vars {
switch vr {
case SeedVar:
v, _ := syVars.ValueByIndex(0, "Seed", 0)
gpu.SetValueFrom(v, Seed)
case DataVar:
v, _ := syVars.ValueByIndex(0, "Data", 0)
gpu.SetValueFrom(v, Data)
}
}
}
// RunGPUSync can be called to synchronize data between CPU and GPU.
// Any prior ToGPU* calls will execute to send data to the GPU,
// and any subsequent RunDone* calls will copy data back from the GPU.
func RunGPUSync() {
if !UseGPU {
return
}
sy := GPUSystem
sy.BeginComputePass()
}
// ReadFromGPU starts the process of copying vars to the GPU.
func ReadFromGPU(vars ...GPUVars) {
sy := GPUSystem
syVars := sy.Vars()
for _, vr := range vars {
switch vr {
case SeedVar:
v, _ := syVars.ValueByIndex(0, "Seed", 0)
v.GPUToRead(sy.CommandEncoder)
case DataVar:
v, _ := syVars.ValueByIndex(0, "Data", 0)
v.GPUToRead(sy.CommandEncoder)
}
}
}
// SyncFromGPU synchronizes vars from the GPU to the actual variable.
func SyncFromGPU(vars ...GPUVars) {
sy := GPUSystem
syVars := sy.Vars()
for _, vr := range vars {
switch vr {
case SeedVar:
v, _ := syVars.ValueByIndex(0, "Seed", 0)
v.ReadSync()
gpu.ReadToBytes(v, Seed)
case DataVar:
v, _ := syVars.ValueByIndex(0, "Data", 0)
v.ReadSync()
gpu.ReadToBytes(v, Data)
}
}
}
// GetSeed returns a pointer to the given global variable:
// [Seed] []Seeds at given index.
// To ensure that values are updated on the GPU, you must call [SetSeed].
// after all changes have been made.
func GetSeed(idx uint32) *Seeds {
return &Seed[idx]
}
// GetData returns a pointer to the given global variable:
// [Data] []Rnds at given index.
// To ensure that values are updated on the GPU, you must call [SetData].
// after all changes have been made.
func GetData(idx uint32) *Rnds {
return &Data[idx]
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"runtime"
"log/slog"
"cogentcore.org/core/base/timer"
)
//go:generate gosl
func init() {
// must lock main thread for gpu!
runtime.LockOSThread()
}
func main() {
GPUInit()
// n := 10
// n := 16_000_000 // max for macbook M*
n := 200_000
UseGPU = false
Seed = make([]Seeds, 1)
dataC := make([]Rnds, n)
dataG := make([]Rnds, n)
Data = dataC
cpuTmr := timer.Time{}
cpuTmr.Start()
RunOneCompute(n)
cpuTmr.Stop()
UseGPU = true
Data = dataG
gpuFullTmr := timer.Time{}
gpuFullTmr.Start()
ToGPU(SeedVar, DataVar)
gpuTmr := timer.Time{}
gpuTmr.Start()
RunCompute(n)
gpuTmr.Stop()
RunDone(DataVar)
gpuFullTmr.Stop()
anyDiffEx := false
anyDiffTol := false
mx := min(n, 5)
fmt.Printf("Index\tDif(Ex,Tol)\t CPU \t then GPU\n")
for i := 0; i < n; i++ {
dc := &dataC[i]
dg := &dataG[i]
smEx, smTol := dc.IsSame(dg)
if !smEx {
anyDiffEx = true
}
if !smTol {
anyDiffTol = true
}
if i > mx {
continue
}
exS := " "
if !smEx {
exS = "*"
}
tolS := " "
if !smTol {
tolS = "*"
}
fmt.Printf("%d\t%s %s\t%s\n\t\t%s\n", i, exS, tolS, dc.String(), dg.String())
}
fmt.Printf("\n")
if anyDiffEx {
slog.Error("Differences between CPU and GPU detected at Exact level (excludes Gauss)")
}
if anyDiffTol {
slog.Error("Differences between CPU and GPU detected at Tolerance level", "tolerance", Tol)
}
cpu := cpuTmr.Total
gpu := gpuTmr.Total
fmt.Printf("N: %d\t CPU: %v\t GPU: %v\t Full: %v\t CPU/GPU: %6.4g\n", n, cpu, gpu, gpuFullTmr.Total, float64(cpu)/float64(gpu))
GPURelease()
}
package main
import (
"fmt"
"cogentcore.org/core/math32"
"cogentcore.org/lab/gosl/slrand"
"cogentcore.org/lab/gosl/sltype"
)
//gosl:start
//gosl:vars
var (
//gosl:read-only
Seed []Seeds
// Data
Data []Rnds
)
type Seeds struct {
Seed uint64
pad, pad1 int32
}
type Rnds struct {
Uints sltype.Uint32Vec2
pad, pad1 int32
Floats sltype.Float32Vec2
pad2, pad3 int32
Floats11 sltype.Float32Vec2
pad4, pad5 int32
Gauss sltype.Float32Vec2
pad6, pad7 int32
}
// RndGen calls random function calls to test generator.
// Note that the counter to the outer-most computation function
// is passed by *value*, so the same counter goes to each element
// as it is computed, but within this scope, counter is passed by
// reference (as a pointer) so subsequent calls get a new counter value.
// The counter should be incremented by the number of random calls
// outside of the overall update function.
func (r *Rnds) RndGen(counter uint64, idx uint32) {
r.Uints = slrand.Uint32Vec2(counter, uint32(0), idx)
r.Floats = slrand.Float32Vec2(counter, uint32(1), idx)
r.Floats11 = slrand.Float32Range11Vec2(counter, uint32(2), idx)
r.Gauss = slrand.Float32NormVec2(counter, uint32(3), idx)
}
func Compute(i uint32) { //gosl:kernel
Data[i].RndGen(Seed[0].Seed, i)
}
//gosl:end
const Tol = 1.0e-4 // fails at lower tol eventually -- -6 works for many
func FloatSame(f1, f2 float32) (exact, tol bool) {
exact = f1 == f2
tol = math32.Abs(f1-f2) < Tol
return
}
func Float32Vec2Same(f1, f2 sltype.Float32Vec2) (exact, tol bool) {
e1, t1 := FloatSame(f1.X, f2.X)
e2, t2 := FloatSame(f1.Y, f2.Y)
exact = e1 && e2
tol = t1 && t2
return
}
// IsSame compares values at two levels: exact and with Tol
func (r *Rnds) IsSame(o *Rnds) (exact, tol bool) {
e1 := r.Uints == o.Uints
e2, t2 := Float32Vec2Same(r.Floats, o.Floats)
e3, t3 := Float32Vec2Same(r.Floats11, o.Floats11)
_, t4 := Float32Vec2Same(r.Gauss, o.Gauss)
exact = e1 && e2 && e3 // skip e4 -- know it isn't
tol = t2 && t3 && t4
return
}
func (r *Rnds) String() string {
return fmt.Sprintf("U: %x\t%x\tF: %g\t%g\tF11: %g\t%g\tG: %g\t%g", r.Uints.X, r.Uints.Y, r.Floats.X, r.Floats.Y, r.Floats11.X, r.Floats11.Y, r.Gauss.X, r.Gauss.Y)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"cogentcore.org/core/cli"
"cogentcore.org/lab/gosl/gotosl"
)
func main() { //types:skip
opts := cli.DefaultOptions("gosl", "Go as a shader language converts Go code to WGSL WebGPU shader code, which can be run on the GPU through WebGPU.")
cfg := &gotosl.Config{}
cli.Run(opts, cfg, gotosl.Run)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
import (
"fmt"
"sort"
"golang.org/x/exp/maps"
)
// Function represents the call graph of functions
type Function struct {
Name string
Funcs map[string]*Function
Atomics map[string]*Var // variables that have atomic operations in this function
}
func NewFunction(name string) *Function {
return &Function{Name: name, Funcs: make(map[string]*Function)}
}
// get or add a function of given name
func (st *State) RecycleFunc(name string) *Function {
fn, ok := st.FuncGraph[name]
if !ok {
fn = NewFunction(name)
st.FuncGraph[name] = fn
}
return fn
}
func getAllFuncs(f *Function, all map[string]*Function) {
for fnm, fn := range f.Funcs {
_, ok := all[fnm]
if ok {
continue
}
all[fnm] = fn
getAllFuncs(fn, all)
}
}
// AllFuncs returns aggregated list of all functions called be given function
func (st *State) AllFuncs(name string) map[string]*Function {
fn, ok := st.FuncGraph[name]
if !ok {
fmt.Printf("gosl: ERROR kernel function named: %q not found\n", name)
return nil
}
all := make(map[string]*Function)
all[name] = fn
getAllFuncs(fn, all)
// cfs := maps.Keys(all)
// sort.Strings(cfs)
// for _, cfnm := range cfs {
// fmt.Println("\t" + cfnm)
// }
return all
}
// AtomicVars returns all the variables marked as atomic
// within the list of functions.
func (st *State) AtomicVars(funcs map[string]*Function) map[string]*Var {
avars := make(map[string]*Var)
for _, fn := range funcs {
if fn.Atomics == nil {
continue
}
for vn, v := range fn.Atomics {
avars[vn] = v
}
}
return avars
}
func (st *State) PrintFuncGraph() {
funs := maps.Keys(st.FuncGraph)
sort.Strings(funs)
for _, fname := range funs {
fmt.Println(fname)
fn := st.FuncGraph[fname]
cfs := maps.Keys(fn.Funcs)
sort.Strings(cfs)
for _, cfnm := range cfs {
fmt.Println("\t" + cfnm)
}
}
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file is largely copied from the Go source,
// src/go/printer/comment.go:
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
import (
"go/ast"
"go/doc/comment"
"strings"
)
// formatDocComment reformats the doc comment list,
// returning the canonical formatting.
func formatDocComment(list []*ast.Comment) []*ast.Comment {
// Extract comment text (removing comment markers).
var kind, text string
var directives []*ast.Comment
if len(list) == 1 && strings.HasPrefix(list[0].Text, "/*") {
kind = "/*"
text = list[0].Text
if !strings.Contains(text, "\n") || allStars(text) {
// Single-line /* .. */ comment in doc comment position,
// or multiline old-style comment like
// /*
// * Comment
// * text here.
// */
// Should not happen, since it will not work well as a
// doc comment, but if it does, just ignore:
// reformatting it will only make the situation worse.
return list
}
text = text[2 : len(text)-2] // cut /* and */
} else if strings.HasPrefix(list[0].Text, "//") {
kind = "//"
var b strings.Builder
for _, c := range list {
after, found := strings.CutPrefix(c.Text, "//")
if !found {
return list
}
// Accumulate //go:build etc lines separately.
if isDirective(after) {
directives = append(directives, c)
continue
}
b.WriteString(strings.TrimPrefix(after, " "))
b.WriteString("\n")
}
text = b.String()
} else {
// Not sure what this is, so leave alone.
return list
}
if text == "" {
return list
}
// Parse comment and reformat as text.
var p comment.Parser
d := p.Parse(text)
var pr comment.Printer
text = string(pr.Comment(d))
// For /* */ comment, return one big comment with text inside.
slash := list[0].Slash
if kind == "/*" {
c := &ast.Comment{
Slash: slash,
Text: "/*\n" + text + "*/",
}
return []*ast.Comment{c}
}
// For // comment, return sequence of // lines.
var out []*ast.Comment
for text != "" {
var line string
line, text, _ = strings.Cut(text, "\n")
if line == "" {
line = "//"
} else if strings.HasPrefix(line, "\t") {
line = "//" + line
} else {
line = "// " + line
}
out = append(out, &ast.Comment{
Slash: slash,
Text: line,
})
}
if len(directives) > 0 {
out = append(out, &ast.Comment{
Slash: slash,
Text: "//",
})
for _, c := range directives {
out = append(out, &ast.Comment{
Slash: slash,
Text: c.Text,
})
}
}
return out
}
// isDirective reports whether c is a comment directive.
// See go.dev/issue/37974.
// This code is also in go/ast.
func isDirective(c string) bool {
// "//line " is a line directive.
// "//extern " is for gccgo.
// "//export " is for cgo.
// (The // has been removed.)
if strings.HasPrefix(c, "line ") || strings.HasPrefix(c, "extern ") || strings.HasPrefix(c, "export ") {
return true
}
// "//[a-z0-9]+:[a-z0-9]"
// (The // has been removed.)
colon := strings.Index(c, ":")
if colon <= 0 || colon+1 >= len(c) {
return false
}
for i := 0; i <= colon+1; i++ {
if i == colon {
continue
}
b := c[i]
if !('a' <= b && b <= 'z' || '0' <= b && b <= '9') {
return false
}
}
return true
}
// allStars reports whether text is the interior of an
// old-style /* */ comment with a star at the start of each line.
func allStars(text string) bool {
for i := 0; i < len(text); i++ {
if text[i] == '\n' {
j := i + 1
for j < len(text) && (text[j] == ' ' || text[j] == '\t') {
j++
}
if j < len(text) && text[j] != '*' {
return false
}
}
}
return true
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
//go:generate core generate -add-types -add-funcs
// Keep these in sync with go/format/format.go.
const (
tabWidth = 8
printerMode = UseSpaces | TabIndent | printerNormalizeNumbers
// printerNormalizeNumbers means to canonicalize number literal prefixes
// and exponents while printing. See https://golang.org/doc/go1.13#gosl.
//
// This value is defined in go/printer specifically for go/format and cmd/gosl.
printerNormalizeNumbers = 1 << 30
)
// Config has the configuration info for the gosl system.
type Config struct {
// Output is the output directory for shader code,
// relative to where gosl is invoked; must not be an empty string.
Output string `flag:"out" default:"shaders"`
// Exclude is a comma-separated list of names of functions to exclude from exporting to WGSL.
Exclude string `default:"Update,Defaults"`
// Keep keeps temporary converted versions of the source files, for debugging.
Keep bool
// Debug enables debugging messages while running.
Debug bool
}
//cli:cmd -root
func Run(cfg *Config) error { //types:add
st := &State{}
st.Init(cfg)
return st.Run()
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
import (
"bytes"
"fmt"
"path/filepath"
"strings"
"slices"
)
// ExtractFiles processes all the package files and saves the corresponding
// .go files with simple go header.
func (st *State) ExtractFiles() {
st.ImportPackages = make(map[string]bool)
for impath := range st.GoImports {
_, pkg := filepath.Split(impath)
if pkg != "math32" {
st.ImportPackages[pkg] = true
}
}
for fn, fl := range st.GoFiles {
hasVars := false
fl.Lines, hasVars = st.ExtractGosl(fl.Lines)
if hasVars {
st.GoVarsFiles[fn] = fl
delete(st.GoFiles, fn)
}
WriteFileLines(filepath.Join(st.ImportsDir, fn), st.AppendGoHeader(fl.Lines))
}
}
// ExtractImports processes all the imported files and saves the corresponding
// .go files with simple go header.
func (st *State) ExtractImports() {
if len(st.GoImports) == 0 {
return
}
for impath, im := range st.GoImports {
_, pkg := filepath.Split(impath)
for fn, fl := range im {
fl.Lines, _ = st.ExtractGosl(fl.Lines)
WriteFileLines(filepath.Join(st.ImportsDir, pkg+"-"+fn), st.AppendGoHeader(fl.Lines))
}
}
}
// ExtractGosl gosl comment-directive tagged regions from given file.
func (st *State) ExtractGosl(lines [][]byte) (outLines [][]byte, hasVars bool) {
key := []byte("//gosl:")
start := []byte("start")
wgsl := []byte("wgsl")
nowgsl := []byte("nowgsl")
end := []byte("end")
vars := []byte("vars")
imp := []byte("import")
kernel := []byte("//gosl:kernel")
fnc := []byte("func")
inReg := false
inHlsl := false
inNoHlsl := false
for li, ln := range lines {
tln := bytes.TrimSpace(ln)
isKey := bytes.HasPrefix(tln, key)
var keyStr []byte
if isKey {
keyStr = tln[len(key):]
// fmt.Printf("key: %s\n", string(keyStr))
}
switch {
case inReg && isKey && bytes.HasPrefix(keyStr, end):
if inHlsl || inNoHlsl {
outLines = append(outLines, ln)
}
inReg = false
inHlsl = false
inNoHlsl = false
case inReg && isKey && bytes.HasPrefix(keyStr, vars):
hasVars = true
outLines = append(outLines, ln)
case inReg:
for pkg := range st.ImportPackages { // remove package prefixes
if !bytes.Contains(ln, imp) {
ln = bytes.ReplaceAll(ln, []byte(pkg+"."), []byte{})
}
}
if bytes.HasPrefix(ln, fnc) && bytes.Contains(ln, kernel) {
sysnm := strings.TrimSpace(string(ln[bytes.LastIndex(ln, kernel)+len(kernel):]))
sy := st.System(sysnm)
fcall := string(ln[5:])
lp := strings.Index(fcall, "(")
rp := strings.LastIndex(fcall, ")")
args := fcall[lp+1 : rp]
fnm := fcall[:lp]
funcode := ""
for ki := li + 1; ki < len(lines); ki++ {
kl := lines[ki]
if len(kl) > 0 && kl[0] == '}' {
break
}
funcode += string(kl) + "\n"
}
kn := &Kernel{Name: fnm, Args: args, FuncCode: funcode}
sy.Kernels[fnm] = kn
if st.Config.Debug {
fmt.Println("\tAdded kernel:", fnm, "args:", args, "system:", sy.Name)
}
}
outLines = append(outLines, ln)
case isKey && bytes.HasPrefix(keyStr, start):
inReg = true
case isKey && bytes.HasPrefix(keyStr, nowgsl):
inReg = true
inNoHlsl = true
outLines = append(outLines, ln) // key to include self here
case isKey && bytes.HasPrefix(keyStr, wgsl):
inReg = true
inHlsl = true
outLines = append(outLines, ln)
}
}
return
}
// AppendGoHeader appends Go header
func (st *State) AppendGoHeader(lines [][]byte) [][]byte {
olns := make([][]byte, 0, len(lines)+10)
olns = append(olns, []byte("package imports"))
olns = append(olns, []byte(`import (
"math"
"cogentcore.org/lab/gosl/slbool"
"cogentcore.org/lab/gosl/slrand"
"cogentcore.org/lab/gosl/sltype"
"cogentcore.org/lab/tensor"
`))
for impath := range st.GoImports {
if strings.Contains(impath, "core/goal/gosl") {
continue
}
olns = append(olns, []byte("\t\""+impath+"\""))
}
olns = append(olns, []byte(")"))
olns = append(olns, lines...)
SlBoolReplace(olns)
return olns
}
// ExtractWGSL extracts the WGSL code embedded within .Go files,
// which is commented out in the Go code -- remove comments.
func (st *State) ExtractWGSL(lines [][]byte) [][]byte {
key := []byte("//gosl:")
wgsl := []byte("wgsl")
nowgsl := []byte("nowgsl")
end := []byte("end")
stComment := []byte("/*")
edComment := []byte("*/")
comment := []byte("// ")
pack := []byte("package")
imp := []byte("import")
lparen := []byte("(")
rparen := []byte(")")
mx := min(10, len(lines))
stln := 0
gotImp := false
for li := 0; li < mx; li++ {
ln := lines[li]
switch {
case bytes.HasPrefix(ln, pack):
stln = li + 1
case bytes.HasPrefix(ln, imp):
if bytes.HasSuffix(ln, lparen) {
gotImp = true
} else {
stln = li + 1
}
case gotImp && bytes.HasPrefix(ln, rparen):
stln = li + 1
}
}
lines = lines[stln:] // get rid of package, import
inHlsl := false
inNoHlsl := false
noHlslStart := 0
for li := 0; li < len(lines); li++ {
ln := lines[li]
isKey := bytes.HasPrefix(ln, key)
var keyStr []byte
if isKey {
keyStr = ln[len(key):]
// fmt.Printf("key: %s\n", string(keyStr))
}
switch {
case inNoHlsl && isKey && bytes.HasPrefix(keyStr, end):
lines = slices.Delete(lines, noHlslStart, li+1)
li -= ((li + 1) - noHlslStart)
inNoHlsl = false
case inHlsl && isKey && bytes.HasPrefix(keyStr, end):
lines = slices.Delete(lines, li, li+1)
li--
inHlsl = false
case inHlsl:
switch {
case bytes.HasPrefix(ln, stComment) || bytes.HasPrefix(ln, edComment):
lines = slices.Delete(lines, li, li+1)
li--
case bytes.HasPrefix(ln, comment):
lines[li] = ln[3:]
}
case isKey && bytes.HasPrefix(keyStr, wgsl):
inHlsl = true
lines = slices.Delete(lines, li, li+1)
li--
case isKey && bytes.HasPrefix(keyStr, nowgsl):
inNoHlsl = true
noHlslStart = li
}
}
return lines
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
import (
"bytes"
"fmt"
"io/fs"
"log"
"os"
"path/filepath"
"strings"
"cogentcore.org/core/base/fsx"
"golang.org/x/tools/go/packages"
)
// wgslFile returns the file with a ".wgsl" extension
func wgslFile(fn string) string {
f, _ := fsx.ExtSplit(fn)
return f + ".wgsl"
}
// bareFile returns the file with no extention
func bareFile(fn string) string {
f, _ := fsx.ExtSplit(fn)
return f
}
func ReadFileLines(fn string) ([][]byte, error) {
nl := []byte("\n")
buf, err := os.ReadFile(fn)
if err != nil {
fmt.Println(err)
return nil, err
}
lines := bytes.Split(buf, nl)
return lines, nil
}
func WriteFileLines(fn string, lines [][]byte) error {
res := bytes.Join(lines, []byte("\n"))
return os.WriteFile(fn, res, 0644)
}
// HasGoslTag returns true if given file has a //gosl: tag
func (st *State) HasGoslTag(lines [][]byte) bool {
key := []byte("//gosl:")
pkg := []byte("package ")
for _, ln := range lines {
tln := bytes.TrimSpace(ln)
if st.Package == "" {
if bytes.HasPrefix(tln, pkg) {
st.Package = string(bytes.TrimPrefix(tln, pkg))
}
}
if bytes.HasPrefix(tln, key) {
return true
}
}
return false
}
func IsGoFile(f fs.DirEntry) bool {
name := f.Name()
return !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go") && !f.IsDir()
}
func IsWGSLFile(f fs.DirEntry) bool {
name := f.Name()
return !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".wgsl") && !f.IsDir()
}
// ProjectFiles gets the files in the current directory.
func (st *State) ProjectFiles() {
fls := fsx.Filenames(".", ".go")
st.GoFiles = make(map[string]*File)
st.GoVarsFiles = make(map[string]*File)
for _, fn := range fls {
fl := &File{Name: fn}
var err error
fl.Lines, err = ReadFileLines(fn)
if err != nil {
continue
}
if !st.HasGoslTag(fl.Lines) {
continue
}
st.GoFiles[fn] = fl
st.ImportFiles(fl.Lines)
}
}
// ImportFiles checks the given content for //gosl:import tags
// and imports the package if so.
func (st *State) ImportFiles(lines [][]byte) {
key := []byte("//gosl:import ")
for _, ln := range lines {
tln := bytes.TrimSpace(ln)
if !bytes.HasPrefix(tln, key) {
continue
}
impath := strings.TrimSpace(string(tln[len(key):]))
if impath[0] == '"' {
impath = impath[1:]
}
if impath[len(impath)-1] == '"' {
impath = impath[:len(impath)-1]
}
_, ok := st.GoImports[impath]
if ok {
continue
}
var pkgs []*packages.Package
var err error
pkgs, err = packages.Load(&packages.Config{Mode: packages.NeedName | packages.NeedFiles}, impath)
if err != nil {
fmt.Println(err)
continue
}
pfls := make(map[string]*File)
st.GoImports[impath] = pfls
pkg := pkgs[0]
gofls := pkg.GoFiles
if len(gofls) == 0 {
fmt.Printf("WARNING: no go files found in path: %s\n", impath)
}
for _, gf := range gofls {
lns, err := ReadFileLines(gf)
if err != nil {
continue
}
if !st.HasGoslTag(lns) {
continue
}
_, fo := filepath.Split(gf)
pfls[fo] = &File{Name: fo, Lines: lns}
st.ImportFiles(lns)
// fmt.Printf("added file: %s from package: %s\n", gf, impath)
}
st.GoImports[impath] = pfls
}
}
// RemoveGenFiles removes .go, .wgsl, .spv files in shader generated dir
func RemoveGenFiles(dir string) {
err := filepath.WalkDir(dir, func(path string, f fs.DirEntry, err error) error {
if err != nil {
return err
}
if IsGoFile(f) || IsWGSLFile(f) {
os.Remove(path)
}
return nil
})
if err != nil {
log.Println(err)
}
}
// CopyPackageFile copies given file name from given package path
// into the current imports directory.
// e.g., "slrand.wgsl", "cogentcore.org/lab/gosl/slrand"
func (st *State) CopyPackageFile(fnm, packagePath string) error {
for _, f := range st.SLImportFiles {
if f.Name == fnm { // don't re-import
return nil
}
}
tofn := filepath.Join(st.ImportsDir, fnm)
pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName | packages.NeedFiles}, packagePath)
if err != nil {
fmt.Println(err)
return err
}
if len(pkgs) != 1 {
err = fmt.Errorf("%s package not found", packagePath)
fmt.Println(err)
return err
}
pkg := pkgs[0]
var fn string
if len(pkg.GoFiles) > 0 {
fn = pkg.GoFiles[0]
} else if len(pkg.OtherFiles) > 0 {
fn = pkg.GoFiles[0]
} else {
err = fmt.Errorf("No files found in package: %s", packagePath)
fmt.Println(err)
return err
}
dir, _ := filepath.Split(fn)
fmfn := filepath.Join(dir, fnm)
lines, err := CopyFile(fmfn, tofn)
if err == nil {
lines = SlRemoveComments(lines)
st.SLImportFiles = append(st.SLImportFiles, &File{Name: fnm, Lines: lines})
}
return nil
}
func CopyFile(src, dst string) ([][]byte, error) {
lines, err := ReadFileLines(src)
if err != nil {
return lines, err
}
err = WriteFileLines(dst, lines)
if err != nil {
return lines, err
}
return lines, err
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
import (
"fmt"
"os"
"slices"
"strings"
"golang.org/x/exp/maps"
)
// genSysName is the name to use for system in generating code.
// if only one system, the name is empty
func (st *State) genSysName(sy *System) string {
if len(st.Systems) == 1 {
return ""
}
return sy.Name
}
// genSysVar is the name to use for system in generating code.
// if only one system, the name is empty
func (st *State) genSysVar(sy *System) string {
return fmt.Sprintf("GPU%sSystem", st.genSysName(sy))
}
// GenGPU generates and writes the Go GPU helper code
func (st *State) GenGPU() {
var b strings.Builder
header := `// Code generated by "gosl"; DO NOT EDIT
package %s
import (
"embed"
"unsafe"
"cogentcore.org/core/gpu"
"cogentcore.org/lab/tensor"
)
//go:embed %s/*.wgsl
var shaders embed.FS
// ComputeGPU is the compute gpu device
var ComputeGPU *gpu.GPU
// UseGPU indicates whether to use GPU vs. CPU.
var UseGPU bool
`
b.WriteString(fmt.Sprintf(header, st.Package, st.Config.Output))
sys := maps.Keys(st.Systems)
slices.Sort(sys)
for _, synm := range sys {
sy := st.Systems[synm]
b.WriteString(fmt.Sprintf("// %s is a GPU compute System with kernels operating on the\n// same set of data variables.\n", st.genSysVar(sy)))
b.WriteString(fmt.Sprintf("var %s *gpu.ComputeSystem\n", st.genSysVar(sy)))
}
venum := `
// GPUVars is an enum for GPU variables, for specifying what to sync.
type GPUVars int32 //enums:enum
const (
`
b.WriteString(venum)
vidx := 0
hasTensors := false
for _, synm := range sys {
sy := st.Systems[synm]
if sy.NTensors > 0 {
hasTensors = true
}
for _, gp := range sy.Groups {
for _, vr := range gp.Vars {
b.WriteString(fmt.Sprintf("\t%sVar GPUVars = %d\n", vr.Name, vidx))
vidx++
}
}
}
b.WriteString(")\n")
if hasTensors {
b.WriteString("\n// Tensor stride variables\n")
for _, synm := range sys {
sy := st.Systems[synm]
genSynm := st.genSysName(sy)
b.WriteString(fmt.Sprintf("var %sTensorStrides tensor.Uint32\n", genSynm))
}
} else {
b.WriteString("\n// Dummy tensor stride variable to avoid import error\n")
b.WriteString("var __TensorStrides tensor.Uint32\n")
}
initf := `
// GPUInit initializes the GPU compute system,
// configuring system(s), variables and kernels.
// It is safe to call multiple times: detects if already run.
func GPUInit() {
if ComputeGPU != nil {
return
}
gp := gpu.NewComputeGPU()
ComputeGPU = gp
`
b.WriteString(initf)
for _, synm := range sys {
sy := st.Systems[synm]
b.WriteString(st.GenGPUSystemInit(sy))
}
b.WriteString("}\n\n")
release := `// GPURelease releases the GPU compute system resources.
// Call this at program exit.
func GPURelease() {
`
b.WriteString(release)
sysRelease := ` if %[1]s != nil {
%[1]s.Release()
%[1]s = nil
}
`
for _, synm := range sys {
sy := st.Systems[synm]
b.WriteString(fmt.Sprintf(sysRelease, st.genSysVar(sy)))
}
gpuRelease := `
if ComputeGPU != nil {
ComputeGPU.Release()
ComputeGPU = nil
}
}
`
b.WriteString(gpuRelease)
for _, synm := range sys {
sy := st.Systems[synm]
b.WriteString(st.GenGPUSystemOps(sy))
}
gs := b.String()
fn := "gosl.go"
os.WriteFile(fn, []byte(gs), 0644)
}
// GenGPUSystemInit generates GPU Init code for given system.
func (st *State) GenGPUSystemInit(sy *System) string {
var b strings.Builder
syvar := st.genSysVar(sy)
b.WriteString("\t{\n")
b.WriteString(fmt.Sprintf("\t\tsy := gpu.NewComputeSystem(gp, %q)\n", sy.Name))
b.WriteString(fmt.Sprintf("\t\t%s = sy\n", syvar))
kns := maps.Keys(sy.Kernels)
slices.Sort(kns)
for _, knm := range kns {
kn := sy.Kernels[knm]
b.WriteString(fmt.Sprintf("\t\tgpu.NewComputePipelineShaderFS(shaders, %q, sy)\n", kn.Filename))
}
b.WriteString("\t\tvars := sy.Vars()\n")
for gi, gp := range sy.Groups {
b.WriteString("\t\t{\n")
gtyp := "gpu.Storage"
if gp.Uniform {
gtyp = "gpu.Uniform"
}
b.WriteString(fmt.Sprintf("\t\t\tsgp := vars.AddGroup(%s, %q)\n", gtyp, gp.Name))
b.WriteString("\t\t\tvar vr *gpu.Var\n\t\t\t_ = vr\n")
if sy.NTensors > 0 && gi == 0 {
b.WriteString(fmt.Sprintf("\t\t\tvr = sgp.Add(%q, gpu.%s, 1, gpu.ComputeShader)\n", "TensorStrides", "Uint32"))
b.WriteString("\t\t\tvr.ReadOnly = true\n")
}
for _, vr := range gp.Vars {
if vr.Tensor {
typ := strings.TrimPrefix(vr.Type, "tensor.")
b.WriteString(fmt.Sprintf("\t\t\tvr = sgp.Add(%q, gpu.%s, 1, gpu.ComputeShader)\n", vr.Name, typ))
} else {
b.WriteString(fmt.Sprintf("\t\t\tvr = sgp.AddStruct(%q, int(unsafe.Sizeof(%s{})), 1, gpu.ComputeShader)\n", vr.Name, vr.SLType()))
}
if vr.ReadOnly {
b.WriteString("\t\t\tvr.ReadOnly = true\n")
}
}
b.WriteString("\t\t\tsgp.SetNValues(1)\n")
b.WriteString("\t\t}\n")
}
b.WriteString("\t\tsy.Config()\n")
b.WriteString("\t}\n")
return b.String()
}
// GenGPUSystemOps generates GPU helper functions for given system.
func (st *State) GenGPUSystemOps(sy *System) string {
var b strings.Builder
syvar := st.genSysVar(sy)
synm := st.genSysName(sy)
// 1 = kernel, 2 = system var, 3 = sysname (blank for 1 default)
run := `// Run%[1]s runs the %[1]s kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// Can call multiple Run* kernels in a row, which are then all launched
// in the same command submission on the GPU, which is by far the most efficient.
// MUST call RunDone (with optional vars to sync) after all Run calls.
// Alternatively, a single-shot RunOne%[1]s call does Run and Done for a
// single run-and-sync case.
func Run%[1]s(n int) {
if UseGPU {
Run%[1]sGPU(n)
} else {
Run%[1]sCPU(n)
}
}
// Run%[1]sGPU runs the %[1]s kernel on the GPU. See [Run%[1]s] for more info.
func Run%[1]sGPU(n int) {
sy := %[2]s
pl := sy.ComputePipelines[%[1]q]
ce, _ := sy.BeginComputePass()
pl.Dispatch1D(ce, n, 64)
}
// Run%[1]sCPU runs the %[1]s kernel on the CPU.
func Run%[1]sCPU(n int) {
gpu.VectorizeFunc(0, n, %[1]s)
}
// RunOne%[1]s runs the %[1]s kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// This version then calls RunDone with the given variables to sync
// after the Run, for a single-shot Run-and-Done call. If multiple kernels
// can be run in sequence, it is much more efficient to do multiple Run*
// calls followed by a RunDone call.
func RunOne%[1]s(n int, syncVars ...GPUVars) {
if UseGPU {
Run%[1]sGPU(n)
RunDone%[3]s(syncVars...)
} else {
Run%[1]sCPU(n)
}
}
`
// 1 = sysname (blank for 1 default), 2 = system var
runDone := `// RunDone%[1]s must be called after Run* calls to start compute kernels.
// This actually submits the kernel jobs to the GPU, and adds commands
// to synchronize the given variables back from the GPU to the CPU.
// After this function completes, the GPU results will be available in
// the specified variables.
func RunDone%[1]s(syncVars ...GPUVars) {
if !UseGPU {
return
}
sy := %[2]s
sy.ComputeEncoder.End()
%[1]sReadFromGPU(syncVars...)
sy.EndComputePass()
%[1]sSyncFromGPU(syncVars...)
}
// %[1]sToGPU copies given variables to the GPU for the system.
func %[1]sToGPU(vars ...GPUVars) {
if !UseGPU {
return
}
sy := %[2]s
syVars := sy.Vars()
for _, vr := range vars {
switch vr {
`
kns := maps.Keys(sy.Kernels)
slices.Sort(kns)
for _, knm := range kns {
kn := sy.Kernels[knm]
b.WriteString(fmt.Sprintf(run, kn.Name, syvar, synm))
}
b.WriteString(fmt.Sprintf(runDone, synm, syvar))
for gi, gp := range sy.Groups {
for _, vr := range gp.Vars {
b.WriteString(fmt.Sprintf("\t\tcase %sVar:\n", vr.Name))
b.WriteString(fmt.Sprintf("\t\t\tv, _ := syVars.ValueByIndex(%d, %q, 0)\n", gi, vr.Name))
vv := vr.Name
if vr.Tensor {
vv += ".Values"
}
b.WriteString(fmt.Sprintf("\t\t\tgpu.SetValueFrom(v, %s)\n", vv))
}
}
b.WriteString("\t\t}\n\t}\n}\n")
runSync := `// Run%[1]sGPUSync can be called to synchronize data between CPU and GPU.
// Any prior ToGPU* calls will execute to send data to the GPU,
// and any subsequent RunDone* calls will copy data back from the GPU.
func Run%[1]sGPUSync() {
if !UseGPU {
return
}
sy := %[2]s
sy.BeginComputePass()
}
`
b.WriteString(fmt.Sprintf(runSync, synm, syvar))
if sy.NTensors > 0 {
tensorStrides := `
// %[1]sToGPUTensorStrides gets tensor strides and starts copying to the GPU.
func %[1]sToGPUTensorStrides() {
if !UseGPU {
return
}
sy := %[2]s
syVars := sy.Vars()
`
b.WriteString(fmt.Sprintf(tensorStrides, synm, syvar))
strvar := synm + "TensorStrides"
b.WriteString(fmt.Sprintf("\t%s.SetShapeSizes(%d)\n", strvar, sy.NTensors*10))
for _, gp := range sy.Groups {
for _, vr := range gp.Vars {
if !vr.Tensor {
continue
}
for d := range vr.TensorDims {
b.WriteString(fmt.Sprintf("\t%sTensorStrides.SetInt1D(%s.Shape().Strides[%d], %d)\n", synm, vr.Name, d, vr.TensorIndex*10+d))
}
}
}
b.WriteString(fmt.Sprintf("\tv, _ := syVars.ValueByIndex(0, %q, 0)\n", strvar))
b.WriteString(fmt.Sprintf("\tgpu.SetValueFrom(v, %s.Values)\n", strvar))
b.WriteString("}\n")
}
fmGPU := `
// %[1]sReadFromGPU starts the process of copying vars to the GPU.
func %[1]sReadFromGPU(vars ...GPUVars) {
sy := %[2]s
syVars := sy.Vars()
for _, vr := range vars {
switch vr {
`
b.WriteString(fmt.Sprintf(fmGPU, synm, syvar))
for gi, gp := range sy.Groups {
for _, vr := range gp.Vars {
b.WriteString(fmt.Sprintf("\t\tcase %sVar:\n", vr.Name))
b.WriteString(fmt.Sprintf("\t\t\tv, _ := syVars.ValueByIndex(%d, %q, 0)\n", gi, vr.Name))
b.WriteString("\t\t\tv.GPUToRead(sy.CommandEncoder)\n")
}
}
b.WriteString("\t\t}\n\t}\n}\n")
syncGPU := `
// %[1]sSyncFromGPU synchronizes vars from the GPU to the actual variable.
func %[1]sSyncFromGPU(vars ...GPUVars) {
sy := %[2]s
syVars := sy.Vars()
for _, vr := range vars {
switch vr {
`
b.WriteString(fmt.Sprintf(syncGPU, synm, syvar))
for gi, gp := range sy.Groups {
for _, vr := range gp.Vars {
b.WriteString(fmt.Sprintf("\t\tcase %sVar:\n", vr.Name))
b.WriteString(fmt.Sprintf("\t\t\tv, _ := syVars.ValueByIndex(%d, %q, 0)\n", gi, vr.Name))
b.WriteString(fmt.Sprintf("\t\t\tv.ReadSync()\n"))
vv := vr.Name
if vr.Tensor {
vv += ".Values"
}
b.WriteString(fmt.Sprintf("\t\t\tgpu.ReadToBytes(v, %s)\n", vv))
}
}
b.WriteString("\t\t}\n\t}\n}\n")
getFun := `
// Get%[1]s returns a pointer to the given global variable:
// [%[1]s] []%[2]s at given index.
// To ensure that values are updated on the GPU, you must call [Set%[1]s].
// after all changes have been made.
func Get%[1]s(idx uint32) *%[2]s {
return &%[1]s[idx]
}
`
for _, gp := range sy.Groups {
for _, vr := range gp.Vars {
if vr.Tensor {
continue
}
b.WriteString(fmt.Sprintf(getFun, vr.Name, vr.SLType()))
}
}
return b.String()
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
import (
"fmt"
"strings"
)
// GenKernelHeader returns the novel generated WGSL kernel code
// for given kernel, which goes at the top of the resulting file.
func (st *State) GenKernelHeader(sy *System, kn *Kernel, avars map[string]*Var) string {
var b strings.Builder
b.WriteString("// Code generated by \"gosl\"; DO NOT EDIT\n")
b.WriteString("// kernel: " + kn.Name + "\n\n")
for gi, gp := range sy.Groups {
if gp.Doc != "" {
b.WriteString("// " + gp.Doc + "\n")
}
str := "storage"
if gp.Uniform {
str = "uniform"
}
viOff := 0
if gi == 0 && sy.NTensors > 0 {
access := ", read"
if gp.Uniform {
access = ""
}
viOff = 1
b.WriteString("@group(0) @binding(0)\n")
b.WriteString(fmt.Sprintf("var<%s%s> TensorStrides: array<u32>;\n", str, access))
}
for vi, vr := range gp.Vars {
access := ", read_write"
if vr.ReadOnly {
access = ", read"
}
if gp.Uniform {
access = ""
}
if vr.Doc != "" {
b.WriteString("// " + vr.Doc + "\n")
}
b.WriteString(fmt.Sprintf("@group(%d) @binding(%d)\n", gi, vi+viOff))
b.WriteString(fmt.Sprintf("var<%s%s> %s: ", str, access, vr.Name))
if _, ok := avars[vr.Name]; ok {
b.WriteString(fmt.Sprintf("array<atomic<%s>>;\n", vr.SLType()))
} else {
b.WriteString(fmt.Sprintf("array<%s>;\n", vr.SLType()))
}
}
}
b.WriteString("\nalias GPUVars = i32;\n\n") // gets included when iteratively processing enumgen.go
b.WriteString("@compute @workgroup_size(64, 1, 1)\n")
b.WriteString("fn main(@builtin(workgroup_id) wgid: vec3<u32>, @builtin(num_workgroups) nwg: vec3<u32>, @builtin(local_invocation_index) loci: u32) {\n")
b.WriteString("\tlet idx = loci + (wgid.x + wgid.y * nwg.x + wgid.z * nwg.x * nwg.y) * 64;\n")
b.WriteString(fmt.Sprintf("\t%s(idx);\n", kn.Name))
b.WriteString("}\n")
b.WriteString(st.GenTensorFuncs(sy))
return b.String()
}
// GenTensorFuncs returns the generated WGSL code
// for indexing the tensors in given system.
func (st *State) GenTensorFuncs(sy *System) string {
var b strings.Builder
done := make(map[string]bool)
for _, gp := range sy.Groups {
for _, vr := range gp.Vars {
if !vr.Tensor {
continue
}
fn := vr.IndexFunc()
if _, ok := done[fn]; ok {
continue
}
done[fn] = true
typ := "u32"
b.WriteString("\nfn " + fn + "(")
nd := vr.TensorDims
for d := range nd {
b.WriteString(fmt.Sprintf("s%d: %s, ", d, typ))
}
for d := range nd {
b.WriteString(fmt.Sprintf("i%d: u32", d))
if d < nd-1 {
b.WriteString(", ")
}
}
b.WriteString(") -> u32 {\n\treturn ")
for d := range nd {
b.WriteString(fmt.Sprintf("s%d * i%d", d, d))
if d < nd-1 {
b.WriteString(" + ")
}
}
b.WriteString(";\n}\n")
}
}
return b.String()
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file is largely copied from the Go source,
// src/go/printer/gobuild.go:
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
import (
"go/build/constraint"
"sort"
"text/tabwriter"
)
func (p *printer) fixGoBuildLines() {
if len(p.goBuild)+len(p.plusBuild) == 0 {
return
}
// Find latest possible placement of //go:build and // +build comments.
// That's just after the last blank line before we find a non-comment.
// (We'll add another blank line after our comment block.)
// When we start dropping // +build comments, we can skip over /* */ comments too.
// Note that we are processing tabwriter input, so every comment
// begins and ends with a tabwriter.Escape byte.
// And some newlines have turned into \f bytes.
insert := 0
for pos := 0; ; {
// Skip leading space at beginning of line.
blank := true
for pos < len(p.output) && (p.output[pos] == ' ' || p.output[pos] == '\t') {
pos++
}
// Skip over // comment if any.
if pos+3 < len(p.output) && p.output[pos] == tabwriter.Escape && p.output[pos+1] == '/' && p.output[pos+2] == '/' {
blank = false
for pos < len(p.output) && !isNL(p.output[pos]) {
pos++
}
}
// Skip over \n at end of line.
if pos >= len(p.output) || !isNL(p.output[pos]) {
break
}
pos++
if blank {
insert = pos
}
}
// If there is a //go:build comment before the place we identified,
// use that point instead. (Earlier in the file is always fine.)
if len(p.goBuild) > 0 && p.goBuild[0] < insert {
insert = p.goBuild[0]
} else if len(p.plusBuild) > 0 && p.plusBuild[0] < insert {
insert = p.plusBuild[0]
}
var x constraint.Expr
switch len(p.goBuild) {
case 0:
// Synthesize //go:build expression from // +build lines.
for _, pos := range p.plusBuild {
y, err := constraint.Parse(p.commentTextAt(pos))
if err != nil {
x = nil
break
}
if x == nil {
x = y
} else {
x = &constraint.AndExpr{X: x, Y: y}
}
}
case 1:
// Parse //go:build expression.
x, _ = constraint.Parse(p.commentTextAt(p.goBuild[0]))
}
var block []byte
if x == nil {
// Don't have a valid //go:build expression to treat as truth.
// Bring all the lines together but leave them alone.
// Note that these are already tabwriter-escaped.
for _, pos := range p.goBuild {
block = append(block, p.lineAt(pos)...)
}
for _, pos := range p.plusBuild {
block = append(block, p.lineAt(pos)...)
}
} else {
block = append(block, tabwriter.Escape)
block = append(block, "//go:build "...)
block = append(block, x.String()...)
block = append(block, tabwriter.Escape, '\n')
if len(p.plusBuild) > 0 {
lines, err := constraint.PlusBuildLines(x)
if err != nil {
lines = []string{"// +build error: " + err.Error()}
}
for _, line := range lines {
block = append(block, tabwriter.Escape)
block = append(block, line...)
block = append(block, tabwriter.Escape, '\n')
}
}
}
block = append(block, '\n')
// Build sorted list of lines to delete from remainder of output.
toDelete := append(p.goBuild, p.plusBuild...)
sort.Ints(toDelete)
// Collect output after insertion point, with lines deleted, into after.
var after []byte
start := insert
for _, end := range toDelete {
if end < start {
continue
}
after = appendLines(after, p.output[start:end])
start = end + len(p.lineAt(end))
}
after = appendLines(after, p.output[start:])
if n := len(after); n >= 2 && isNL(after[n-1]) && isNL(after[n-2]) {
after = after[:n-1]
}
p.output = p.output[:insert]
p.output = append(p.output, block...)
p.output = append(p.output, after...)
}
// appendLines is like append(x, y...)
// but it avoids creating doubled blank lines,
// which would not be gofmt-standard output.
// It assumes that only whole blocks of lines are being appended,
// not line fragments.
func appendLines(x, y []byte) []byte {
if len(y) > 0 && isNL(y[0]) && // y starts in blank line
(len(x) == 0 || len(x) >= 2 && isNL(x[len(x)-1]) && isNL(x[len(x)-2])) { // x is empty or ends in blank line
y = y[1:] // delete y's leading blank line
}
return append(x, y...)
}
func (p *printer) lineAt(start int) []byte {
pos := start
for pos < len(p.output) && !isNL(p.output[pos]) {
pos++
}
if pos < len(p.output) {
pos++
}
return p.output[start:pos]
}
func (p *printer) commentTextAt(start int) string {
if start < len(p.output) && p.output[start] == tabwriter.Escape {
start++
}
pos := start
for pos < len(p.output) && p.output[pos] != tabwriter.Escape && !isNL(p.output[pos]) {
pos++
}
return string(p.output[start:pos])
}
func isNL(b byte) bool {
return b == '\n' || b == '\f'
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
import (
"fmt"
"go/ast"
"os"
"path/filepath"
"reflect"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/stack"
)
// System represents a ComputeSystem, and its kernels and variables.
type System struct {
Name string
// Kernels are the kernels using this compute system.
Kernels map[string]*Kernel
// Groups are the variables for this compute system.
Groups []*Group
// NTensors is the number of tensor vars.
NTensors int
}
func NewSystem(name string) *System {
sy := &System{Name: name}
sy.Kernels = make(map[string]*Kernel)
return sy
}
// Kernel represents a kernel function, which is the basis for
// each wgsl generated code file.
type Kernel struct {
Name string
Args string
// Filename is the name of the kernel shader file, e.g., shaders/Compute.wgsl
Filename string
// function code
FuncCode string
// Lines is full shader code
Lines [][]byte
}
// Var represents one global system buffer variable.
type Var struct {
Name string
// comment docs about this var.
Doc string
// Type of variable: either []Type or F32, U32 for tensors
Type string
// ReadOnly indicates that this variable is never read back from GPU,
// specified by the gosl:read-only property in the variable comments.
// It is important to optimize GPU memory usage to indicate this.
ReadOnly bool
// True if a tensor type
Tensor bool
// Number of dimensions
TensorDims int
// data kind of the tensor
TensorKind reflect.Kind
// index of tensor in list of tensor variables, for indexing.
TensorIndex int
}
func (vr *Var) SetTensorKind() {
kindStr := strings.TrimPrefix(vr.Type, "tensor.")
kind := reflect.Float32
switch kindStr {
case "Float32":
kind = reflect.Float32
case "Uint32":
kind = reflect.Uint32
case "Int32":
kind = reflect.Int32
default:
errors.Log(fmt.Errorf("gosl: variable %q type is not supported: %q", vr.Name, kindStr))
}
vr.TensorKind = kind
}
// SLType returns the WGSL type string
func (vr *Var) SLType() string {
if vr.Tensor {
switch vr.TensorKind {
case reflect.Float32:
return "f32"
case reflect.Int32:
return "i32"
case reflect.Uint32:
return "u32"
}
} else {
return vr.Type[2:]
}
return ""
}
// IndexFunc returns the tensor index function name
func (vr *Var) IndexFunc() string {
return fmt.Sprintf("Index%dD", vr.TensorDims)
}
// IndexStride returns the tensor stride variable reference
func (vr *Var) IndexStride(dim int) string {
return fmt.Sprintf("TensorStrides[%d]", vr.TensorIndex*10+dim)
}
// Group represents one variable group.
type Group struct {
Name string
// comment docs about this group
Doc string
// Uniform indicates a uniform group; else default is Storage.
Uniform bool
Vars []*Var
}
// File has contents of a file as lines of bytes.
type File struct {
Name string
Lines [][]byte
}
// GetGlobalVar holds GetVar expression, to Set variable back when done.
type GetGlobalVar struct {
// global variable
Var *Var
// name of temporary variable
TmpVar string
// index passed to the Get function
IdxExpr ast.Expr
}
// State holds the current Go -> WGSL processing state.
type State struct {
// Config options.
Config *Config
// path to shaders/imports directory.
ImportsDir string
// name of the package
Package string
// GoFiles are all the files with gosl content in current directory.
GoFiles map[string]*File
// GoVarsFiles are all the files with gosl:vars content in current directory.
// These must be processed first! they are moved from GoFiles to here.
GoVarsFiles map[string]*File
// GoImports has all the imported files.
GoImports map[string]map[string]*File
// ImportPackages has short package names, to remove from go code
// so everything lives in same main package.
ImportPackages map[string]bool
// Systems has the kernels and variables for each system.
// There is an initial "Default" system when system is not specified.
Systems map[string]*System
// GetFuncs is a map of GetVar, SetVar function names for global vars.
GetFuncs map[string]*Var
// SLImportFiles are all the extracted and translated WGSL files in shaders/imports,
// which are copied into the generated shader kernel files.
SLImportFiles []*File
// generated Go GPU gosl.go file contents
GPUFile File
// ExcludeMap is the compiled map of functions to exclude in Go -> WGSL translation.
ExcludeMap map[string]bool
// GetVarStack is a stack per function definition of GetVar variables
// that need to be set at the end.
GetVarStack stack.Stack[map[string]*GetGlobalVar]
// GetFuncGraph is true if getting the function graph (first pass)
GetFuncGraph bool
// KernelFuncs are the list of functions to include for current kernel.
KernelFuncs map[string]*Function
// FuncGraph is the call graph of functions, for dead code elimination
FuncGraph map[string]*Function
}
func (st *State) Init(cfg *Config) {
st.Config = cfg
st.GoImports = make(map[string]map[string]*File)
st.Systems = make(map[string]*System)
st.ExcludeMap = make(map[string]bool)
ex := strings.Split(cfg.Exclude, ",")
for _, fn := range ex {
st.ExcludeMap[fn] = true
}
st.Systems["Default"] = NewSystem("Default")
}
func (st *State) Run() error {
if gomod := os.Getenv("GO111MODULE"); gomod == "off" {
err := errors.New("gosl only works in go modules mode, but GO111MODULE=off")
return err
}
if st.Config.Output == "" {
st.Config.Output = "shaders"
}
st.ProjectFiles() // get list of all files, recursively gets imports etc.
if len(st.GoFiles) == 0 {
return nil
}
st.ImportsDir = filepath.Join(st.Config.Output, "imports")
os.MkdirAll(st.Config.Output, 0755)
os.MkdirAll(st.ImportsDir, 0755)
RemoveGenFiles(st.Config.Output)
RemoveGenFiles(st.ImportsDir)
st.ExtractFiles() // get .go from project files
st.ExtractImports() // get .go from imports
st.TranslateDir("./" + st.ImportsDir)
st.GenGPU()
return nil
}
// System returns the given system by name, making if not made.
// if name is empty, "Default" is used.
func (st *State) System(sysname string) *System {
if sysname == "" {
sysname = "Default"
}
sy, ok := st.Systems[sysname]
if ok {
return sy
}
sy = NewSystem(sysname)
st.Systems[sysname] = sy
return sy
}
// GlobalVar returns global variable of given name, if found.
func (st *State) GlobalVar(vrnm string) *Var {
if st == nil {
return nil
}
if st.Systems == nil {
return nil
}
for _, sy := range st.Systems {
for _, gp := range sy.Groups {
for _, vr := range gp.Vars {
if vr.Name == vrnm {
return vr
}
}
}
}
return nil
}
// GetTempVar returns temp var for global variable of given name, if found.
func (st *State) GetTempVar(vrnm string) *GetGlobalVar {
if st == nil || st.GetVarStack == nil {
return nil
}
nv := len(st.GetVarStack)
for i := nv - 1; i >= 0; i-- {
gvars := st.GetVarStack[i]
if gv, ok := gvars[vrnm]; ok {
return gv
}
}
return nil
}
// VarsAdded is called when a set of vars has been added; update relevant maps etc.
func (st *State) VarsAdded() {
st.GetFuncs = make(map[string]*Var)
for _, sy := range st.Systems {
tensorIdx := 0
for _, gp := range sy.Groups {
for _, vr := range gp.Vars {
if vr.Tensor {
vr.TensorIndex = tensorIdx
tensorIdx++
continue
}
st.GetFuncs["Get"+vr.Name] = vr
}
}
sy.NTensors = tensorIdx
}
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file is largely copied from the Go source,
// src/go/printer/nodes.go:
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements printing of AST nodes; specifically
// expressions, statements, declarations, and files. It uses
// the print functionality implemented in printer.go.
package gotosl
import (
"fmt"
"go/ast"
"go/token"
"go/types"
"math"
"path"
"slices"
"strconv"
"strings"
"unicode"
"unicode/utf8"
)
// Formatting issues:
// - better comment formatting for /*-style comments at the end of a line (e.g. a declaration)
// when the comment spans multiple lines; if such a comment is just two lines, formatting is
// not idempotent
// - formatting of expression lists
// - should use blank instead of tab to separate one-line function bodies from
// the function header unless there is a group of consecutive one-liners
// ----------------------------------------------------------------------------
// Common AST nodes.
// Print as many newlines as necessary (but at least min newlines) to get to
// the current line. ws is printed before the first line break. If newSection
// is set, the first line break is printed as formfeed. Returns 0 if no line
// breaks were printed, returns 1 if there was exactly one newline printed,
// and returns a value > 1 if there was a formfeed or more than one newline
// printed.
//
// TODO(gri): linebreak may add too many lines if the next statement at "line"
// is preceded by comments because the computation of n assumes
// the current position before the comment and the target position
// after the comment. Thus, after interspersing such comments, the
// space taken up by them is not considered to reduce the number of
// linebreaks. At the moment there is no easy way to know about
// future (not yet interspersed) comments in this function.
func (p *printer) linebreak(line, min int, ws whiteSpace, newSection bool) (nbreaks int) {
n := max(nlimit(line-p.pos.Line), min)
if n > 0 {
p.print(ws)
if newSection {
p.print(formfeed)
n--
nbreaks = 2
}
nbreaks += n
for ; n > 0; n-- {
p.print(newline)
}
}
return
}
// gosl: find any gosl directive in given comments, returns directive(s) and remaining docs
func (p *printer) findDirective(g *ast.CommentGroup) (dirs []string, docs string) {
if g == nil {
return
}
for _, c := range g.List {
if strings.HasPrefix(c.Text, "//gosl:") {
dirs = append(dirs, c.Text[7:])
} else {
docs += c.Text + " "
}
}
return
}
// gosl: hasDirective returns whether directive(s) contains string
func hasDirective(dirs []string, dir string) bool {
for _, d := range dirs {
if strings.Contains(d, dir) {
return true
}
}
return false
}
// gosl: directiveAfter returns the directive after given leading text,
// and a bool indicating if the string was found.
func directiveAfter(dirs []string, dir string) (string, bool) {
for _, d := range dirs {
if strings.HasPrefix(d, dir) {
return strings.TrimSpace(strings.TrimPrefix(d, dir)), true
}
}
return "", false
}
// setComment sets g as the next comment if g != nil and if node comments
// are enabled - this mode is used when printing source code fragments such
// as exports only. It assumes that there is no pending comment in p.comments
// and at most one pending comment in the p.comment cache.
func (p *printer) setComment(g *ast.CommentGroup) {
if g == nil || !p.useNodeComments {
return
}
if p.comments == nil {
// initialize p.comments lazily
p.comments = make([]*ast.CommentGroup, 1)
} else if p.cindex < len(p.comments) {
// for some reason there are pending comments; this
// should never happen - handle gracefully and flush
// all comments up to g, ignore anything after that
p.flush(p.posFor(g.List[0].Pos()), token.ILLEGAL)
p.comments = p.comments[0:1]
// in debug mode, report error
p.internalError("setComment found pending comments")
}
p.comments[0] = g
p.cindex = 0
// don't overwrite any pending comment in the p.comment cache
// (there may be a pending comment when a line comment is
// immediately followed by a lead comment with no other
// tokens between)
if p.commentOffset == infinity {
p.nextComment() // get comment ready for use
}
}
type exprListMode uint
const (
commaTerm exprListMode = 1 << iota // list is optionally terminated by a comma
noIndent // no extra indentation in multi-line lists
)
// If indent is set, a multi-line identifier list is indented after the
// first linebreak encountered.
func (p *printer) identList(list []*ast.Ident, indent bool) {
// convert into an expression list so we can re-use exprList formatting
xlist := make([]ast.Expr, len(list))
for i, x := range list {
xlist[i] = x
}
var mode exprListMode
if !indent {
mode = noIndent
}
p.exprList(token.NoPos, xlist, 1, mode, token.NoPos, false)
}
const filteredMsg = "contains filtered or unexported fields"
// Print a list of expressions. If the list spans multiple
// source lines, the original line breaks are respected between
// expressions.
//
// TODO(gri) Consider rewriting this to be independent of []ast.Expr
// so that we can use the algorithm for any kind of list
//
// (e.g., pass list via a channel over which to range).
func (p *printer) exprList(prev0 token.Pos, list []ast.Expr, depth int, mode exprListMode, next0 token.Pos, isIncomplete bool) {
if len(list) == 0 {
if isIncomplete {
prev := p.posFor(prev0)
next := p.posFor(next0)
if prev.IsValid() && prev.Line == next.Line {
p.print("/* " + filteredMsg + " */")
} else {
p.print(newline)
p.print(indent, "// "+filteredMsg, unindent, newline)
}
}
return
}
prev := p.posFor(prev0)
next := p.posFor(next0)
line := p.lineFor(list[0].Pos())
endLine := p.lineFor(list[len(list)-1].End())
if prev.IsValid() && prev.Line == line && line == endLine {
// all list entries on a single line
for i, x := range list {
if i > 0 {
// use position of expression following the comma as
// comma position for correct comment placement
p.setPos(x.Pos())
p.print(token.COMMA, blank)
}
p.expr0(x, depth)
}
if isIncomplete {
p.print(token.COMMA, blank, "/* "+filteredMsg+" */")
}
return
}
// list entries span multiple lines;
// use source code positions to guide line breaks
// Don't add extra indentation if noIndent is set;
// i.e., pretend that the first line is already indented.
ws := ignore
if mode&noIndent == 0 {
ws = indent
}
// The first linebreak is always a formfeed since this section must not
// depend on any previous formatting.
prevBreak := -1 // index of last expression that was followed by a linebreak
if prev.IsValid() && prev.Line < line && p.linebreak(line, 0, ws, true) > 0 {
ws = ignore
prevBreak = 0
}
// initialize expression/key size: a zero value indicates expr/key doesn't fit on a single line
size := 0
// We use the ratio between the geometric mean of the previous key sizes and
// the current size to determine if there should be a break in the alignment.
// To compute the geometric mean we accumulate the ln(size) values (lnsum)
// and the number of sizes included (count).
lnsum := 0.0
count := 0
// print all list elements
prevLine := prev.Line
for i, x := range list {
line = p.lineFor(x.Pos())
// Determine if the next linebreak, if any, needs to use formfeed:
// in general, use the entire node size to make the decision; for
// key:value expressions, use the key size.
// TODO(gri) for a better result, should probably incorporate both
// the key and the node size into the decision process
useFF := true
// Determine element size: All bets are off if we don't have
// position information for the previous and next token (likely
// generated code - simply ignore the size in this case by setting
// it to 0).
prevSize := size
const infinity = 1e6 // larger than any source line
size = p.nodeSize(x, infinity)
pair, isPair := x.(*ast.KeyValueExpr)
if size <= infinity && prev.IsValid() && next.IsValid() {
// x fits on a single line
if isPair {
size = p.nodeSize(pair.Key, infinity) // size <= infinity
}
} else {
// size too large or we don't have good layout information
size = 0
}
// If the previous line and the current line had single-
// line-expressions and the key sizes are small or the
// ratio between the current key and the geometric mean
// if the previous key sizes does not exceed a threshold,
// align columns and do not use formfeed.
if prevSize > 0 && size > 0 {
const smallSize = 40
if count == 0 || prevSize <= smallSize && size <= smallSize {
useFF = false
} else {
const r = 2.5 // threshold
geomean := math.Exp(lnsum / float64(count)) // count > 0
ratio := float64(size) / geomean
useFF = r*ratio <= 1 || r <= ratio
}
}
needsLinebreak := 0 < prevLine && prevLine < line
if i > 0 {
// Use position of expression following the comma as
// comma position for correct comment placement, but
// only if the expression is on the same line.
if !needsLinebreak {
p.setPos(x.Pos())
}
p.print(token.COMMA)
needsBlank := true
if needsLinebreak {
// Lines are broken using newlines so comments remain aligned
// unless useFF is set or there are multiple expressions on
// the same line in which case formfeed is used.
nbreaks := p.linebreak(line, 0, ws, useFF || prevBreak+1 < i)
if nbreaks > 0 {
ws = ignore
prevBreak = i
needsBlank = false // we got a line break instead
}
// If there was a new section or more than one new line
// (which means that the tabwriter will implicitly break
// the section), reset the geomean variables since we are
// starting a new group of elements with the next element.
if nbreaks > 1 {
lnsum = 0
count = 0
}
}
if needsBlank {
p.print(blank)
}
}
if len(list) > 1 && isPair && size > 0 && needsLinebreak {
// We have a key:value expression that fits onto one line
// and it's not on the same line as the prior expression:
// Use a column for the key such that consecutive entries
// can align if possible.
// (needsLinebreak is set if we started a new line before)
p.expr(pair.Key)
p.setPos(pair.Colon)
p.print(token.COLON, vtab)
p.expr(pair.Value)
} else {
p.expr0(x, depth)
}
if size > 0 {
lnsum += math.Log(float64(size))
count++
}
prevLine = line
}
if mode&commaTerm != 0 && next.IsValid() && p.pos.Line < next.Line {
// Print a terminating comma if the next token is on a new line.
p.print(token.COMMA)
if isIncomplete {
p.print(newline)
p.print("// " + filteredMsg)
}
if ws == ignore && mode&noIndent == 0 {
// unindent if we indented
p.print(unindent)
}
p.print(formfeed) // terminating comma needs a line break to look good
return
}
if isIncomplete {
p.print(token.COMMA, newline)
p.print("// "+filteredMsg, newline)
}
if ws == ignore && mode&noIndent == 0 {
// unindent if we indented
p.print(unindent)
}
}
type paramMode int
const (
funcParam paramMode = iota
funcTParam
typeTParam
)
func (p *printer) parameters(fields *ast.FieldList, mode paramMode) {
openTok, closeTok := token.LPAREN, token.RPAREN
if mode != funcParam {
openTok, closeTok = token.LBRACK, token.RBRACK
}
p.setPos(fields.Opening)
p.print(openTok)
if len(fields.List) > 0 {
prevLine := p.lineFor(fields.Opening)
ws := indent
for i, par := range fields.List {
// determine par begin and end line (may be different
// if there are multiple parameter names for this par
// or the type is on a separate line)
parLineBeg := p.lineFor(par.Pos())
parLineEnd := p.lineFor(par.End())
// separating "," if needed
needsLinebreak := 0 < prevLine && prevLine < parLineBeg
if i > 0 {
// use position of parameter following the comma as
// comma position for correct comma placement, but
// only if the next parameter is on the same line
if !needsLinebreak {
p.setPos(par.Pos())
}
p.print(token.COMMA)
}
// separator if needed (linebreak or blank)
if needsLinebreak && p.linebreak(parLineBeg, 0, ws, true) > 0 {
// break line if the opening "(" or previous parameter ended on a different line
ws = ignore
} else if i > 0 {
p.print(blank)
}
// parameter names
if len(par.Names) > 1 {
nnm := len(par.Names)
for ni, nm := range par.Names {
p.print(nm.Name)
p.print(token.COLON)
p.print(blank)
atyp, isPtr := p.ptrType(stripParensAlways(par.Type))
p.expr(atyp)
if isPtr {
p.print(">")
p.curPtrArgs = append(p.curPtrArgs, par.Names[0])
}
if ni < nnm-1 {
p.print(token.COMMA)
}
}
} else if len(par.Names) > 0 {
// Very subtle: If we indented before (ws == ignore), identList
// won't indent again. If we didn't (ws == indent), identList will
// indent if the identList spans multiple lines, and it will outdent
// again at the end (and still ws == indent). Thus, a subsequent indent
// by a linebreak call after a type, or in the next multi-line identList
// will do the right thing.
p.identList(par.Names, ws == indent)
p.print(token.COLON)
p.print(blank)
// parameter type -- gosl = type first, replace ptr star with `inout`
atyp, isPtr := p.ptrType(stripParensAlways(par.Type))
p.expr(atyp)
if isPtr {
p.print(">")
p.curPtrArgs = append(p.curPtrArgs, par.Names[0])
}
} else {
atyp, isPtr := p.ptrType(stripParensAlways(par.Type))
p.expr(atyp)
if isPtr {
p.print(">")
}
}
prevLine = parLineEnd
}
// if the closing ")" is on a separate line from the last parameter,
// print an additional "," and line break
if closing := p.lineFor(fields.Closing); 0 < prevLine && prevLine < closing {
p.print(token.COMMA)
p.linebreak(closing, 0, ignore, true)
} else if mode == typeTParam && fields.NumFields() == 1 && combinesWithName(fields.List[0].Type) {
// A type parameter list [P T] where the name P and the type expression T syntactically
// combine to another valid (value) expression requires a trailing comma, as in [P *T,]
// (or an enclosing interface as in [P interface(*T)]), so that the type parameter list
// is not gotosld as an array length [P*T].
p.print(token.COMMA)
}
// unindent if we indented
if ws == ignore {
p.print(unindent)
}
}
p.setPos(fields.Closing)
p.print(closeTok)
}
type rwArg struct {
idx *ast.IndexExpr
tmpVar string
}
func (p *printer) assignRwArgs(rwargs []rwArg) {
nrw := len(rwargs)
if nrw == 0 {
return
}
p.print(token.SEMICOLON, blank, formfeed)
for i, rw := range rwargs {
p.expr(rw.idx)
p.print(token.ASSIGN)
tv := rw.tmpVar
if len(tv) > 0 && tv[0] == '&' {
tv = tv[1:]
}
p.print(tv)
if i < nrw-1 {
p.print(token.SEMICOLON, blank)
}
}
}
// gosl: ensure basic literals are properly cast
func (p *printer) goslFixArgs(args []ast.Expr, params *types.Tuple) ([]ast.Expr, []rwArg) {
ags := slices.Clone(args)
mx := min(len(args), params.Len())
var rwargs []rwArg
for i := 0; i < mx; i++ {
ag := args[i]
pr := params.At(i)
switch x := ag.(type) {
case *ast.BasicLit:
typ := pr.Type()
tnm := getLocalTypeName(typ)
nn := normalizedNumber(x)
nn.Value = tnm + "(" + nn.Value + ")"
ags[i] = nn
case *ast.Ident:
if gvar := p.GoToSL.GetTempVar(x.Name); gvar != nil {
x.Name = "&" + x.Name
ags[i] = x
}
case *ast.IndexExpr:
isGlobal, tmpVar, _, _, isReadOnly := p.globalVar(x)
if isGlobal {
ags[i] = &ast.Ident{Name: tmpVar}
if !isReadOnly {
rwargs = append(rwargs, rwArg{idx: x, tmpVar: tmpVar})
}
}
case *ast.UnaryExpr:
if idx, ok := x.X.(*ast.IndexExpr); ok {
isGlobal, tmpVar, _, _, isReadOnly := p.globalVar(idx)
if isGlobal {
ags[i] = &ast.Ident{Name: tmpVar}
if !isReadOnly {
rwargs = append(rwargs, rwArg{idx: idx, tmpVar: tmpVar})
}
}
}
}
}
return ags, rwargs
}
// gosl: ensure basic literals are properly cast
func (p *printer) matchLiteralType(x ast.Expr, typ *ast.Ident) bool {
if lit, ok := x.(*ast.BasicLit); ok {
p.print(typ.Name, token.LPAREN, normalizedNumber(lit), token.RPAREN)
return true
}
return false
}
// gosl: ensure basic literals are properly cast
func (p *printer) matchAssignType(lhs []ast.Expr, rhs []ast.Expr) bool {
if len(rhs) != 1 || len(lhs) != 1 {
return false
}
val := ""
lit, ok := rhs[0].(*ast.BasicLit)
if ok {
val = normalizedNumber(lit).Value
} else {
un, ok := rhs[0].(*ast.UnaryExpr)
if !ok || un.Op != token.SUB {
return false
}
lit, ok = un.X.(*ast.BasicLit)
if !ok {
return false
}
val = "-" + normalizedNumber(lit).Value
}
var err error
var typ types.Type
if id, ok := lhs[0].(*ast.Ident); ok {
typ = p.getIdType(id)
if typ == nil {
return false
}
} else if sl, ok := lhs[0].(*ast.SelectorExpr); ok {
typ, err = p.pathType(sl)
if err != nil {
return false
}
} else if st, ok := lhs[0].(*ast.StarExpr); ok {
if id, ok := st.X.(*ast.Ident); ok {
typ = p.getIdType(id)
if typ == nil {
return false
}
}
if err != nil {
return false
}
}
if typ == nil {
return false
}
tnm := getLocalTypeName(typ)
if tnm[0] == '*' {
tnm = tnm[1:]
}
p.print(tnm, "(", val, ")")
return true
}
// gosl: pathType returns the final type for the selector path.
// a.b.c -> sel.X = (a.b) Sel=c -- returns type of c by tracing
// through the path.
func (p *printer) pathType(x *ast.SelectorExpr) (types.Type, error) {
var paths []*ast.Ident
cur := x
for {
paths = append(paths, cur.Sel)
if sl, ok := cur.X.(*ast.SelectorExpr); ok { // path is itself a selector
cur = sl
continue
}
if id, ok := cur.X.(*ast.Ident); ok {
paths = append(paths, id)
break
}
return nil, fmt.Errorf("gosl pathType: path not a pure selector path")
}
np := len(paths)
idt := p.getIdType(paths[np-1])
if idt == nil {
err := fmt.Errorf("gosl pathType ERROR: cannot find type for name: %q", paths[np-1].Name)
p.userError(err)
return nil, err
}
bt, err := p.getStructType(idt)
if err != nil {
return nil, err
}
for pi := np - 2; pi >= 0; pi-- {
pt := paths[pi]
f := fieldByName(bt, pt.Name)
if f == nil {
return nil, fmt.Errorf("gosl pathType: field not found %q in type: %q:", pt, bt.String())
}
if pi == 0 {
return f.Type(), nil
} else {
bt, err = p.getStructType(f.Type())
if err != nil {
return nil, err
}
}
}
return nil, fmt.Errorf("gosl pathType: path not a pure selector path")
}
// gosl: check if identifier is a pointer arg
func (p *printer) isPtrArg(id *ast.Ident) bool {
for _, pt := range p.curPtrArgs {
if id.Name == pt.Name {
return true
}
}
return false
}
// gosl: dereference pointer vals
func (p *printer) derefPtrArgs(x ast.Expr, prec, depth int) {
if id, ok := x.(*ast.Ident); ok {
if p.isPtrArg(id) {
p.print(token.LPAREN, token.MUL, id, token.RPAREN)
} else {
p.expr1(x, prec, depth)
}
} else {
p.expr1(x, prec, depth)
}
}
// gosl: mark pointer types, returns true if pointer
func (p *printer) ptrType(x ast.Expr) (ast.Expr, bool) {
if u, ok := x.(*ast.StarExpr); ok {
p.print("ptr<function", token.COMMA)
return u.X, true
}
return x, false
}
// gosl: printMethRecv prints the method recv prefix for function. returns true if recv is ptr
func (p *printer) printMethRecv() (isPtr bool, typnm string) {
if u, ok := p.curMethRecv.Type.(*ast.StarExpr); ok {
typnm = u.X.(*ast.Ident).Name
isPtr = true
} else {
typnm = p.curMethRecv.Type.(*ast.Ident).Name
}
return
}
// combinesWithName reports whether a name followed by the expression x
// syntactically combines to another valid (value) expression. For instance
// using *T for x, "name *T" syntactically appears as the expression x*T.
// On the other hand, using P|Q or *P|~Q for x, "name P|Q" or name *P|~Q"
// cannot be combined into a valid (value) expression.
func combinesWithName(x ast.Expr) bool {
switch x := x.(type) {
case *ast.StarExpr:
// name *x.X combines to name*x.X if x.X is not a type element
return !isTypeElem(x.X)
case *ast.BinaryExpr:
return combinesWithName(x.X) && !isTypeElem(x.Y)
case *ast.ParenExpr:
// name(x) combines but we are making sure at
// the call site that x is never parenthesized.
panic("unexpected parenthesized expression")
}
return false
}
// isTypeElem reports whether x is a (possibly parenthesized) type element expression.
// The result is false if x could be a type element OR an ordinary (value) expression.
func isTypeElem(x ast.Expr) bool {
switch x := x.(type) {
case *ast.ArrayType, *ast.StructType, *ast.FuncType, *ast.InterfaceType, *ast.MapType, *ast.ChanType:
return true
case *ast.UnaryExpr:
return x.Op == token.TILDE
case *ast.BinaryExpr:
return isTypeElem(x.X) || isTypeElem(x.Y)
case *ast.ParenExpr:
return isTypeElem(x.X)
}
return false
}
func (p *printer) signature(sig *ast.FuncType, recv *ast.FieldList) {
if sig.TypeParams != nil {
p.parameters(sig.TypeParams, funcTParam)
}
if sig.Params != nil {
if recv != nil {
flist := &ast.FieldList{}
*flist = *recv
flist.List = append(flist.List, sig.Params.List...)
p.parameters(flist, funcParam)
} else {
p.parameters(sig.Params, funcParam)
}
} else if recv != nil {
p.parameters(recv, funcParam)
} else {
p.print(token.LPAREN, token.RPAREN)
}
res := sig.Results
n := res.NumFields()
if n > 0 {
// res != nil
if id, ok := res.List[0].Type.(*ast.Ident); ok {
p.curReturnType = id
}
p.print(blank, "->", blank)
if n == 1 && res.List[0].Names == nil {
// single anonymous res; no ()'s
p.expr(stripParensAlways(res.List[0].Type))
return
}
p.parameters(res, funcParam)
}
}
func identListSize(list []*ast.Ident, maxSize int) (size int) {
for i, x := range list {
if i > 0 {
size += len(", ")
}
size += utf8.RuneCountInString(x.Name)
if size >= maxSize {
break
}
}
return
}
func (p *printer) isOneLineFieldList(list []*ast.Field) bool {
if len(list) != 1 {
return false // allow only one field
}
f := list[0]
if f.Tag != nil || f.Comment != nil {
return false // don't allow tags or comments
}
// only name(s) and type
const maxSize = 30 // adjust as appropriate, this is an approximate value
namesSize := identListSize(f.Names, maxSize)
if namesSize > 0 {
namesSize = 1 // blank between names and types
}
typeSize := p.nodeSize(f.Type, maxSize)
return namesSize+typeSize <= maxSize
}
func (p *printer) setLineComment(text string) {
p.setComment(&ast.CommentGroup{List: []*ast.Comment{{Slash: token.NoPos, Text: text}}})
}
func (p *printer) fieldList(fields *ast.FieldList, isStruct, isIncomplete bool) {
lbrace := fields.Opening
list := fields.List
rbrace := fields.Closing
hasComments := isIncomplete || p.commentBefore(p.posFor(rbrace))
srcIsOneLine := lbrace.IsValid() && rbrace.IsValid() && p.lineFor(lbrace) == p.lineFor(rbrace)
if !hasComments && srcIsOneLine {
// possibly a one-line struct/interface
if len(list) == 0 {
// no blank between keyword and {} in this case
p.setPos(lbrace)
p.print(token.LBRACE)
p.setPos(rbrace)
p.print(token.RBRACE)
return
} else if p.isOneLineFieldList(list) {
// small enough - print on one line
// (don't use identList and ignore source line breaks)
p.setPos(lbrace)
p.print(token.LBRACE, blank)
f := list[0]
if isStruct {
for i, x := range f.Names {
if i > 0 {
// no comments so no need for comma position
p.print(token.COMMA, blank)
}
p.expr(x)
}
p.print(token.COLON)
if len(f.Names) > 0 {
p.print(blank)
}
p.expr(f.Type)
} else { // interface
if len(f.Names) > 0 {
name := f.Names[0] // method name
p.expr(name)
p.print(token.COLON)
p.signature(f.Type.(*ast.FuncType), nil) // don't print "func"
} else {
// embedded interface
p.expr(f.Type)
}
}
p.print(blank)
p.setPos(rbrace)
p.print(token.RBRACE)
return
}
}
// hasComments || !srcIsOneLine
p.print(blank)
p.setPos(lbrace)
p.print(token.LBRACE, indent)
if hasComments || len(list) > 0 {
p.print(formfeed)
}
if isStruct {
sep := vtab
if len(list) == 1 {
sep = blank
}
var line int
for i, f := range list {
if i > 0 {
p.linebreak(p.lineFor(f.Pos()), 1, ignore, p.linesFrom(line) > 0)
}
extraTabs := 0
p.setComment(f.Doc)
p.recordLine(&line)
if len(f.Names) > 1 {
nnm := len(f.Names)
p.setPos(f.Type.Pos())
for ni, nm := range f.Names {
p.print(nm.Name)
p.print(token.COLON)
p.print(sep)
p.expr(f.Type)
if ni < nnm-1 {
p.print(token.COMMA)
p.print(formfeed)
}
}
extraTabs = 1
} else if len(f.Names) > 0 {
// named fields
p.identList(f.Names, false)
p.print(token.COLON)
p.print(sep)
p.expr(f.Type)
extraTabs = 1
} else {
// anonymous field
p.expr(f.Type)
extraTabs = 2
}
p.print(token.COMMA)
// if f.Tag != nil {
// if len(f.Names) > 0 && sep == vtab {
// p.print(sep)
// }
// p.print(sep)
// p.expr(f.Tag)
// extraTabs = 0
// }
if f.Comment != nil {
for ; extraTabs > 0; extraTabs-- {
p.print(sep)
}
p.setComment(f.Comment)
}
}
if isIncomplete {
if len(list) > 0 {
p.print(formfeed)
}
p.flush(p.posFor(rbrace), token.RBRACE) // make sure we don't lose the last line comment
p.setLineComment("// " + filteredMsg)
}
} else { // interface
var line int
var prev *ast.Ident // previous "type" identifier
for i, f := range list {
var name *ast.Ident // first name, or nil
if len(f.Names) > 0 {
name = f.Names[0]
}
if i > 0 {
// don't do a line break (min == 0) if we are printing a list of types
// TODO(gri) this doesn't work quite right if the list of types is
// spread across multiple lines
min := 1
if prev != nil && name == prev {
min = 0
}
p.linebreak(p.lineFor(f.Pos()), min, ignore, p.linesFrom(line) > 0)
}
p.setComment(f.Doc)
p.recordLine(&line)
if name != nil {
// method
p.expr(name)
p.signature(f.Type.(*ast.FuncType), nil) // don't print "func"
prev = nil
} else {
// embedded interface
p.expr(f.Type)
prev = nil
}
p.setComment(f.Comment)
}
if isIncomplete {
if len(list) > 0 {
p.print(formfeed)
}
p.flush(p.posFor(rbrace), token.RBRACE) // make sure we don't lose the last line comment
p.setLineComment("// contains filtered or unexported methods")
}
}
p.print(unindent, formfeed)
p.setPos(rbrace)
p.print(token.RBRACE)
}
// ----------------------------------------------------------------------------
// Expressions
func walkBinary(e *ast.BinaryExpr) (has4, has5 bool, maxProblem int) {
switch e.Op.Precedence() {
case 4:
has4 = true
case 5:
has5 = true
}
switch l := e.X.(type) {
case *ast.BinaryExpr:
if l.Op.Precedence() < e.Op.Precedence() {
// parens will be inserted.
// pretend this is an *ast.ParenExpr and do nothing.
break
}
h4, h5, mp := walkBinary(l)
has4 = has4 || h4
has5 = has5 || h5
maxProblem = max(maxProblem, mp)
}
switch r := e.Y.(type) {
case *ast.BinaryExpr:
if r.Op.Precedence() <= e.Op.Precedence() {
// parens will be inserted.
// pretend this is an *ast.ParenExpr and do nothing.
break
}
h4, h5, mp := walkBinary(r)
has4 = has4 || h4
has5 = has5 || h5
maxProblem = max(maxProblem, mp)
case *ast.StarExpr:
if e.Op == token.QUO { // `*/`
maxProblem = 5
}
case *ast.UnaryExpr:
switch e.Op.String() + r.Op.String() {
case "/*", "&&", "&^":
maxProblem = 5
case "++", "--":
maxProblem = max(maxProblem, 4)
}
}
return
}
func cutoff(e *ast.BinaryExpr, depth int) int {
has4, has5, maxProblem := walkBinary(e)
if maxProblem > 0 {
return maxProblem + 1
}
if has4 && has5 {
if depth == 1 {
return 5
}
return 4
}
if depth == 1 {
return 6
}
return 4
}
func diffPrec(expr ast.Expr, prec int) int {
x, ok := expr.(*ast.BinaryExpr)
if !ok || prec != x.Op.Precedence() {
return 1
}
return 0
}
func reduceDepth(depth int) int {
depth--
if depth < 1 {
depth = 1
}
return depth
}
// Format the binary expression: decide the cutoff and then format.
// Let's call depth == 1 Normal mode, and depth > 1 Compact mode.
// (Algorithm suggestion by Russ Cox.)
//
// The precedences are:
//
// 5 * / % << >> & &^
// 4 + - | ^
// 3 == != < <= > >=
// 2 &&
// 1 ||
//
// The only decision is whether there will be spaces around levels 4 and 5.
// There are never spaces at level 6 (unary), and always spaces at levels 3 and below.
//
// To choose the cutoff, look at the whole expression but excluding primary
// expressions (function calls, parenthesized exprs), and apply these rules:
//
// 1. If there is a binary operator with a right side unary operand
// that would clash without a space, the cutoff must be (in order):
//
// /* 6
// && 6
// &^ 6
// ++ 5
// -- 5
//
// (Comparison operators always have spaces around them.)
//
// 2. If there is a mix of level 5 and level 4 operators, then the cutoff
// is 5 (use spaces to distinguish precedence) in Normal mode
// and 4 (never use spaces) in Compact mode.
//
// 3. If there are no level 4 operators or no level 5 operators, then the
// cutoff is 6 (always use spaces) in Normal mode
// and 4 (never use spaces) in Compact mode.
func (p *printer) binaryExpr(x *ast.BinaryExpr, prec1, cutoff, depth int) {
prec := x.Op.Precedence()
if prec < prec1 {
// parenthesis needed
// Note: The gotoslr inserts an ast.ParenExpr node; thus this case
// can only occur if the AST is created in a different way.
p.print(token.LPAREN)
p.expr0(x, reduceDepth(depth)) // parentheses undo one level of depth
p.print(token.RPAREN)
return
}
printBlank := prec < cutoff
ws := indent
p.expr1(x.X, prec, depth+diffPrec(x.X, prec))
if printBlank {
p.print(blank)
}
xline := p.pos.Line // before the operator (it may be on the next line!)
yline := p.lineFor(x.Y.Pos())
p.setPos(x.OpPos)
if x.Op == token.AND_NOT {
p.print(token.AND, blank, token.TILDE)
} else {
p.print(x.Op)
}
if xline != yline && xline > 0 && yline > 0 {
// at least one line break, but respect an extra empty line
// in the source
if p.linebreak(yline, 1, ws, true) > 0 {
ws = ignore
printBlank = false // no blank after line break
}
}
if printBlank {
p.print(blank)
}
p.expr1(x.Y, prec+1, depth+1)
if ws == ignore {
p.print(unindent)
}
}
func isBinary(expr ast.Expr) bool {
_, ok := expr.(*ast.BinaryExpr)
return ok
}
func (p *printer) expr1(expr ast.Expr, prec1, depth int) {
p.setPos(expr.Pos())
switch x := expr.(type) {
case *ast.BadExpr:
p.print("BadExpr")
case *ast.Ident:
if x.Name == "int" {
p.print("i32")
} else {
p.print(x)
}
case *ast.BinaryExpr:
if depth < 1 {
p.internalError("depth < 1:", depth)
depth = 1
}
p.binaryExpr(x, prec1, cutoff(x, depth), depth)
case *ast.KeyValueExpr:
p.expr(x.Key)
p.setPos(x.Colon)
p.print(token.COLON, blank)
p.expr(x.Value)
case *ast.StarExpr:
const prec = token.UnaryPrec
if prec < prec1 {
// parenthesis needed
p.print(token.LPAREN)
p.print(token.MUL)
p.expr(x.X)
p.print(token.RPAREN)
} else {
// no parenthesis needed
p.print(token.MUL)
p.expr(x.X)
}
case *ast.UnaryExpr:
const prec = token.UnaryPrec
if prec < prec1 {
// parenthesis needed
p.print(token.LPAREN)
p.expr(x)
p.print(token.RPAREN)
} else {
// no parenthesis needed
p.print(x.Op)
if x.Op == token.RANGE {
// TODO(gri) Remove this code if it cannot be reached.
p.print(blank)
}
p.expr1(x.X, prec, depth)
}
case *ast.BasicLit:
if p.PrintConfig.Mode&normalizeNumbers != 0 {
x = normalizedNumber(x)
}
p.print(x)
case *ast.FuncLit:
p.setPos(x.Type.Pos())
p.print(token.FUNC)
// See the comment in funcDecl about how the header size is computed.
startCol := p.out.Column - len("func")
p.signature(x.Type, nil)
p.funcBody(p.distanceFrom(x.Type.Pos(), startCol), blank, x.Body)
case *ast.ParenExpr:
if _, hasParens := x.X.(*ast.ParenExpr); hasParens {
// don't print parentheses around an already parenthesized expression
// TODO(gri) consider making this more general and incorporate precedence levels
p.expr0(x.X, depth)
} else {
p.print(token.LPAREN)
p.expr0(x.X, reduceDepth(depth)) // parentheses undo one level of depth
p.setPos(x.Rparen)
p.print(token.RPAREN)
}
case *ast.SelectorExpr:
p.selectorExpr(x, depth)
case *ast.TypeAssertExpr:
p.expr1(x.X, token.HighestPrec, depth)
p.print(token.PERIOD)
p.setPos(x.Lparen)
p.print(token.LPAREN)
if x.Type != nil {
p.expr(x.Type)
} else {
p.print(token.TYPE)
}
p.setPos(x.Rparen)
p.print(token.RPAREN)
case *ast.IndexExpr:
// TODO(gri): should treat[] like parentheses and undo one level of depth
p.expr1(x.X, token.HighestPrec, 1)
p.setPos(x.Lbrack)
p.print(token.LBRACK)
p.expr0(x.Index, depth+1)
p.setPos(x.Rbrack)
p.print(token.RBRACK)
case *ast.IndexListExpr:
// TODO(gri): as for IndexExpr, should treat [] like parentheses and undo
// one level of depth
p.expr1(x.X, token.HighestPrec, 1)
p.setPos(x.Lbrack)
p.print(token.LBRACK)
p.exprList(x.Lbrack, x.Indices, depth+1, commaTerm, x.Rbrack, false)
p.setPos(x.Rbrack)
p.print(token.RBRACK)
case *ast.SliceExpr:
// TODO(gri): should treat[] like parentheses and undo one level of depth
p.expr1(x.X, token.HighestPrec, 1)
p.setPos(x.Lbrack)
p.print(token.LBRACK)
indices := []ast.Expr{x.Low, x.High}
if x.Max != nil {
indices = append(indices, x.Max)
}
// determine if we need extra blanks around ':'
var needsBlanks bool
if depth <= 1 {
var indexCount int
var hasBinaries bool
for _, x := range indices {
if x != nil {
indexCount++
if isBinary(x) {
hasBinaries = true
}
}
}
if indexCount > 1 && hasBinaries {
needsBlanks = true
}
}
for i, x := range indices {
if i > 0 {
if indices[i-1] != nil && needsBlanks {
p.print(blank)
}
p.print(token.COLON)
if x != nil && needsBlanks {
p.print(blank)
}
}
if x != nil {
p.expr0(x, depth+1)
}
}
p.setPos(x.Rbrack)
p.print(token.RBRACK)
case *ast.CallExpr:
if len(x.Args) > 1 {
depth++
}
fid, isid := x.Fun.(*ast.Ident)
// Conversions to literal function types or <-chan
// types require parentheses around the type.
paren := false
switch t := x.Fun.(type) {
case *ast.FuncType:
paren = true
case *ast.ChanType:
paren = t.Dir == ast.RECV
}
if paren {
p.print(token.LPAREN)
}
if _, ok := x.Fun.(*ast.SelectorExpr); ok {
p.methodExpr(x, depth)
break // handles everything, break out of case
}
args := x.Args
var rwargs []rwArg
if isid {
if p.curFunc != nil {
p.curFunc.Funcs[fid.Name] = p.GoToSL.RecycleFunc(fid.Name)
}
if obj, ok := p.pkg.TypesInfo.Uses[fid]; ok {
if ft, ok := obj.(*types.Func); ok {
sig := ft.Type().(*types.Signature)
args, rwargs = p.goslFixArgs(x.Args, sig.Params())
}
}
}
p.expr1(x.Fun, token.HighestPrec, depth)
if paren {
p.print(token.RPAREN)
}
p.setPos(x.Lparen)
p.print(token.LPAREN)
if x.Ellipsis.IsValid() {
p.exprList(x.Lparen, args, depth, 0, x.Ellipsis, false)
p.setPos(x.Ellipsis)
p.print(token.ELLIPSIS)
if x.Rparen.IsValid() && p.lineFor(x.Ellipsis) < p.lineFor(x.Rparen) {
p.print(token.COMMA, formfeed)
}
} else {
p.exprList(x.Lparen, args, depth, commaTerm, x.Rparen, false)
}
p.setPos(x.Rparen)
p.print(token.RPAREN)
p.assignRwArgs(rwargs)
case *ast.CompositeLit:
// composite literal elements that are composite literals themselves may have the type omitted
lb := token.LBRACE
rb := token.RBRACE
if _, isAry := x.Type.(*ast.ArrayType); isAry {
lb = token.LPAREN
rb = token.RPAREN
}
if x.Type != nil {
p.expr1(x.Type, token.HighestPrec, depth)
}
p.level++
p.setPos(x.Lbrace)
p.print(lb)
p.exprList(x.Lbrace, x.Elts, 1, commaTerm, x.Rbrace, x.Incomplete)
// do not insert extra line break following a /*-style comment
// before the closing '}' as it might break the code if there
// is no trailing ','
mode := noExtraLinebreak
// do not insert extra blank following a /*-style comment
// before the closing '}' unless the literal is empty
if len(x.Elts) > 0 {
mode |= noExtraBlank
}
// need the initial indent to print lone comments with
// the proper level of indentation
p.print(indent, unindent, mode)
p.setPos(x.Rbrace)
p.print(rb, mode)
p.level--
case *ast.Ellipsis:
p.print(token.ELLIPSIS)
if x.Elt != nil {
p.expr(x.Elt)
}
case *ast.ArrayType:
p.print("array")
// p.print(token.LBRACK)
// if x.Len != nil {
// p.expr(x.Len)
// }
// p.print(token.RBRACK)
// p.expr(x.Elt)
case *ast.StructType:
// p.print(token.STRUCT)
p.fieldList(x.Fields, true, x.Incomplete)
case *ast.FuncType:
p.print(token.FUNC)
p.signature(x, nil)
case *ast.InterfaceType:
p.print(token.INTERFACE)
p.fieldList(x.Methods, false, x.Incomplete)
case *ast.MapType:
p.print(token.MAP, token.LBRACK)
p.expr(x.Key)
p.print(token.RBRACK)
p.expr(x.Value)
case *ast.ChanType:
switch x.Dir {
case ast.SEND | ast.RECV:
p.print(token.CHAN)
case ast.RECV:
p.print(token.ARROW, token.CHAN) // x.Arrow and x.Pos() are the same
case ast.SEND:
p.print(token.CHAN)
p.setPos(x.Arrow)
p.print(token.ARROW)
}
p.print(blank)
p.expr(x.Value)
default:
panic("unreachable")
}
}
// normalizedNumber rewrites base prefixes and exponents
// of numbers to use lower-case letters (0X123 to 0x123 and 1.2E3 to 1.2e3),
// and removes leading 0's from integer imaginary literals (0765i to 765i).
// It leaves hexadecimal digits alone.
//
// normalizedNumber doesn't modify the ast.BasicLit value lit points to.
// If lit is not a number or a number in canonical format already,
// lit is returned as is. Otherwise a new ast.BasicLit is created.
func normalizedNumber(lit *ast.BasicLit) *ast.BasicLit {
if lit.Kind != token.INT && lit.Kind != token.FLOAT && lit.Kind != token.IMAG {
return lit // not a number - nothing to do
}
if len(lit.Value) < 2 {
return lit // only one digit (common case) - nothing to do
}
// len(lit.Value) >= 2
// We ignore lit.Kind because for lit.Kind == token.IMAG the literal may be an integer
// or floating-point value, decimal or not. Instead, just consider the literal pattern.
x := lit.Value
switch x[:2] {
default:
// 0-prefix octal, decimal int, or float (possibly with 'i' suffix)
if i := strings.LastIndexByte(x, 'E'); i >= 0 {
x = x[:i] + "e" + x[i+1:]
break
}
// remove leading 0's from integer (but not floating-point) imaginary literals
if x[len(x)-1] == 'i' && !strings.ContainsAny(x, ".e") {
x = strings.TrimLeft(x, "0_")
if x == "i" {
x = "0i"
}
}
case "0X":
x = "0x" + x[2:]
// possibly a hexadecimal float
if i := strings.LastIndexByte(x, 'P'); i >= 0 {
x = x[:i] + "p" + x[i+1:]
}
case "0x":
// possibly a hexadecimal float
i := strings.LastIndexByte(x, 'P')
if i == -1 {
return lit // nothing to do
}
x = x[:i] + "p" + x[i+1:]
case "0O":
x = "0o" + x[2:]
case "0o":
return lit // nothing to do
case "0B":
x = "0b" + x[2:]
case "0b":
return lit // nothing to do
}
return &ast.BasicLit{ValuePos: lit.ValuePos, Kind: lit.Kind, Value: x}
}
// selectorExpr handles an *ast.SelectorExpr node and reports whether x spans
// multiple lines, and thus was indented.
func (p *printer) selectorExpr(x *ast.SelectorExpr, depth int) (wasIndented bool) {
p.derefPtrArgs(x.X, token.HighestPrec, depth)
p.print(token.PERIOD)
if line := p.lineFor(x.Sel.Pos()); p.pos.IsValid() && p.pos.Line < line {
p.print(indent, newline)
p.setPos(x.Sel.Pos())
p.print(x.Sel)
p.print(unindent)
return true
}
p.setPos(x.Sel.Pos())
p.print(x.Sel)
return false
}
// gosl: methodExpr needs to deal with possible multiple chains of selector exprs
// to determine the actual type and name of the receiver.
// a.b.c() -> sel.X = (a.b) Sel=c
func (p *printer) methodPath(x *ast.SelectorExpr) (recvPath, recvType string, pathType types.Type, err error) {
var baseRecv *ast.Ident // first receiver in path
var paths []string
cur := x
for {
paths = append(paths, cur.Sel.Name)
if sl, ok := cur.X.(*ast.SelectorExpr); ok { // path is itself a selector
cur = sl
continue
}
if id, ok := cur.X.(*ast.Ident); ok {
baseRecv = id
break
}
err = fmt.Errorf("gosl methodPath ERROR: path for method call must be simple list of fields, not %#v:", cur.X)
p.userError(err)
return
}
if p.isPtrArg(baseRecv) {
recvPath = "&(*" + baseRecv.Name + ")"
} else {
recvPath = "&" + baseRecv.Name
}
var idt types.Type
if gvar := p.GoToSL.GetTempVar(baseRecv.Name); gvar != nil {
idt = p.getTypeNameType(gvar.Var.SLType())
} else {
idt = p.getIdType(baseRecv)
}
if idt == nil {
err = fmt.Errorf("gosl methodPath ERROR: cannot find type for name: %q", baseRecv.Name)
p.userError(err)
return
}
bt, err := p.getStructType(idt)
if err != nil {
fmt.Println(baseRecv)
return
}
curt := bt
np := len(paths)
for pi := np - 1; pi >= 0; pi-- {
pth := paths[pi]
recvPath += "." + pth
f := fieldByName(curt, pth)
if f == nil {
err = fmt.Errorf("gosl ERROR: field not found %q in type: %q:", pth, curt.String())
p.userError(err)
return
}
if pi == 0 {
pathType = f.Type()
recvType = getLocalTypeName(f.Type())
} else {
curt, err = p.getStructType(f.Type())
if err != nil {
return
}
}
}
return
}
func fieldByName(st *types.Struct, name string) *types.Var {
nf := st.NumFields()
for i := range nf {
f := st.Field(i)
if f.Name() == name {
return f
}
}
return nil
}
func (p *printer) getIdType(id *ast.Ident) types.Type {
if obj, ok := p.pkg.TypesInfo.Uses[id]; ok {
return obj.Type()
}
return nil
}
func (p *printer) getTypeNameType(typeName string) types.Type {
obj := p.pkg.Types.Scope().Lookup(typeName)
if obj != nil {
return obj.Type()
}
return nil
}
func getLocalTypeName(typ types.Type) string {
_, nm := path.Split(typ.String())
return nm
}
func (p *printer) getStructType(typ types.Type) (*types.Struct, error) {
typ = typ.Underlying()
if st, ok := typ.(*types.Struct); ok {
return st, nil
}
if ptr, ok := typ.(*types.Pointer); ok {
typ = ptr.Elem().Underlying()
if st, ok := typ.(*types.Struct); ok {
return st, nil
}
}
if sl, ok := typ.(*types.Slice); ok {
typ = sl.Elem().Underlying()
if st, ok := typ.(*types.Struct); ok {
return st, nil
}
}
err := fmt.Errorf("gosl ERROR: type is not a struct and it should be: %q %+t", typ.String(), typ)
p.userError(err)
return nil, err
}
func (p *printer) getNamedType(typ types.Type) (*types.Named, error) {
if nmd, ok := typ.(*types.Named); ok {
return nmd, nil
}
typ = typ.Underlying()
if ptr, ok := typ.(*types.Pointer); ok {
typ = ptr.Elem()
if nmd, ok := typ.(*types.Named); ok {
return nmd, nil
}
}
if sl, ok := typ.(*types.Slice); ok {
typ = sl.Elem()
if nmd, ok := typ.(*types.Named); ok {
return nmd, nil
}
}
err := fmt.Errorf("gosl ERROR: type is not a named type: %q %+t", typ.String(), typ)
p.userError(err)
return nil, err
}
// gosl: globalVar looks up whether the id in an IndexExpr is a global gosl variable.
// in which case it returns a temp variable name to use, and the type info.
func (p *printer) globalVar(idx *ast.IndexExpr) (isGlobal bool, tmpVar, typName string, vtyp types.Type, isReadOnly bool) {
id, ok := idx.X.(*ast.Ident)
if !ok {
return
}
gvr := p.GoToSL.GlobalVar(id.Name)
if gvr == nil {
return
}
isGlobal = true
isReadOnly = gvr.ReadOnly
tmpVar = strings.ToLower(id.Name)
vtyp = p.getIdType(id)
if vtyp == nil {
err := fmt.Errorf("gosl globalVar ERROR: cannot find type for name: %q", id.Name)
p.userError(err)
return
}
nmd, err := p.getNamedType(vtyp)
if err == nil {
vtyp = nmd
}
typName = gvr.SLType()
p.print("var ", tmpVar, token.ASSIGN)
p.expr(idx)
p.print(token.SEMICOLON, blank)
tmpVar = "&" + tmpVar
return
}
// gosl: replace GetVar function call with assignment of local var
func (p *printer) getGlobalVar(ae *ast.AssignStmt, gvr *Var) {
tmpVar := ae.Lhs[0].(*ast.Ident).Name
cf := ae.Rhs[0].(*ast.CallExpr)
p.print("var", blank, tmpVar, blank, token.ASSIGN, blank, gvr.Name, token.LBRACK)
p.expr(cf.Args[0])
p.print(token.RBRACK, token.SEMICOLON)
gvars := p.GoToSL.GetVarStack.Peek()
gvars[tmpVar] = &GetGlobalVar{Var: gvr, TmpVar: tmpVar, IdxExpr: cf.Args[0]}
p.GoToSL.GetVarStack[len(p.GoToSL.GetVarStack)-1] = gvars
}
// gosl: set non-read-only global vars back from temp var
func (p *printer) setGlobalVars(gvrs map[string]*GetGlobalVar) {
for _, gvr := range gvrs {
if gvr.Var.ReadOnly {
continue
}
p.print(formfeed, "\t")
p.print(gvr.Var.Name, token.LBRACK)
p.expr(gvr.IdxExpr)
p.print(token.RBRACK, blank, token.ASSIGN, blank)
p.print(gvr.TmpVar)
p.print(token.SEMICOLON)
}
}
// gosl: methodIndex processes an index expression as receiver type of method call
func (p *printer) methodIndex(idx *ast.IndexExpr) (recvPath, recvType string, pathType types.Type, isReadOnly bool, err error) {
id, ok := idx.X.(*ast.Ident)
if !ok {
err = fmt.Errorf("gosl methodIndex ERROR: must have a recv variable identifier, not %#v:", idx.X)
p.userError(err)
return
}
isGlobal, tmpVar, typName, vtyp, isReadOnly := p.globalVar(idx)
if isGlobal {
recvPath = tmpVar
recvType = typName
pathType = vtyp
} else {
_ = id
// do above
}
return
}
func (p *printer) tensorMethod(x *ast.CallExpr, vr *Var, methName string) {
args := x.Args
stArg := 0
if strings.HasPrefix(methName, "Set") {
stArg = 1
}
if strings.HasSuffix(methName, "Ptr") {
p.print(token.AND)
if p.curMethIsAtomic {
gv := p.GoToSL.GlobalVar(vr.Name)
if gv != nil {
if p.curFunc != nil {
if p.curFunc.Atomics == nil {
p.curFunc.Atomics = make(map[string]*Var)
}
p.curFunc.Atomics[vr.Name] = vr
}
}
}
}
p.print(vr.Name, token.LBRACK)
p.print(vr.IndexFunc(), token.LPAREN)
nd := vr.TensorDims
for d := range nd {
p.print(vr.IndexStride(d), token.COMMA, blank)
}
n := len(args)
for i := stArg; i < n; i++ {
ag := args[i]
p.print("u32", token.LPAREN)
if ce, ok := ag.(*ast.CallExpr); ok { // get rid of int() wrapper from goal n-dim index
if fn, ok := ce.Fun.(*ast.Ident); ok {
if fn.Name == "int" {
ag = ce.Args[0]
}
}
}
p.expr(ag)
p.print(token.RPAREN)
if i < n-1 {
p.print(token.COMMA, blank)
}
}
p.print(token.RPAREN, token.RBRACK)
if strings.HasPrefix(methName, "Set") {
opnm := strings.TrimPrefix(methName, "Set")
tok := token.ASSIGN
switch opnm {
case "Add":
tok = token.ADD_ASSIGN
case "Sub":
tok = token.SUB_ASSIGN
case "Mul":
tok = token.MUL_ASSIGN
case "Div":
tok = token.QUO_ASSIGN
}
p.print(blank, tok, blank)
p.expr(args[0])
}
}
func (p *printer) methodExpr(x *ast.CallExpr, depth int) {
path := x.Fun.(*ast.SelectorExpr) // we know fun is selector
methName := path.Sel.Name
recvPath := ""
recvType := ""
var err error
pathIsPackage := false
var rwargs []rwArg
var pathType types.Type
if sl, ok := path.X.(*ast.SelectorExpr); ok { // path is itself a selector
recvPath, recvType, pathType, err = p.methodPath(sl)
if err != nil {
return
}
} else if id, ok := path.X.(*ast.Ident); ok {
gvr := p.GoToSL.GlobalVar(id.Name)
if gvr != nil && gvr.Tensor {
p.tensorMethod(x, gvr, methName)
return
}
recvPath = id.Name
typ := p.getIdType(id)
if typ != nil {
recvType = getLocalTypeName(typ)
if strings.HasPrefix(recvType, "invalid") {
if gvar := p.GoToSL.GetTempVar(id.Name); gvar != nil {
recvType = gvar.Var.SLType()
recvPath = "&" + recvPath
pathType = p.getTypeNameType(gvar.Var.SLType())
} else {
pathIsPackage = true
recvType = id.Name // is a package path
}
} else {
pathType = typ
recvPath = recvPath
}
} else {
pathIsPackage = true
recvType = id.Name // is a package path
}
} else if idx, ok := path.X.(*ast.IndexExpr); ok {
isReadOnly := false
recvPath, recvType, pathType, isReadOnly, err = p.methodIndex(idx)
if err != nil {
return
}
if !isReadOnly {
rwargs = append(rwargs, rwArg{idx: idx, tmpVar: recvPath})
}
} else {
err := fmt.Errorf("gosl methodExpr ERROR: path expression for method call must be simple list of fields, not %#v:", path.X)
p.userError(err)
return
}
args := x.Args
if pathType != nil {
meth, _, _ := types.LookupFieldOrMethod(pathType, true, p.pkg.Types, methName)
if meth != nil {
if ft, ok := meth.(*types.Func); ok {
sig := ft.Type().(*types.Signature)
var rwa []rwArg
args, rwa = p.goslFixArgs(x.Args, sig.Params())
rwargs = append(rwargs, rwa...)
}
}
if len(rwargs) > 0 {
p.print(formfeed)
}
}
// fmt.Println(pathIsPackage, recvType, methName, recvPath)
if pathIsPackage {
if recvType == "atomic" || recvType == "atomicx" {
p.curMethIsAtomic = true
switch {
case strings.HasPrefix(methName, "Add"):
p.print("atomicAdd")
case strings.HasPrefix(methName, "Max"):
p.print("atomicMax")
}
} else {
p.print(recvType + "." + methName)
if p.curFunc != nil {
p.curFunc.Funcs[methName] = p.GoToSL.RecycleFunc(methName)
}
}
p.setPos(x.Lparen)
p.print(token.LPAREN)
} else {
recvType = strings.TrimPrefix(recvType, "imports.") // no!
fname := recvType + "_" + methName
if p.curFunc != nil {
p.curFunc.Funcs[fname] = p.GoToSL.RecycleFunc(fname)
}
p.print(fname)
p.setPos(x.Lparen)
p.print(token.LPAREN)
p.print(recvPath)
if len(x.Args) > 0 {
p.print(token.COMMA, blank)
}
}
if x.Ellipsis.IsValid() {
p.exprList(x.Lparen, args, depth, 0, x.Ellipsis, false)
p.setPos(x.Ellipsis)
p.print(token.ELLIPSIS)
if x.Rparen.IsValid() && p.lineFor(x.Ellipsis) < p.lineFor(x.Rparen) {
p.print(token.COMMA, formfeed)
}
} else {
p.exprList(x.Lparen, args, depth, commaTerm, x.Rparen, false)
}
p.setPos(x.Rparen)
p.print(token.RPAREN)
p.curMethIsAtomic = false
p.assignRwArgs(rwargs) // gosl: assign temp var back to global var
}
func (p *printer) expr0(x ast.Expr, depth int) {
p.expr1(x, token.LowestPrec, depth)
}
func (p *printer) expr(x ast.Expr) {
const depth = 1
p.expr1(x, token.LowestPrec, depth)
}
// ----------------------------------------------------------------------------
// Statements
// Print the statement list indented, but without a newline after the last statement.
// Extra line breaks between statements in the source are respected but at most one
// empty line is printed between statements.
func (p *printer) stmtList(list []ast.Stmt, nindent int, nextIsRBrace bool) {
if nindent > 0 {
p.print(indent)
}
var line int
i := 0
for _, s := range list {
// ignore empty statements (was issue 3466)
if _, isEmpty := s.(*ast.EmptyStmt); !isEmpty {
// nindent == 0 only for lists of switch/select case clauses;
// in those cases each clause is a new section
if len(p.output) > 0 {
// only print line break if we are not at the beginning of the output
// (i.e., we are not printing only a partial program)
p.linebreak(p.lineFor(s.Pos()), 1, ignore, i == 0 || nindent == 0 || p.linesFrom(line) > 0)
}
p.recordLine(&line)
p.stmt(s, nextIsRBrace && i == len(list)-1, false)
// labeled statements put labels on a separate line, but here
// we only care about the start line of the actual statement
// without label - correct line for each label
for t := s; ; {
lt, _ := t.(*ast.LabeledStmt)
if lt == nil {
break
}
line++
t = lt.Stmt
}
i++
}
}
if nindent > 0 {
p.print(unindent)
}
}
// block prints an *ast.BlockStmt; it always spans at least two lines.
func (p *printer) block(b *ast.BlockStmt, nindent int) {
p.GoToSL.GetVarStack.Push(make(map[string]*GetGlobalVar))
p.setPos(b.Lbrace)
p.print(token.LBRACE)
nstmt := len(b.List)
retLast := false
if nstmt > 1 {
if _, ok := b.List[nstmt-1].(*ast.ReturnStmt); ok {
retLast = true
}
}
if retLast {
p.stmtList(b.List[:nstmt-1], nindent, true)
} else {
p.stmtList(b.List, nindent, true)
}
getVars := p.GoToSL.GetVarStack.Pop()
if len(getVars) > 0 { // gosl: set the get vars
p.setGlobalVars(getVars)
}
if retLast {
p.stmt(b.List[nstmt-1], true, false)
}
p.linebreak(p.lineFor(b.Rbrace), 1, ignore, true)
p.setPos(b.Rbrace)
p.print(token.RBRACE)
}
func isTypeName(x ast.Expr) bool {
switch t := x.(type) {
case *ast.Ident:
return true
case *ast.SelectorExpr:
return isTypeName(t.X)
}
return false
}
func stripParens(x ast.Expr) ast.Expr {
if px, strip := x.(*ast.ParenExpr); strip {
// parentheses must not be stripped if there are any
// unparenthesized composite literals starting with
// a type name
ast.Inspect(px.X, func(node ast.Node) bool {
switch x := node.(type) {
case *ast.ParenExpr:
// parentheses protect enclosed composite literals
return false
case *ast.CompositeLit:
if isTypeName(x.Type) {
strip = false // do not strip parentheses
}
return false
}
// in all other cases, keep inspecting
return true
})
if strip {
return stripParens(px.X)
}
}
return x
}
func stripParensAlways(x ast.Expr) ast.Expr {
if x, ok := x.(*ast.ParenExpr); ok {
return stripParensAlways(x.X)
}
return x
}
func (p *printer) controlClause(isForStmt bool, init ast.Stmt, expr ast.Expr, post ast.Stmt) {
p.print(blank)
p.print(token.LPAREN)
needsBlank := false
if init == nil && post == nil {
// no semicolons required
if expr != nil {
p.expr(stripParens(expr))
needsBlank = true
}
} else {
// all semicolons required
// (they are not separators, print them explicitly)
if init != nil {
p.stmt(init, false, false) // false = generate own semi
p.print(blank)
} else {
p.print(token.SEMICOLON, blank)
}
if expr != nil {
p.expr(stripParens(expr))
needsBlank = true
}
if isForStmt {
p.print(token.SEMICOLON, blank)
needsBlank = false
if post != nil {
p.stmt(post, false, true) // nosemi
needsBlank = true
}
}
}
p.print(token.RPAREN)
if needsBlank {
p.print(blank)
}
}
// indentList reports whether an expression list would look better if it
// were indented wholesale (starting with the very first element, rather
// than starting at the first line break).
func (p *printer) indentList(list []ast.Expr) bool {
// Heuristic: indentList reports whether there are more than one multi-
// line element in the list, or if there is any element that is not
// starting on the same line as the previous one ends.
if len(list) >= 2 {
var b = p.lineFor(list[0].Pos())
var e = p.lineFor(list[len(list)-1].End())
if 0 < b && b < e {
// list spans multiple lines
n := 0 // multi-line element count
line := b
for _, x := range list {
xb := p.lineFor(x.Pos())
xe := p.lineFor(x.End())
if line < xb {
// x is not starting on the same
// line as the previous one ended
return true
}
if xb < xe {
// x is a multi-line element
n++
}
line = xe
}
return n > 1
}
}
return false
}
func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool, nosemi bool) {
p.setPos(stmt.Pos())
switch s := stmt.(type) {
case *ast.BadStmt:
p.print("BadStmt")
case *ast.DeclStmt:
p.decl(s.Decl)
case *ast.EmptyStmt:
// nothing to do
case *ast.LabeledStmt:
// a "correcting" unindent immediately following a line break
// is applied before the line break if there is no comment
// between (see writeWhitespace)
p.print(unindent)
p.expr(s.Label)
p.setPos(s.Colon)
p.print(token.COLON, indent)
if e, isEmpty := s.Stmt.(*ast.EmptyStmt); isEmpty {
if !nextIsRBrace {
p.print(newline)
p.setPos(e.Pos())
p.print(token.SEMICOLON)
break
}
} else {
p.linebreak(p.lineFor(s.Stmt.Pos()), 1, ignore, true)
}
p.stmt(s.Stmt, nextIsRBrace, nosemi)
case *ast.ExprStmt:
const depth = 1
p.expr0(s.X, depth)
if !nosemi {
p.print(token.SEMICOLON)
}
case *ast.SendStmt:
const depth = 1
p.expr0(s.Chan, depth)
p.print(blank)
p.setPos(s.Arrow)
p.print(token.ARROW, blank)
p.expr0(s.Value, depth)
case *ast.IncDecStmt:
const depth = 1
p.expr0(s.X, depth+1)
p.setPos(s.TokPos)
p.print(s.Tok)
if !nosemi {
p.print(token.SEMICOLON)
}
case *ast.AssignStmt:
var depth = 1
if len(s.Lhs) > 1 && len(s.Rhs) > 1 {
depth++
}
if s.Tok == token.DEFINE {
if ce, ok := s.Rhs[0].(*ast.CallExpr); ok {
if fid, ok := ce.Fun.(*ast.Ident); ok {
if strings.HasPrefix(fid.Name, "Get") {
if gvr, ok := p.GoToSL.GetFuncs[fid.Name]; ok {
p.getGlobalVar(s, gvr) // replace GetVar function call with assignment of local var
return
}
}
}
}
p.print("var", blank) // we don't know if it is var or let..
}
p.exprList(s.Pos(), s.Lhs, depth, 0, s.TokPos, false)
p.print(blank)
p.setPos(s.TokPos)
switch s.Tok {
case token.DEFINE:
p.print(token.ASSIGN, blank)
case token.AND_NOT_ASSIGN:
p.print(token.AND_ASSIGN, blank, "~")
default:
p.print(s.Tok, blank)
}
if p.matchAssignType(s.Lhs, s.Rhs) {
} else {
p.exprList(s.TokPos, s.Rhs, depth, 0, token.NoPos, false)
}
if !nosemi {
p.print(token.SEMICOLON)
}
case *ast.GoStmt:
p.print(token.GO, blank)
p.expr(s.Call)
case *ast.DeferStmt:
p.print(token.DEFER, blank)
p.expr(s.Call)
case *ast.ReturnStmt:
p.print(token.RETURN)
if s.Results != nil {
p.print(blank)
if !p.matchLiteralType(s.Results[0], p.curReturnType) {
// Use indentList heuristic to make corner cases look
// better (issue 1207). A more systematic approach would
// always indent, but this would cause significant
// reformatting of the code base and not necessarily
// lead to more nicely formatted code in general.
if p.indentList(s.Results) {
p.print(indent)
// Use NoPos so that a newline never goes before
// the results (see issue #32854).
p.exprList(token.NoPos, s.Results, 1, noIndent, token.NoPos, false)
p.print(unindent)
} else {
p.exprList(token.NoPos, s.Results, 1, 0, token.NoPos, false)
}
}
}
if !nosemi {
p.print(token.SEMICOLON)
}
case *ast.BranchStmt:
p.print(s.Tok)
if s.Label != nil {
p.print(blank)
p.expr(s.Label)
}
p.print(token.SEMICOLON)
case *ast.BlockStmt:
p.block(s, 1)
case *ast.IfStmt:
p.print(token.IF)
p.controlClause(false, s.Init, s.Cond, nil)
p.block(s.Body, 1)
if s.Else != nil {
p.print(blank, token.ELSE, blank)
switch s.Else.(type) {
case *ast.BlockStmt, *ast.IfStmt:
p.stmt(s.Else, nextIsRBrace, false)
default:
// This can only happen with an incorrectly
// constructed AST. Permit it but print so
// that it can be gotosld without errors.
p.print(token.LBRACE, indent, formfeed)
p.stmt(s.Else, true, false)
p.print(unindent, formfeed, token.RBRACE)
}
}
case *ast.CaseClause:
if s.List != nil {
p.print(token.CASE, blank)
p.exprList(s.Pos(), s.List, 1, 0, s.Colon, false)
} else {
p.print(token.DEFAULT)
}
p.setPos(s.Colon)
p.print(token.COLON, blank, token.LBRACE) // Go implies new context, C doesn't
p.stmtList(s.Body, 1, nextIsRBrace)
p.print(formfeed, token.RBRACE)
case *ast.SwitchStmt:
p.print(token.SWITCH)
p.controlClause(false, s.Init, s.Tag, nil)
p.block(s.Body, 0)
case *ast.TypeSwitchStmt:
p.print(token.SWITCH)
if s.Init != nil {
p.print(blank)
p.stmt(s.Init, false, false)
p.print(token.SEMICOLON)
}
p.print(blank)
p.stmt(s.Assign, false, false)
p.print(blank)
p.block(s.Body, 0)
case *ast.CommClause:
if s.Comm != nil {
p.print(token.CASE, blank)
p.stmt(s.Comm, false, false)
} else {
p.print(token.DEFAULT)
}
p.setPos(s.Colon)
p.print(token.COLON)
p.stmtList(s.Body, 1, nextIsRBrace)
case *ast.SelectStmt:
p.print(token.SELECT, blank)
body := s.Body
if len(body.List) == 0 && !p.commentBefore(p.posFor(body.Rbrace)) {
// print empty select statement w/o comments on one line
p.setPos(body.Lbrace)
p.print(token.LBRACE)
p.setPos(body.Rbrace)
p.print(token.RBRACE)
} else {
p.block(body, 0)
}
case *ast.ForStmt:
p.print(token.FOR)
p.controlClause(true, s.Init, s.Cond, s.Post)
p.block(s.Body, 1)
case *ast.RangeStmt:
// gosl: only supporting the for i := range 10 kind of range loop
p.print(token.FOR, blank)
if s.Key != nil {
p.print(token.LPAREN, "var", blank)
p.expr(s.Key)
p.print(token.ASSIGN, "0", token.SEMICOLON, blank)
p.expr(s.Key)
p.print(token.LSS)
p.expr(stripParens(s.X))
p.print(token.SEMICOLON, blank)
p.expr(s.Key)
p.print(token.INC, token.RPAREN)
// if s.Value != nil {
// // use position of value following the comma as
// // comma position for correct comment placement
// p.setPos(s.Value.Pos())
// p.print(token.COMMA, blank)
// p.expr(s.Value)
// }
// p.print(blank)
// p.setPos(s.TokPos)
// p.print(s.Tok, blank)
}
// p.print(token.RANGE, blank)
// p.expr(stripParens(s.X))
p.print(blank)
p.block(s.Body, 1)
default:
panic("unreachable")
}
}
// ----------------------------------------------------------------------------
// Declarations
// The keepTypeColumn function determines if the type column of a series of
// consecutive const or var declarations must be kept, or if initialization
// values (V) can be placed in the type column (T) instead. The i'th entry
// in the result slice is true if the type column in spec[i] must be kept.
//
// For example, the declaration:
//
// const (
// foobar int = 42 // comment
// x = 7 // comment
// foo
// bar = 991
// )
//
// leads to the type/values matrix below. A run of value columns (V) can
// be moved into the type column if there is no type for any of the values
// in that column (we only move entire columns so that they align properly).
//
// matrix formatted result
// matrix
// T V -> T V -> true there is a T and so the type
// - V - V true column must be kept
// - - - - false
// - V V - false V is moved into T column
func keepTypeColumn(specs []ast.Spec) []bool {
m := make([]bool, len(specs))
populate := func(i, j int, keepType bool) {
if keepType {
for ; i < j; i++ {
m[i] = true
}
}
}
i0 := -1 // if i0 >= 0 we are in a run and i0 is the start of the run
var keepType bool
for i, s := range specs {
t := s.(*ast.ValueSpec)
if t.Values != nil {
if i0 < 0 {
// start of a run of ValueSpecs with non-nil Values
i0 = i
keepType = false
}
} else {
if i0 >= 0 {
// end of a run
populate(i0, i, keepType)
i0 = -1
}
}
if t.Type != nil {
keepType = true
}
}
if i0 >= 0 {
// end of a run
populate(i0, len(specs), keepType)
}
return m
}
func (p *printer) valueSpec(s *ast.ValueSpec, keepType bool, tok token.Token, firstSpec *ast.ValueSpec, isIota bool, idx int) {
p.setComment(s.Doc)
// gosl: key to use Pos() as first arg to trigger emitting of comments!
switch tok {
case token.CONST:
p.setPos(s.Pos())
p.print(tok, blank)
case token.TYPE:
p.setPos(s.Pos())
p.print("alias", blank)
}
p.print(vtab)
extraTabs := 3
p.identList(s.Names, false) // always present
if isIota {
if s.Type != nil {
p.print(token.COLON, blank)
p.expr(s.Type)
} else if firstSpec.Type != nil {
p.print(token.COLON, blank)
p.expr(firstSpec.Type)
}
p.print(vtab, token.ASSIGN, blank)
p.print(fmt.Sprintf("%d", idx))
} else if s.Type != nil || keepType {
p.print(token.COLON, blank)
p.expr(s.Type)
extraTabs--
} else if tok == token.CONST && firstSpec.Type != nil {
p.expr(firstSpec.Type)
extraTabs--
}
if !(isIota && s == firstSpec) && s.Values != nil {
p.print(vtab, token.ASSIGN, blank)
p.exprList(token.NoPos, s.Values, 1, 0, token.NoPos, false)
extraTabs--
}
p.print(token.SEMICOLON)
if s.Comment != nil {
for ; extraTabs > 0; extraTabs-- {
p.print(vtab)
}
p.setComment(s.Comment)
}
}
func sanitizeImportPath(lit *ast.BasicLit) *ast.BasicLit {
// Note: An unmodified AST generated by go/gotoslr will already
// contain a backward- or double-quoted path string that does
// not contain any invalid characters, and most of the work
// here is not needed. However, a modified or generated AST
// may possibly contain non-canonical paths. Do the work in
// all cases since it's not too hard and not speed-critical.
// if we don't have a proper string, be conservative and return whatever we have
if lit.Kind != token.STRING {
return lit
}
s, err := strconv.Unquote(lit.Value)
if err != nil {
return lit
}
// if the string is an invalid path, return whatever we have
//
// spec: "Implementation restriction: A compiler may restrict
// ImportPaths to non-empty strings using only characters belonging
// to Unicode's L, M, N, P, and S general categories (the Graphic
// characters without spaces) and may also exclude the characters
// !"#$%&'()*,:;<=>?[\]^`{|} and the Unicode replacement character
// U+FFFD."
if s == "" {
return lit
}
const illegalChars = `!"#$%&'()*,:;<=>?[\]^{|}` + "`\uFFFD"
for _, r := range s {
if !unicode.IsGraphic(r) || unicode.IsSpace(r) || strings.ContainsRune(illegalChars, r) {
return lit
}
}
// otherwise, return the double-quoted path
s = strconv.Quote(s)
if s == lit.Value {
return lit // nothing wrong with lit
}
return &ast.BasicLit{ValuePos: lit.ValuePos, Kind: token.STRING, Value: s}
}
// The parameter n is the number of specs in the group. If doIndent is set,
// multi-line identifier lists in the spec are indented when the first
// linebreak is encountered.
func (p *printer) spec(spec ast.Spec, n int, doIndent bool, tok token.Token) {
switch s := spec.(type) {
case *ast.ImportSpec:
p.setComment(s.Doc)
if s.Name != nil {
p.expr(s.Name)
p.print(blank)
}
p.expr(sanitizeImportPath(s.Path))
p.setComment(s.Comment)
p.setPos(s.EndPos)
case *ast.ValueSpec:
if n != 1 {
p.internalError("expected n = 1; got", n)
}
p.setComment(s.Doc)
if len(s.Names) > 1 {
nnm := len(s.Names)
for ni, nm := range s.Names {
p.print(tok, blank)
p.print(nm.Name)
if s.Type != nil {
p.print(token.COLON, blank)
p.expr(s.Type)
}
if s.Values != nil {
p.print(blank, token.ASSIGN, blank)
p.exprList(token.NoPos, s.Values, 1, 0, token.NoPos, false)
}
p.print(token.SEMICOLON)
if ni < nnm-1 {
p.print(formfeed)
}
}
} else {
p.print(tok, blank)
p.identList(s.Names, doIndent) // always present
if s.Type != nil {
p.print(token.COLON, blank)
p.expr(s.Type)
}
if s.Values != nil {
p.print(blank, token.ASSIGN, blank)
p.exprList(token.NoPos, s.Values, 1, 0, token.NoPos, false)
}
p.print(token.SEMICOLON)
p.setComment(s.Comment)
}
case *ast.TypeSpec:
p.setComment(s.Doc)
st, isStruct := s.Type.(*ast.StructType)
if isStruct {
p.setPos(st.Pos())
p.print(token.STRUCT, blank)
} else {
p.print("alias", blank)
}
p.expr(s.Name)
if !isStruct {
p.print(blank, token.ASSIGN, blank)
}
if s.TypeParams != nil {
p.parameters(s.TypeParams, typeTParam)
}
// if n == 1 {
// p.print(blank)
// } else {
// p.print(vtab)
// }
if s.Assign.IsValid() {
p.print(token.ASSIGN, blank)
}
p.expr(s.Type)
if !isStruct {
p.print(token.SEMICOLON)
}
p.setComment(s.Comment)
default:
panic("unreachable")
}
}
// gosl: process system global vars
func (p *printer) systemVars(d *ast.GenDecl, sysname string) {
if !p.GoToSL.GetFuncGraph {
return
}
sy := p.GoToSL.System(sysname)
var gp *Group
var err error
for _, s := range d.Specs {
vs := s.(*ast.ValueSpec)
dirs, docs := p.findDirective(vs.Doc)
readOnly := false
if hasDirective(dirs, "read-only") {
readOnly = true
}
if gpnm, ok := directiveAfter(dirs, "group"); ok {
if gpnm == "" {
gp = &Group{Name: fmt.Sprintf("Group_%d", len(sy.Groups)), Doc: docs}
sy.Groups = append(sy.Groups, gp)
} else {
gps := strings.Fields(gpnm)
gp = &Group{Doc: docs}
if gps[0] == "-uniform" {
gp.Uniform = true
if len(gps) > 1 {
gp.Name = gps[1]
}
} else {
gp.Name = gps[0]
}
sy.Groups = append(sy.Groups, gp)
}
}
if gp == nil {
gp = &Group{Name: fmt.Sprintf("Group_%d", len(sy.Groups)), Doc: docs}
sy.Groups = append(sy.Groups, gp)
}
if len(vs.Names) != 1 {
err = fmt.Errorf("gosl: system %q: vars must have only 1 variable per line", sysname)
p.userError(err)
}
nm := vs.Names[0].Name
typ := ""
if sl, ok := vs.Type.(*ast.ArrayType); ok {
id, ok := sl.Elt.(*ast.Ident)
if !ok {
err = fmt.Errorf("gosl: system %q: Var type not recognized: %#v", sysname, sl.Elt)
p.userError(err)
continue
}
typ = "[]" + id.Name
} else {
sel, ok := vs.Type.(*ast.SelectorExpr)
if !ok {
st, ok := vs.Type.(*ast.StarExpr)
if !ok {
err = fmt.Errorf("gosl: system %q: Var types must be []slices or tensor.Float32, tensor.Uint32", sysname)
p.userError(err)
continue
}
sel, ok = st.X.(*ast.SelectorExpr)
if !ok {
err = fmt.Errorf("gosl: system %q: Var types must be []slices or tensor.Float32, tensor.Uint32", sysname)
p.userError(err)
continue
}
}
sid, ok := sel.X.(*ast.Ident)
if !ok {
err = fmt.Errorf("gosl: system %q: Var type selector is not recognized: %#v", sysname, sel.X)
p.userError(err)
continue
}
typ = sid.Name + "." + sel.Sel.Name
}
vr := &Var{Name: nm, Type: typ, ReadOnly: readOnly}
if strings.HasPrefix(typ, "tensor.") {
vr.Tensor = true
dstr, ok := directiveAfter(dirs, "dims")
if !ok {
err = fmt.Errorf("gosl: system %q: variable %q tensor vars require //gosl:dims <n> to specify number of dimensions", sysname, nm)
p.userError(err)
continue
}
dims, err := strconv.Atoi(dstr)
if !ok {
err = fmt.Errorf("gosl: system %q: variable %q tensor dims parse error: %s", sysname, nm, err.Error())
p.userError(err)
}
vr.SetTensorKind()
vr.TensorDims = dims
}
gp.Vars = append(gp.Vars, vr)
if p.GoToSL.Config.Debug {
fmt.Println("\tAdded var:", nm, typ, "to group:", gp.Name)
}
}
p.GoToSL.VarsAdded()
}
func (p *printer) genDecl(d *ast.GenDecl) {
p.setComment(d.Doc)
// note: critical to print here to trigger comment generation in right place
p.setPos(d.Pos())
if d.Tok == token.IMPORT {
return
}
// p.print(d.Pos(), d.Tok, blank)
p.print(ignore) // don't print
if d.Lparen.IsValid() || len(d.Specs) != 1 {
// group of parenthesized declarations
// p.setPos(d.Lparen)
// p.print(token.LPAREN)
if n := len(d.Specs); n > 0 {
// p.print(indent, formfeed)
if n > 1 && (d.Tok == token.CONST || d.Tok == token.VAR) {
// two or more grouped const/var declarations:
if d.Tok == token.VAR {
dirs, _ := p.findDirective(d.Doc)
if sysname, ok := directiveAfter(dirs, "vars"); ok {
p.systemVars(d, sysname)
return
}
}
// determine if the type column must be kept
keepType := keepTypeColumn(d.Specs)
firstSpec := d.Specs[0].(*ast.ValueSpec)
isIota := false
if d.Tok == token.CONST {
if id, isId := firstSpec.Values[0].(*ast.Ident); isId {
if id.Name == "iota" {
isIota = true
}
}
}
var line int
for i, s := range d.Specs {
vs := s.(*ast.ValueSpec)
if i > 0 {
p.linebreak(p.lineFor(s.Pos()), 1, ignore, p.linesFrom(line) > 0)
}
p.recordLine(&line)
p.valueSpec(vs, keepType[i], d.Tok, firstSpec, isIota, i)
}
} else {
tok := d.Tok
if p.curFunc == nil && tok == token.VAR { // only system vars are supported at global scope
// could add further comment-directive logic
// to specify <private> or <workgroup> scope if needed
tok = token.CONST
}
var line int
for i, s := range d.Specs {
if i > 0 {
p.linebreak(p.lineFor(s.Pos()), 1, ignore, p.linesFrom(line) > 0)
}
p.recordLine(&line)
p.spec(s, n, false, tok)
}
}
// p.print(unindent, formfeed)
}
// p.setPos(d.Rparen)
// p.print(token.RPAREN)
} else if len(d.Specs) > 0 {
tok := d.Tok
if p.curFunc == nil && tok == token.VAR { // only system vars are supported at global scope
tok = token.CONST
}
// single declaration
p.spec(d.Specs[0], 1, true, tok)
}
}
// sizeCounter is an io.Writer which counts the number of bytes written,
// as well as whether a newline character was seen.
type sizeCounter struct {
hasNewline bool
size int
}
func (c *sizeCounter) Write(p []byte) (int, error) {
if !c.hasNewline {
for _, b := range p {
if b == '\n' || b == '\f' {
c.hasNewline = true
break
}
}
}
c.size += len(p)
return len(p), nil
}
// nodeSize determines the size of n in chars after formatting.
// The result is <= maxSize if the node fits on one line with at
// most maxSize chars and the formatted output doesn't contain
// any control chars. Otherwise, the result is > maxSize.
func (p *printer) nodeSize(n ast.Node, maxSize int) (size int) {
// nodeSize invokes the printer, which may invoke nodeSize
// recursively. For deep composite literal nests, this can
// lead to an exponential algorithm. Remember previous
// results to prune the recursion (was issue 1628).
if size, found := p.nodeSizes[n]; found {
return size
}
size = maxSize + 1 // assume n doesn't fit
p.nodeSizes[n] = size
// nodeSize computation must be independent of particular
// style so that we always get the same decision; print
// in RawFormat
cfg := PrintConfig{Mode: RawFormat}
var counter sizeCounter
if err := cfg.fprint(&counter, p.pkg, n, p.nodeSizes); err != nil {
return
}
if counter.size <= maxSize && !counter.hasNewline {
// n fits in a single line
size = counter.size
p.nodeSizes[n] = size
}
return
}
// numLines returns the number of lines spanned by node n in the original source.
func (p *printer) numLines(n ast.Node) int {
if from := n.Pos(); from.IsValid() {
if to := n.End(); to.IsValid() {
return p.lineFor(to) - p.lineFor(from) + 1
}
}
return infinity
}
// bodySize is like nodeSize but it is specialized for *ast.BlockStmt's.
func (p *printer) bodySize(b *ast.BlockStmt, maxSize int) int {
pos1 := b.Pos()
pos2 := b.Rbrace
if pos1.IsValid() && pos2.IsValid() && p.lineFor(pos1) != p.lineFor(pos2) {
// opening and closing brace are on different lines - don't make it a one-liner
return maxSize + 1
}
if len(b.List) > 5 {
// too many statements - don't make it a one-liner
return maxSize + 1
}
// otherwise, estimate body size
bodySize := p.commentSizeBefore(p.posFor(pos2))
for i, s := range b.List {
if bodySize > maxSize {
break // no need to continue
}
if i > 0 {
bodySize += 2 // space for a semicolon and blank
}
bodySize += p.nodeSize(s, maxSize)
}
return bodySize
}
// funcBody prints a function body following a function header of given headerSize.
// If the header's and block's size are "small enough" and the block is "simple enough",
// the block is printed on the current line, without line breaks, spaced from the header
// by sep. Otherwise the block's opening "{" is printed on the current line, followed by
// lines for the block's statements and its closing "}".
func (p *printer) funcBody(headerSize int, sep whiteSpace, b *ast.BlockStmt) {
if b == nil {
return
}
// save/restore composite literal nesting level
defer func(level int) {
p.level = level
}(p.level)
p.level = 0
const maxSize = 100
if headerSize+p.bodySize(b, maxSize) <= maxSize {
p.print(sep)
p.setPos(b.Lbrace)
p.print(token.LBRACE)
if len(b.List) > 0 {
p.print(blank)
for i, s := range b.List {
if i > 0 {
p.print(token.SEMICOLON, blank)
}
p.stmt(s, i == len(b.List)-1, false)
}
p.print(blank)
}
p.print(noExtraLinebreak)
p.setPos(b.Rbrace)
p.print(token.RBRACE, noExtraLinebreak)
return
}
if sep != ignore {
p.print(blank) // always use blank
}
p.block(b, 1)
}
// distanceFrom returns the column difference between p.out (the current output
// position) and startOutCol. If the start position is on a different line from
// the current position (or either is unknown), the result is infinity.
func (p *printer) distanceFrom(startPos token.Pos, startOutCol int) int {
if startPos.IsValid() && p.pos.IsValid() && p.posFor(startPos).Line == p.pos.Line {
return p.out.Column - startOutCol
}
return infinity
}
func (p *printer) methRecvType(typ ast.Expr) string {
switch x := typ.(type) {
case *ast.StarExpr:
return p.methRecvType(x.X)
case *ast.Ident:
return x.Name
default:
return fmt.Sprintf("recv type unknown: %+T", x)
}
return ""
}
func (p *printer) funcDecl(d *ast.FuncDecl) {
fname := ""
if d.Recv != nil {
for ex := range p.ExcludeFunctions {
if d.Name.Name == ex {
return
}
}
if d.Recv.List[0].Names != nil {
p.curMethRecv = d.Recv.List[0]
isptr, typnm := p.printMethRecv()
if isptr {
p.curPtrArgs = []*ast.Ident{p.curMethRecv.Names[0]}
}
fname = typnm + "_" + d.Name.Name
// fmt.Printf("cur func recv: %v\n", p.curMethRecv)
}
// p.parameters(d.Recv, funcParam) // method: print receiver
// p.print(blank)
} else {
fname = d.Name.Name
}
if p.GoToSL.GetFuncGraph {
p.curFunc = p.GoToSL.RecycleFunc(fname)
} else {
fn, ok := p.GoToSL.KernelFuncs[fname]
if !ok {
return
}
p.curFunc = fn
}
p.setComment(d.Doc)
p.setPos(d.Pos())
// We have to save startCol only after emitting FUNC; otherwise it can be on a
// different line (all whitespace preceding the FUNC is emitted only when the
// FUNC is emitted).
startCol := p.out.Column - len("func ")
p.print("fn", blank, fname)
p.signature(d.Type, d.Recv)
p.funcBody(p.distanceFrom(d.Pos(), startCol), vtab, d.Body)
p.curPtrArgs = nil
p.curMethRecv = nil
if p.GoToSL.GetFuncGraph {
p.GoToSL.FuncGraph[fname] = p.curFunc
p.curFunc = nil
}
}
func (p *printer) decl(decl ast.Decl) {
switch d := decl.(type) {
case *ast.BadDecl:
p.setPos(d.Pos())
p.print("BadDecl")
case *ast.GenDecl:
p.genDecl(d)
case *ast.FuncDecl:
p.funcDecl(d)
default:
panic("unreachable")
}
}
// ----------------------------------------------------------------------------
// Files
func declToken(decl ast.Decl) (tok token.Token) {
tok = token.ILLEGAL
switch d := decl.(type) {
case *ast.GenDecl:
tok = d.Tok
case *ast.FuncDecl:
tok = token.FUNC
}
return
}
func (p *printer) declList(list []ast.Decl) {
tok := token.ILLEGAL
for _, d := range list {
prev := tok
tok = declToken(d)
// If the declaration token changed (e.g., from CONST to TYPE)
// or the next declaration has documentation associated with it,
// print an empty line between top-level declarations.
// (because p.linebreak is called with the position of d, which
// is past any documentation, the minimum requirement is satisfied
// even w/o the extra getDoc(d) nil-check - leave it in case the
// linebreak logic improves - there's already a TODO).
if len(p.output) > 0 {
// only print line break if we are not at the beginning of the output
// (i.e., we are not printing only a partial program)
min := 1
if prev != tok || getDoc(d) != nil {
min = 2
}
// start a new section if the next declaration is a function
// that spans multiple lines (see also issue #19544)
p.linebreak(p.lineFor(d.Pos()), min, ignore, tok == token.FUNC && p.numLines(d) > 1)
}
p.decl(d)
}
}
func (p *printer) file(src *ast.File) {
p.setComment(src.Doc)
p.setPos(src.Pos())
p.print(token.PACKAGE, blank)
p.expr(src.Name)
p.declList(src.Decls)
p.print(newline)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file is largely copied from the Go source,
// src/go/printer/printer.go:
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
import (
"fmt"
"go/ast"
"go/build/constraint"
"go/token"
"io"
"os"
"strings"
"sync"
"text/tabwriter"
"unicode"
"cogentcore.org/core/base/fsx"
"golang.org/x/tools/go/packages"
)
const (
maxNewlines = 2 // max. number of newlines between source text
debug = false // enable for debugging
infinity = 1 << 30
)
type whiteSpace byte
const (
ignore = whiteSpace(0)
blank = whiteSpace(' ')
vtab = whiteSpace('\v')
newline = whiteSpace('\n')
formfeed = whiteSpace('\f')
indent = whiteSpace('>')
unindent = whiteSpace('<')
)
// A pmode value represents the current printer mode.
type pmode int
const (
noExtraBlank pmode = 1 << iota // disables extra blank after /*-style comment
noExtraLinebreak // disables extra line break after /*-style comment
)
type commentInfo struct {
cindex int // current comment index
comment *ast.CommentGroup // = printer.comments[cindex]; or nil
commentOffset int // = printer.posFor(printer.comments[cindex].List[0].Pos()).Offset; or infinity
commentNewline bool // true if the comment group contains newlines
}
type printer struct {
// Configuration (does not change after initialization)
PrintConfig
fset *token.FileSet
pkg *packages.Package // gosl: extra
// Current state
output []byte // raw printer result
indent int // current indentation
level int // level == 0: outside composite literal; level > 0: inside composite literal
mode pmode // current printer mode
endAlignment bool // if set, terminate alignment immediately
impliedSemi bool // if set, a linebreak implies a semicolon
lastTok token.Token // last token printed (token.ILLEGAL if it's whitespace)
prevOpen token.Token // previous non-brace "open" token (, [, or token.ILLEGAL
wsbuf []whiteSpace // delayed white space
goBuild []int // start index of all //go:build comments in output
plusBuild []int // start index of all // +build comments in output
// Positions
// The out position differs from the pos position when the result
// formatting differs from the source formatting (in the amount of
// white space). If there's a difference and SourcePos is set in
// ConfigMode, //line directives are used in the output to restore
// original source positions for a reader.
pos token.Position // current position in AST (source) space
out token.Position // current position in output space
last token.Position // value of pos after calling writeString
linePtr *int // if set, record out.Line for the next token in *linePtr
sourcePosErr error // if non-nil, the first error emitting a //line directive
// The list of all source comments, in order of appearance.
comments []*ast.CommentGroup // may be nil
useNodeComments bool // if not set, ignore lead and line comments of nodes
// Information about p.comments[p.cindex]; set up by nextComment.
commentInfo
// Cache of already computed node sizes.
nodeSizes map[ast.Node]int
// Cache of most recently computed line position.
cachedPos token.Pos
cachedLine int // line corresponding to cachedPos
// current arguments to function that are pointers and thus need dereferencing
// when accessing fields
curPtrArgs []*ast.Ident
curFunc *Function
curMethRecv *ast.Field // current method receiver, also included in curPtrArgs if ptr
curReturnType *ast.Ident
curMethIsAtomic bool // current method an atomic.* function -- marks arg as atomic
}
func (p *printer) internalError(msg ...any) {
if debug {
fmt.Print(p.pos.String() + ": ")
fmt.Println(msg...)
panic("go/printer")
}
}
func (p *printer) userError(err error) {
fname := fsx.DirAndFile(p.pos.String())
fmt.Print(fname + ": ")
fmt.Println(err.Error())
}
// commentsHaveNewline reports whether a list of comments belonging to
// an *ast.CommentGroup contains newlines. Because the position information
// may only be partially correct, we also have to read the comment text.
func (p *printer) commentsHaveNewline(list []*ast.Comment) bool {
// len(list) > 0
line := p.lineFor(list[0].Pos())
for i, c := range list {
if i > 0 && p.lineFor(list[i].Pos()) != line {
// not all comments on the same line
return true
}
if t := c.Text; len(t) >= 2 && (t[1] == '/' || strings.Contains(t, "\n")) {
return true
}
}
_ = line
return false
}
func (p *printer) nextComment() {
for p.cindex < len(p.comments) {
c := p.comments[p.cindex]
p.cindex++
if list := c.List; len(list) > 0 {
p.comment = c
p.commentOffset = p.posFor(list[0].Pos()).Offset
p.commentNewline = p.commentsHaveNewline(list)
return
}
// we should not reach here (correct ASTs don't have empty
// ast.CommentGroup nodes), but be conservative and try again
}
// no more comments
p.commentOffset = infinity
}
// commentBefore reports whether the current comment group occurs
// before the next position in the source code and printing it does
// not introduce implicit semicolons.
func (p *printer) commentBefore(next token.Position) bool {
return p.commentOffset < next.Offset && (!p.impliedSemi || !p.commentNewline)
}
// commentSizeBefore returns the estimated size of the
// comments on the same line before the next position.
func (p *printer) commentSizeBefore(next token.Position) int {
// save/restore current p.commentInfo (p.nextComment() modifies it)
defer func(info commentInfo) {
p.commentInfo = info
}(p.commentInfo)
size := 0
for p.commentBefore(next) {
for _, c := range p.comment.List {
size += len(c.Text)
}
p.nextComment()
}
return size
}
// recordLine records the output line number for the next non-whitespace
// token in *linePtr. It is used to compute an accurate line number for a
// formatted construct, independent of pending (not yet emitted) whitespace
// or comments.
func (p *printer) recordLine(linePtr *int) {
p.linePtr = linePtr
}
// linesFrom returns the number of output lines between the current
// output line and the line argument, ignoring any pending (not yet
// emitted) whitespace or comments. It is used to compute an accurate
// size (in number of lines) for a formatted construct.
func (p *printer) linesFrom(line int) int {
return p.out.Line - line
}
func (p *printer) posFor(pos token.Pos) token.Position {
// not used frequently enough to cache entire token.Position
return p.fset.PositionFor(pos, false /* absolute position */)
}
func (p *printer) lineFor(pos token.Pos) int {
if pos != p.cachedPos {
p.cachedPos = pos
p.cachedLine = p.fset.PositionFor(pos, false /* absolute position */).Line
}
return p.cachedLine
}
// writeLineDirective writes a //line directive if necessary.
func (p *printer) writeLineDirective(pos token.Position) {
if pos.IsValid() && (p.out.Line != pos.Line || p.out.Filename != pos.Filename) {
if strings.ContainsAny(pos.Filename, "\r\n") {
if p.sourcePosErr == nil {
p.sourcePosErr = fmt.Errorf("go/printer: source filename contains unexpected newline character: %q", pos.Filename)
}
return
}
p.output = append(p.output, tabwriter.Escape) // protect '\n' in //line from tabwriter interpretation
p.output = append(p.output, fmt.Sprintf("//line %s:%d\n", pos.Filename, pos.Line)...)
p.output = append(p.output, tabwriter.Escape)
// p.out must match the //line directive
p.out.Filename = pos.Filename
p.out.Line = pos.Line
}
}
// writeIndent writes indentation.
func (p *printer) writeIndent() {
// use "hard" htabs - indentation columns
// must not be discarded by the tabwriter
n := p.PrintConfig.Indent + p.indent // include base indentation
for i := 0; i < n; i++ {
p.output = append(p.output, '\t')
}
// update positions
p.pos.Offset += n
p.pos.Column += n
p.out.Column += n
}
// writeByte writes ch n times to p.output and updates p.pos.
// Only used to write formatting (white space) characters.
func (p *printer) writeByte(ch byte, n int) {
if p.endAlignment {
// Ignore any alignment control character;
// and at the end of the line, break with
// a formfeed to indicate termination of
// existing columns.
switch ch {
case '\t', '\v':
ch = ' '
case '\n', '\f':
ch = '\f'
p.endAlignment = false
}
}
if p.out.Column == 1 {
// no need to write line directives before white space
p.writeIndent()
}
for i := 0; i < n; i++ {
p.output = append(p.output, ch)
}
// update positions
p.pos.Offset += n
if ch == '\n' || ch == '\f' {
p.pos.Line += n
p.out.Line += n
p.pos.Column = 1
p.out.Column = 1
return
}
p.pos.Column += n
p.out.Column += n
}
// writeString writes the string s to p.output and updates p.pos, p.out,
// and p.last. If isLit is set, s is escaped w/ tabwriter.Escape characters
// to protect s from being interpreted by the tabwriter.
//
// Note: writeString is only used to write Go tokens, literals, and
// comments, all of which must be written literally. Thus, it is correct
// to always set isLit = true. However, setting it explicitly only when
// needed (i.e., when we don't know that s contains no tabs or line breaks)
// avoids processing extra escape characters and reduces run time of the
// printer benchmark by up to 10%.
func (p *printer) writeString(pos token.Position, s string, isLit bool) {
if p.out.Column == 1 {
if p.PrintConfig.Mode&SourcePos != 0 {
p.writeLineDirective(pos)
}
p.writeIndent()
}
if pos.IsValid() {
// update p.pos (if pos is invalid, continue with existing p.pos)
// Note: Must do this after handling line beginnings because
// writeIndent updates p.pos if there's indentation, but p.pos
// is the position of s.
p.pos = pos
}
if isLit {
// Protect s such that is passes through the tabwriter
// unchanged. Note that valid Go programs cannot contain
// tabwriter.Escape bytes since they do not appear in legal
// UTF-8 sequences.
p.output = append(p.output, tabwriter.Escape)
}
if debug {
p.output = append(p.output, fmt.Sprintf("/*%s*/", pos)...) // do not update p.pos!
}
p.output = append(p.output, s...)
// update positions
nlines := 0
var li int // index of last newline; valid if nlines > 0
for i := 0; i < len(s); i++ {
// Raw string literals may contain any character except back quote (`).
if ch := s[i]; ch == '\n' || ch == '\f' {
// account for line break
nlines++
li = i
// A line break inside a literal will break whatever column
// formatting is in place; ignore any further alignment through
// the end of the line.
p.endAlignment = true
}
}
p.pos.Offset += len(s)
if nlines > 0 {
p.pos.Line += nlines
p.out.Line += nlines
c := len(s) - li
p.pos.Column = c
p.out.Column = c
} else {
p.pos.Column += len(s)
p.out.Column += len(s)
}
if isLit {
p.output = append(p.output, tabwriter.Escape)
}
p.last = p.pos
}
// writeCommentPrefix writes the whitespace before a comment.
// If there is any pending whitespace, it consumes as much of
// it as is likely to help position the comment nicely.
// pos is the comment position, next the position of the item
// after all pending comments, prev is the previous comment in
// a group of comments (or nil), and tok is the next token.
func (p *printer) writeCommentPrefix(pos, next token.Position, prev *ast.Comment, tok token.Token) {
if len(p.output) == 0 {
// the comment is the first item to be printed - don't write any whitespace
return
}
if pos.IsValid() && pos.Filename != p.last.Filename {
// comment in a different file - separate with newlines
p.writeByte('\f', maxNewlines)
return
}
if pos.Line == p.last.Line && (prev == nil || prev.Text[1] != '/') {
// comment on the same line as last item:
// separate with at least one separator
hasSep := false
if prev == nil {
// first comment of a comment group
j := 0
for i, ch := range p.wsbuf {
switch ch {
case blank:
// ignore any blanks before a comment
p.wsbuf[i] = ignore
continue
case vtab:
// respect existing tabs - important
// for proper formatting of commented structs
hasSep = true
continue
case indent:
// apply pending indentation
continue
}
j = i
break
}
p.writeWhitespace(j)
}
// make sure there is at least one separator
if !hasSep {
sep := byte('\t')
if pos.Line == next.Line {
// next item is on the same line as the comment
// (which must be a /*-style comment): separate
// with a blank instead of a tab
sep = ' '
}
p.writeByte(sep, 1)
}
} else {
// comment on a different line:
// separate with at least one line break
droppedLinebreak := false
j := 0
for i, ch := range p.wsbuf {
switch ch {
case blank, vtab:
// ignore any horizontal whitespace before line breaks
p.wsbuf[i] = ignore
continue
case indent:
// apply pending indentation
continue
case unindent:
// if this is not the last unindent, apply it
// as it is (likely) belonging to the last
// construct (e.g., a multi-line expression list)
// and is not part of closing a block
if i+1 < len(p.wsbuf) && p.wsbuf[i+1] == unindent {
continue
}
// if the next token is not a closing }, apply the unindent
// if it appears that the comment is aligned with the
// token; otherwise assume the unindent is part of a
// closing block and stop (this scenario appears with
// comments before a case label where the comments
// apply to the next case instead of the current one)
if tok != token.RBRACE && pos.Column == next.Column {
continue
}
case newline, formfeed:
p.wsbuf[i] = ignore
droppedLinebreak = prev == nil // record only if first comment of a group
}
j = i
break
}
p.writeWhitespace(j)
// determine number of linebreaks before the comment
n := 0
if pos.IsValid() && p.last.IsValid() {
n = pos.Line - p.last.Line
if n < 0 { // should never happen
n = 0
}
}
// at the package scope level only (p.indent == 0),
// add an extra newline if we dropped one before:
// this preserves a blank line before documentation
// comments at the package scope level (issue 2570)
if p.indent == 0 && droppedLinebreak {
n++
}
// make sure there is at least one line break
// if the previous comment was a line comment
if n == 0 && prev != nil && prev.Text[1] == '/' {
n = 1
}
if n > 0 {
// use formfeeds to break columns before a comment;
// this is analogous to using formfeeds to separate
// individual lines of /*-style comments
p.writeByte('\f', nlimit(n))
}
}
}
// Returns true if s contains only white space
// (only tabs and blanks can appear in the printer's context).
func isBlank(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] > ' ' {
return false
}
}
return true
}
// commonPrefix returns the common prefix of a and b.
func commonPrefix(a, b string) string {
i := 0
for i < len(a) && i < len(b) && a[i] == b[i] && (a[i] <= ' ' || a[i] == '*') {
i++
}
return a[0:i]
}
// trimRight returns s with trailing whitespace removed.
func trimRight(s string) string {
return strings.TrimRightFunc(s, unicode.IsSpace)
}
// stripCommonPrefix removes a common prefix from /*-style comment lines (unless no
// comment line is indented, all but the first line have some form of space prefix).
// The prefix is computed using heuristics such that is likely that the comment
// contents are nicely laid out after re-printing each line using the printer's
// current indentation.
func stripCommonPrefix(lines []string) {
if len(lines) <= 1 {
return // at most one line - nothing to do
}
// len(lines) > 1
// The heuristic in this function tries to handle a few
// common patterns of /*-style comments: Comments where
// the opening /* and closing */ are aligned and the
// rest of the comment text is aligned and indented with
// blanks or tabs, cases with a vertical "line of stars"
// on the left, and cases where the closing */ is on the
// same line as the last comment text.
// Compute maximum common white prefix of all but the first,
// last, and blank lines, and replace blank lines with empty
// lines (the first line starts with /* and has no prefix).
// In cases where only the first and last lines are not blank,
// such as two-line comments, or comments where all inner lines
// are blank, consider the last line for the prefix computation
// since otherwise the prefix would be empty.
//
// Note that the first and last line are never empty (they
// contain the opening /* and closing */ respectively) and
// thus they can be ignored by the blank line check.
prefix := ""
prefixSet := false
if len(lines) > 2 {
for i, line := range lines[1 : len(lines)-1] {
if isBlank(line) {
lines[1+i] = "" // range starts with lines[1]
} else {
if !prefixSet {
prefix = line
prefixSet = true
}
prefix = commonPrefix(prefix, line)
}
}
}
// If we don't have a prefix yet, consider the last line.
if !prefixSet {
line := lines[len(lines)-1]
prefix = commonPrefix(line, line)
}
/*
* Check for vertical "line of stars" and correct prefix accordingly.
*/
lineOfStars := false
if p, _, ok := strings.Cut(prefix, "*"); ok {
// remove trailing blank from prefix so stars remain aligned
prefix = strings.TrimSuffix(p, " ")
lineOfStars = true
} else {
// No line of stars present.
// Determine the white space on the first line after the /*
// and before the beginning of the comment text, assume two
// blanks instead of the /* unless the first character after
// the /* is a tab. If the first comment line is empty but
// for the opening /*, assume up to 3 blanks or a tab. This
// whitespace may be found as suffix in the common prefix.
first := lines[0]
if isBlank(first[2:]) {
// no comment text on the first line:
// reduce prefix by up to 3 blanks or a tab
// if present - this keeps comment text indented
// relative to the /* and */'s if it was indented
// in the first place
i := len(prefix)
for n := 0; n < 3 && i > 0 && prefix[i-1] == ' '; n++ {
i--
}
if i == len(prefix) && i > 0 && prefix[i-1] == '\t' {
i--
}
prefix = prefix[0:i]
} else {
// comment text on the first line
suffix := make([]byte, len(first))
n := 2 // start after opening /*
for n < len(first) && first[n] <= ' ' {
suffix[n] = first[n]
n++
}
if n > 2 && suffix[2] == '\t' {
// assume the '\t' compensates for the /*
suffix = suffix[2:n]
} else {
// otherwise assume two blanks
suffix[0], suffix[1] = ' ', ' '
suffix = suffix[0:n]
}
// Shorten the computed common prefix by the length of
// suffix, if it is found as suffix of the prefix.
prefix = strings.TrimSuffix(prefix, string(suffix))
}
}
// Handle last line: If it only contains a closing */, align it
// with the opening /*, otherwise align the text with the other
// lines.
last := lines[len(lines)-1]
closing := "*/"
before, _, _ := strings.Cut(last, closing) // closing always present
if isBlank(before) {
// last line only contains closing */
if lineOfStars {
closing = " */" // add blank to align final star
}
lines[len(lines)-1] = prefix + closing
} else {
// last line contains more comment text - assume
// it is aligned like the other lines and include
// in prefix computation
prefix = commonPrefix(prefix, last)
}
// Remove the common prefix from all but the first and empty lines.
for i, line := range lines {
if i > 0 && line != "" {
lines[i] = line[len(prefix):]
}
}
}
func (p *printer) writeComment(comment *ast.Comment) {
text := comment.Text
pos := p.posFor(comment.Pos())
const linePrefix = "//line "
if strings.HasPrefix(text, linePrefix) && (!pos.IsValid() || pos.Column == 1) {
// Possibly a //-style line directive.
// Suspend indentation temporarily to keep line directive valid.
defer func(indent int) { p.indent = indent }(p.indent)
p.indent = 0
}
// shortcut common case of //-style comments
if text[1] == '/' {
if constraint.IsGoBuild(text) {
p.goBuild = append(p.goBuild, len(p.output))
} else if constraint.IsPlusBuild(text) {
p.plusBuild = append(p.plusBuild, len(p.output))
}
p.writeString(pos, trimRight(text), true)
return
}
// for /*-style comments, print line by line and let the
// write function take care of the proper indentation
lines := strings.Split(text, "\n")
// The comment started in the first column but is going
// to be indented. For an idempotent result, add indentation
// to all lines such that they look like they were indented
// before - this will make sure the common prefix computation
// is the same independent of how many times formatting is
// applied (was issue 1835).
if pos.IsValid() && pos.Column == 1 && p.indent > 0 {
for i, line := range lines[1:] {
lines[1+i] = " " + line
}
}
stripCommonPrefix(lines)
// write comment lines, separated by formfeed,
// without a line break after the last line
for i, line := range lines {
if i > 0 {
p.writeByte('\f', 1)
pos = p.pos
}
if len(line) > 0 {
p.writeString(pos, trimRight(line), true)
}
}
}
// writeCommentSuffix writes a line break after a comment if indicated
// and processes any leftover indentation information. If a line break
// is needed, the kind of break (newline vs formfeed) depends on the
// pending whitespace. The writeCommentSuffix result indicates if a
// newline was written or if a formfeed was dropped from the whitespace
// buffer.
func (p *printer) writeCommentSuffix(needsLinebreak bool) (wroteNewline, droppedFF bool) {
for i, ch := range p.wsbuf {
switch ch {
case blank, vtab:
// ignore trailing whitespace
p.wsbuf[i] = ignore
case indent, unindent:
// don't lose indentation information
case newline, formfeed:
// if we need a line break, keep exactly one
// but remember if we dropped any formfeeds
if needsLinebreak {
needsLinebreak = false
wroteNewline = true
} else {
if ch == formfeed {
droppedFF = true
}
p.wsbuf[i] = ignore
}
}
}
p.writeWhitespace(len(p.wsbuf))
// make sure we have a line break
if needsLinebreak {
p.writeByte('\n', 1)
wroteNewline = true
}
return
}
// containsLinebreak reports whether the whitespace buffer contains any line breaks.
func (p *printer) containsLinebreak() bool {
for _, ch := range p.wsbuf {
if ch == newline || ch == formfeed {
return true
}
}
return false
}
// intersperseComments consumes all comments that appear before the next token
// tok and prints it together with the buffered whitespace (i.e., the whitespace
// that needs to be written before the next token). A heuristic is used to mix
// the comments and whitespace. The intersperseComments result indicates if a
// newline was written or if a formfeed was dropped from the whitespace buffer.
func (p *printer) intersperseComments(next token.Position, tok token.Token) (wroteNewline, droppedFF bool) {
var last *ast.Comment
for p.commentBefore(next) {
list := p.comment.List
changed := false
if p.lastTok != token.IMPORT && // do not rewrite cgo's import "C" comments
p.posFor(p.comment.Pos()).Column == 1 &&
p.posFor(p.comment.End()+1) == next {
// Unindented comment abutting next token position:
// a top-level doc comment.
list = formatDocComment(list)
changed = true
if len(p.comment.List) > 0 && len(list) == 0 {
// The doc comment was removed entirely.
// Keep preceding whitespace.
p.writeCommentPrefix(p.posFor(p.comment.Pos()), next, last, tok)
// Change print state to continue at next.
p.pos = next
p.last = next
// There can't be any more comments.
p.nextComment()
return p.writeCommentSuffix(false)
}
}
for _, c := range list {
p.writeCommentPrefix(p.posFor(c.Pos()), next, last, tok)
p.writeComment(c)
last = c
}
// In case list was rewritten, change print state to where
// the original list would have ended.
if len(p.comment.List) > 0 && changed {
last = p.comment.List[len(p.comment.List)-1]
p.pos = p.posFor(last.End())
p.last = p.pos
}
p.nextComment()
}
if last != nil {
// If the last comment is a /*-style comment and the next item
// follows on the same line but is not a comma, and not a "closing"
// token immediately following its corresponding "opening" token,
// add an extra separator unless explicitly disabled. Use a blank
// as separator unless we have pending linebreaks, they are not
// disabled, and we are outside a composite literal, in which case
// we want a linebreak (issue 15137).
// TODO(gri) This has become overly complicated. We should be able
// to track whether we're inside an expression or statement and
// use that information to decide more directly.
needsLinebreak := false
if p.mode&noExtraBlank == 0 &&
last.Text[1] == '*' && p.lineFor(last.Pos()) == next.Line &&
tok != token.COMMA &&
(tok != token.RPAREN || p.prevOpen == token.LPAREN) &&
(tok != token.RBRACK || p.prevOpen == token.LBRACK) {
if p.containsLinebreak() && p.mode&noExtraLinebreak == 0 && p.level == 0 {
needsLinebreak = true
} else {
p.writeByte(' ', 1)
}
}
// Ensure that there is a line break after a //-style comment,
// before EOF, and before a closing '}' unless explicitly disabled.
if last.Text[1] == '/' ||
tok == token.EOF ||
tok == token.RBRACE && p.mode&noExtraLinebreak == 0 {
needsLinebreak = true
}
return p.writeCommentSuffix(needsLinebreak)
}
// no comment was written - we should never reach here since
// intersperseComments should not be called in that case
p.internalError("intersperseComments called without pending comments")
return
}
// writeWhitespace writes the first n whitespace entries.
func (p *printer) writeWhitespace(n int) {
// write entries
for i := 0; i < n; i++ {
switch ch := p.wsbuf[i]; ch {
case ignore:
// ignore!
case indent:
p.indent++
case unindent:
p.indent--
if p.indent < 0 {
p.internalError("negative indentation:", p.indent)
p.indent = 0
}
case newline, formfeed:
// A line break immediately followed by a "correcting"
// unindent is swapped with the unindent - this permits
// proper label positioning. If a comment is between
// the line break and the label, the unindent is not
// part of the comment whitespace prefix and the comment
// will be positioned correctly indented.
if i+1 < n && p.wsbuf[i+1] == unindent {
// Use a formfeed to terminate the current section.
// Otherwise, a long label name on the next line leading
// to a wide column may increase the indentation column
// of lines before the label; effectively leading to wrong
// indentation.
p.wsbuf[i], p.wsbuf[i+1] = unindent, formfeed
i-- // do it again
continue
}
fallthrough
default:
p.writeByte(byte(ch), 1)
}
}
// shift remaining entries down
l := copy(p.wsbuf, p.wsbuf[n:])
p.wsbuf = p.wsbuf[:l]
}
// ----------------------------------------------------------------------------
// Printing interface
// nlimit limits n to maxNewlines.
func nlimit(n int) int {
return min(n, maxNewlines)
}
func mayCombine(prev token.Token, next byte) (b bool) {
switch prev {
case token.INT:
b = next == '.' // 1.
case token.ADD:
b = next == '+' // ++
case token.SUB:
b = next == '-' // --
case token.QUO:
b = next == '*' // /*
case token.LSS:
b = next == '-' || next == '<' // <- or <<
case token.AND:
b = next == '&' || next == '^' // && or &^
}
return
}
func (p *printer) setPos(pos token.Pos) {
if pos.IsValid() {
p.pos = p.posFor(pos) // accurate position of next item
}
}
// print prints a list of "items" (roughly corresponding to syntactic
// tokens, but also including whitespace and formatting information).
// It is the only print function that should be called directly from
// any of the AST printing functions in nodes.go.
//
// Whitespace is accumulated until a non-whitespace token appears. Any
// comments that need to appear before that token are printed first,
// taking into account the amount and structure of any pending white-
// space for best comment placement. Then, any leftover whitespace is
// printed, followed by the actual token.
func (p *printer) print(args ...any) {
for _, arg := range args {
// information about the current arg
var data string
var isLit bool
var impliedSemi bool // value for p.impliedSemi after this arg
// record previous opening token, if any
switch p.lastTok {
case token.ILLEGAL:
// ignore (white space)
case token.LPAREN, token.LBRACK:
p.prevOpen = p.lastTok
default:
// other tokens followed any opening token
p.prevOpen = token.ILLEGAL
}
switch x := arg.(type) {
case pmode:
// toggle printer mode
p.mode ^= x
continue
case whiteSpace:
if x == ignore {
// don't add ignore's to the buffer; they
// may screw up "correcting" unindents (see
// LabeledStmt)
continue
}
i := len(p.wsbuf)
if i == cap(p.wsbuf) {
// Whitespace sequences are very short so this should
// never happen. Handle gracefully (but possibly with
// bad comment placement) if it does happen.
p.writeWhitespace(i)
i = 0
}
p.wsbuf = p.wsbuf[0 : i+1]
p.wsbuf[i] = x
if x == newline || x == formfeed {
// newlines affect the current state (p.impliedSemi)
// and not the state after printing arg (impliedSemi)
// because comments can be interspersed before the arg
// in this case
p.impliedSemi = false
}
p.lastTok = token.ILLEGAL
continue
case *ast.Ident:
data = x.Name
impliedSemi = true
p.lastTok = token.IDENT
case *ast.BasicLit:
data = x.Value
isLit = true
impliedSemi = true
p.lastTok = x.Kind
case token.Token:
s := x.String()
if mayCombine(p.lastTok, s[0]) {
// the previous and the current token must be
// separated by a blank otherwise they combine
// into a different incorrect token sequence
// (except for token.INT followed by a '.' this
// should never happen because it is taken care
// of via binary expression formatting)
if len(p.wsbuf) != 0 {
p.internalError("whitespace buffer not empty")
}
p.wsbuf = p.wsbuf[0:1]
p.wsbuf[0] = ' '
}
data = s
// some keywords followed by a newline imply a semicolon
switch x {
case token.BREAK, token.CONTINUE, token.FALLTHROUGH, token.RETURN,
token.INC, token.DEC, token.RPAREN, token.RBRACK, token.RBRACE:
impliedSemi = true
}
p.lastTok = x
case string:
// incorrect AST - print error message
data = x
isLit = true
impliedSemi = true
p.lastTok = token.STRING
default:
fmt.Fprintf(os.Stderr, "print: unsupported argument %v (%T)\n", arg, arg)
panic("go/printer type")
}
// data != ""
next := p.pos // estimated/accurate position of next item
wroteNewline, droppedFF := p.flush(next, p.lastTok)
// intersperse extra newlines if present in the source and
// if they don't cause extra semicolons (don't do this in
// flush as it will cause extra newlines at the end of a file)
if !p.impliedSemi {
n := nlimit(next.Line - p.pos.Line)
// don't exceed maxNewlines if we already wrote one
if wroteNewline && n == maxNewlines {
n = maxNewlines - 1
}
if n > 0 {
ch := byte('\n')
if droppedFF {
ch = '\f' // use formfeed since we dropped one before
}
p.writeByte(ch, n)
impliedSemi = false
}
}
// the next token starts now - record its line number if requested
if p.linePtr != nil {
*p.linePtr = p.out.Line
p.linePtr = nil
}
p.writeString(next, data, isLit)
p.impliedSemi = impliedSemi
}
}
// flush prints any pending comments and whitespace occurring textually
// before the position of the next token tok. The flush result indicates
// if a newline was written or if a formfeed was dropped from the whitespace
// buffer.
func (p *printer) flush(next token.Position, tok token.Token) (wroteNewline, droppedFF bool) {
if p.commentBefore(next) {
// if there are comments before the next item, intersperse them
wroteNewline, droppedFF = p.intersperseComments(next, tok)
} else {
// otherwise, write any leftover whitespace
p.writeWhitespace(len(p.wsbuf))
}
return
}
// getDoc returns the ast.CommentGroup associated with n, if any.
func getDoc(n ast.Node) *ast.CommentGroup {
switch n := n.(type) {
case *ast.Field:
return n.Doc
case *ast.ImportSpec:
return n.Doc
case *ast.ValueSpec:
return n.Doc
case *ast.TypeSpec:
return n.Doc
case *ast.GenDecl:
return n.Doc
case *ast.FuncDecl:
return n.Doc
case *ast.File:
return n.Doc
}
return nil
}
func getLastComment(n ast.Node) *ast.CommentGroup {
switch n := n.(type) {
case *ast.Field:
return n.Comment
case *ast.ImportSpec:
return n.Comment
case *ast.ValueSpec:
return n.Comment
case *ast.TypeSpec:
return n.Comment
case *ast.GenDecl:
if len(n.Specs) > 0 {
return getLastComment(n.Specs[len(n.Specs)-1])
}
case *ast.File:
if len(n.Comments) > 0 {
return n.Comments[len(n.Comments)-1]
}
}
return nil
}
func (p *printer) printNode(node any) error {
// unpack *CommentedNode, if any
var comments []*ast.CommentGroup
if cnode, ok := node.(*CommentedNode); ok {
node = cnode.Node
comments = cnode.Comments
}
if comments != nil {
// commented node - restrict comment list to relevant range
n, ok := node.(ast.Node)
if !ok {
goto unsupported
}
beg := n.Pos()
end := n.End()
// if the node has associated documentation,
// include that commentgroup in the range
// (the comment list is sorted in the order
// of the comment appearance in the source code)
if doc := getDoc(n); doc != nil {
beg = doc.Pos()
}
if com := getLastComment(n); com != nil {
if e := com.End(); e > end {
end = e
}
}
// token.Pos values are global offsets, we can
// compare them directly
i := 0
for i < len(comments) && comments[i].End() < beg {
i++
}
j := i
for j < len(comments) && comments[j].Pos() < end {
j++
}
if i < j {
p.comments = comments[i:j]
}
} else if n, ok := node.(*ast.File); ok {
// use ast.File comments, if any
p.comments = n.Comments
}
// if there are no comments, use node comments
p.useNodeComments = p.comments == nil
// get comments ready for use
p.nextComment()
p.print(pmode(0))
// format node
switch n := node.(type) {
case ast.Expr:
p.expr(n)
case ast.Stmt:
// A labeled statement will un-indent to position the label.
// Set p.indent to 1 so we don't get indent "underflow".
if _, ok := n.(*ast.LabeledStmt); ok {
p.indent = 1
}
p.stmt(n, false, false)
case ast.Decl:
p.decl(n)
case ast.Spec:
p.spec(n, 1, false, token.EOF)
case []ast.Stmt:
// A labeled statement will un-indent to position the label.
// Set p.indent to 1 so we don't get indent "underflow".
for _, s := range n {
if _, ok := s.(*ast.LabeledStmt); ok {
p.indent = 1
}
}
p.stmtList(n, 0, false)
case []ast.Decl:
p.declList(n)
case *ast.File:
p.file(n)
default:
goto unsupported
}
return p.sourcePosErr
unsupported:
return fmt.Errorf("go/printer: unsupported node type %T", node)
}
// ----------------------------------------------------------------------------
// Trimmer
// A trimmer is an io.Writer filter for stripping tabwriter.Escape
// characters, trailing blanks and tabs, and for converting formfeed
// and vtab characters into newlines and htabs (in case no tabwriter
// is used). Text bracketed by tabwriter.Escape characters is passed
// through unchanged.
type trimmer struct {
output io.Writer
state int
space []byte
}
// trimmer is implemented as a state machine.
// It can be in one of the following states:
const (
inSpace = iota // inside space
inEscape // inside text bracketed by tabwriter.Escapes
inText // inside text
)
func (p *trimmer) resetSpace() {
p.state = inSpace
p.space = p.space[0:0]
}
// Design note: It is tempting to eliminate extra blanks occurring in
// whitespace in this function as it could simplify some
// of the blanks logic in the node printing functions.
// However, this would mess up any formatting done by
// the tabwriter.
var aNewline = []byte("\n")
func (p *trimmer) Write(data []byte) (n int, err error) {
// invariants:
// p.state == inSpace:
// p.space is unwritten
// p.state == inEscape, inText:
// data[m:n] is unwritten
m := 0
var b byte
for n, b = range data {
if b == '\v' {
b = '\t' // convert to htab
}
switch p.state {
case inSpace:
switch b {
case '\t', ' ':
p.space = append(p.space, b)
case '\n', '\f':
p.resetSpace() // discard trailing space
_, err = p.output.Write(aNewline)
case tabwriter.Escape:
_, err = p.output.Write(p.space)
p.state = inEscape
m = n + 1 // +1: skip tabwriter.Escape
default:
_, err = p.output.Write(p.space)
p.state = inText
m = n
}
case inEscape:
if b == tabwriter.Escape {
_, err = p.output.Write(data[m:n])
p.resetSpace()
}
case inText:
switch b {
case '\t', ' ':
_, err = p.output.Write(data[m:n])
p.resetSpace()
p.space = append(p.space, b)
case '\n', '\f':
_, err = p.output.Write(data[m:n])
p.resetSpace()
if err == nil {
_, err = p.output.Write(aNewline)
}
case tabwriter.Escape:
_, err = p.output.Write(data[m:n])
p.state = inEscape
m = n + 1 // +1: skip tabwriter.Escape
}
default:
panic("unreachable")
}
if err != nil {
return
}
}
n = len(data)
switch p.state {
case inEscape, inText:
_, err = p.output.Write(data[m:n])
p.resetSpace()
}
return
}
// ----------------------------------------------------------------------------
// Public interface
// A Mode value is a set of flags (or 0). They control printing.
type Mode uint
const (
RawFormat Mode = 1 << iota // do not use a tabwriter; if set, UseSpaces is ignored
TabIndent // use tabs for indentation independent of UseSpaces
UseSpaces // use spaces instead of tabs for alignment
SourcePos // emit //line directives to preserve original source positions
)
// The mode below is not included in printer's public API because
// editing code text is deemed out of scope. Because this mode is
// unexported, it's also possible to modify or remove it based on
// the evolving needs of go/format and cmd/gofmt without breaking
// users. See discussion in CL 240683.
const (
// normalizeNumbers means to canonicalize number
// literal prefixes and exponents while printing.
//
// This value is known in and used by go/format and cmd/gofmt.
// It is currently more convenient and performant for those
// packages to apply number normalization during printing,
// rather than by modifying the AST in advance.
normalizeNumbers Mode = 1 << 30
)
// A PrintConfig node controls the output of Fprint.
type PrintConfig struct {
Mode Mode // default: 0
Tabwidth int // default: 8
Indent int // default: 0 (all code is indented at least by this much)
GoToSL *State // gosl:
ExcludeFunctions map[string]bool
}
var printerPool = sync.Pool{
New: func() any {
return &printer{
// Whitespace sequences are short.
wsbuf: make([]whiteSpace, 0, 16),
// We start the printer with a 16K output buffer, which is currently
// larger than about 80% of Go files in the standard library.
output: make([]byte, 0, 16<<10),
}
},
}
func newPrinter(cfg *PrintConfig, pkg *packages.Package, nodeSizes map[ast.Node]int) *printer {
p := printerPool.Get().(*printer)
*p = printer{
PrintConfig: *cfg,
pkg: pkg,
fset: pkg.Fset,
pos: token.Position{Line: 1, Column: 1},
out: token.Position{Line: 1, Column: 1},
wsbuf: p.wsbuf[:0],
nodeSizes: nodeSizes,
cachedPos: -1,
output: p.output[:0],
}
return p
}
func (p *printer) free() {
// Hard limit on buffer size; see https://golang.org/issue/23199.
if cap(p.output) > 64<<10 {
return
}
printerPool.Put(p)
}
// fprint implements Fprint and takes a nodesSizes map for setting up the printer state.
func (cfg *PrintConfig) fprint(output io.Writer, pkg *packages.Package, node any, nodeSizes map[ast.Node]int) (err error) {
// print node
p := newPrinter(cfg, pkg, nodeSizes)
defer p.free()
if err = p.printNode(node); err != nil {
return
}
// print outstanding comments
p.impliedSemi = false // EOF acts like a newline
p.flush(token.Position{Offset: infinity, Line: infinity}, token.EOF)
// output is buffered in p.output now.
// fix //go:build and // +build comments if needed.
p.fixGoBuildLines()
// redirect output through a trimmer to eliminate trailing whitespace
// (Input to a tabwriter must be untrimmed since trailing tabs provide
// formatting information. The tabwriter could provide trimming
// functionality but no tabwriter is used when RawFormat is set.)
output = &trimmer{output: output}
// redirect output through a tabwriter if necessary
if cfg.Mode&RawFormat == 0 {
minwidth := cfg.Tabwidth
padchar := byte('\t')
if cfg.Mode&UseSpaces != 0 {
padchar = ' '
}
twmode := tabwriter.DiscardEmptyColumns
if cfg.Mode&TabIndent != 0 {
minwidth = 0
twmode |= tabwriter.TabIndent
}
output = tabwriter.NewWriter(output, minwidth, cfg.Tabwidth, 1, padchar, twmode)
}
// write printer result via tabwriter/trimmer to output
if _, err = output.Write(p.output); err != nil {
return
}
// flush tabwriter, if any
if tw, _ := output.(*tabwriter.Writer); tw != nil {
err = tw.Flush()
}
return
}
// A CommentedNode bundles an AST node and corresponding comments.
// It may be provided as argument to any of the [Fprint] functions.
type CommentedNode struct {
Node any // *ast.File, or ast.Expr, ast.Decl, ast.Spec, or ast.Stmt
Comments []*ast.CommentGroup
}
// Fprint "pretty-prints" an AST node to output for a given configuration cfg.
// Position information is interpreted relative to the file set fset.
// The node type must be *[ast.File], *[CommentedNode], [][ast.Decl], [][ast.Stmt],
// or assignment-compatible to [ast.Expr], [ast.Decl], [ast.Spec], or [ast.Stmt].
func (cfg *PrintConfig) Fprint(output io.Writer, pkg *packages.Package, node any) error {
return cfg.fprint(output, pkg, node, make(map[ast.Node]int))
}
// Fprint "pretty-prints" an AST node to output.
// It calls [PrintConfig.Fprint] with default settings.
// Note that gofmt uses tabs for indentation but spaces for alignment;
// use format.Node (package go/format) for output that matches gofmt.
func Fprint(output io.Writer, pkg *packages.Package, node any) error {
return (&PrintConfig{Tabwidth: 8}).Fprint(output, pkg, node)
}
// Copyright 2024 Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
import (
"bytes"
)
// MoveLines moves the st,ed region to 'to' line
func MoveLines(lines *[][]byte, to, st, ed int) {
mvln := (*lines)[st:ed]
btwn := (*lines)[to:st]
aft := (*lines)[ed:len(*lines)]
nln := make([][]byte, to, len(*lines))
copy(nln, (*lines)[:to])
nln = append(nln, mvln...)
nln = append(nln, btwn...)
nln = append(nln, aft...)
*lines = nln
}
// SlEdits performs post-generation edits for wgsl,
// replacing type names, slbool, function calls, etc.
// returns true if a slrand. or sltype. prefix was found,
// driveing copying of those files.
func SlEdits(src []byte) (lines [][]byte, hasSlrand bool, hasSltype bool) {
nl := []byte("\n")
lines = bytes.Split(src, nl)
hasSlrand, hasSltype = SlEditsReplace(lines)
return
}
type Replace struct {
From, To []byte
}
var Replaces = []Replace{
{[]byte("sltype.Uint32Vec2"), []byte("vec2<u32>")},
{[]byte("sltype.Float32Vec2"), []byte("vec2<f32>")},
{[]byte("float32"), []byte("f32")},
{[]byte("float64"), []byte("f64")}, // TODO: not yet supported
{[]byte("uint32"), []byte("u32")},
{[]byte("uint64"), []byte("su64")},
{[]byte("int32"), []byte("i32")},
{[]byte("math32.FastExp("), []byte("FastExp(")}, // FastExp about same speed, numerically identical
// {[]byte("math32.FastExp("), []byte("exp(")}, // exp is slightly faster it seems
{[]byte("math.Float32frombits("), []byte("bitcast<f32>(")},
{[]byte("math.Float32bits("), []byte("bitcast<u32>(")},
{[]byte("shaders."), []byte("")},
{[]byte("slrand."), []byte("Rand")},
{[]byte("RandUi32"), []byte("RandUint32")}, // fix int32 -> i32
{[]byte(".SetFromVector2("), []byte("=(")},
{[]byte(".SetFrom2("), []byte("=(")},
{[]byte(".IsTrue()"), []byte("==1")},
{[]byte(".IsFalse()"), []byte("==0")},
{[]byte(".SetBool(true)"), []byte("=1")},
{[]byte(".SetBool(false)"), []byte("=0")},
{[]byte(".SetBool("), []byte("=i32(")},
{[]byte("slbool.Bool"), []byte("i32")},
{[]byte("slbool.True"), []byte("1")},
{[]byte("slbool.False"), []byte("0")},
{[]byte("slbool.IsTrue("), []byte("(1 == ")},
{[]byte("slbool.IsFalse("), []byte("(0 == ")},
{[]byte("slbool.FromBool("), []byte("i32(")},
{[]byte("bools.ToFloat32("), []byte("f32(")},
{[]byte("bools.FromFloat32("), []byte("bool(")},
{[]byte("num.FromBool[f32]("), []byte("f32(")},
{[]byte("num.ToBool("), []byte("bool(")},
// todo: do this conversion in nodes only for correct types
// {[]byte(".X"), []byte(".x")},
// {[]byte(".Y"), []byte(".y")},
// {[]byte(".Z"), []byte(".z")},
// {[]byte(""), []byte("")},
// {[]byte(""), []byte("")},
// {[]byte(""), []byte("")},
}
func MathReplaceAll(mat, ln []byte) []byte {
ml := len(mat)
st := 0
for {
sln := ln[st:]
i := bytes.Index(sln, mat)
if i < 0 {
return ln
}
fl := ln[st+i+ml : st+i+ml+1]
dl := bytes.ToLower(fl)
el := ln[st+i+ml+1:]
ln = append(ln[:st+i], dl...)
ln = append(ln, el...)
st += i + 1
}
}
func SlRemoveComments(lines [][]byte) [][]byte {
comm := []byte("//")
olns := make([][]byte, 0, len(lines))
for _, ln := range lines {
ts := bytes.TrimSpace(ln)
if len(ts) == 0 {
continue
}
if bytes.HasPrefix(ts, comm) {
continue
}
olns = append(olns, ln)
}
return olns
}
// SlEditsReplace replaces Go with equivalent WGSL code
// returns true if has slrand. or sltype.
// to auto include that header file if so.
func SlEditsReplace(lines [][]byte) (bool, bool) {
mt32 := []byte("math32.")
mth := []byte("math.")
slr := []byte("slrand.")
styp := []byte("sltype.")
include := []byte("#include")
hasSlrand := false
hasSltype := false
for li, ln := range lines {
if bytes.Contains(ln, include) {
continue
}
if !hasSlrand && bytes.Contains(ln, slr) {
hasSlrand = true
}
if !hasSltype && bytes.Contains(ln, styp) {
hasSltype = true
}
for _, r := range Replaces {
ln = bytes.ReplaceAll(ln, r.From, r.To)
}
ln = MathReplaceAll(mt32, ln)
ln = MathReplaceAll(mth, ln)
lines[li] = ln
}
return hasSlrand, hasSltype
}
var SLBools = []Replace{
{[]byte(".IsTrue()"), []byte("==1")},
{[]byte(".IsFalse()"), []byte("==0")},
{[]byte(".SetBool(true)"), []byte("=1")},
{[]byte(".SetBool(false)"), []byte("=0")},
{[]byte(".SetBool("), []byte("=int32(")},
{[]byte("slbool.Bool"), []byte("int32")},
{[]byte("slbool.True"), []byte("1")},
{[]byte("slbool.False"), []byte("0")},
{[]byte("slbool.IsTrue("), []byte("(1 == ")},
{[]byte("slbool.IsFalse("), []byte("(0 == ")},
{[]byte("slbool.FromBool("), []byte("int32(")},
{[]byte("bools.ToFloat32("), []byte("float32(")},
{[]byte("bools.FromFloat32("), []byte("bool(")},
{[]byte("num.FromBool[f32]("), []byte("float32(")},
{[]byte("num.ToBool("), []byte("bool(")},
}
// SlBoolReplace replaces all the slbool methods with literal int32 expressions.
func SlBoolReplace(lines [][]byte) {
for li, ln := range lines {
for _, r := range SLBools {
ln = bytes.ReplaceAll(ln, r.From, r.To)
}
lines[li] = ln
}
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
import (
"bytes"
"fmt"
"go/ast"
"go/token"
"log"
"os"
"os/exec"
"path/filepath"
"sort"
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/gosl/alignsl"
"golang.org/x/exp/maps"
"golang.org/x/tools/go/packages"
)
// TranslateDir translate all .Go files in given directory to WGSL.
func (st *State) TranslateDir(pf string) error {
pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName | packages.NeedFiles | packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesSizes | packages.NeedTypesInfo}, pf)
// pkgs, err := packages.Load(&packages.Config{Mode: packages.LoadAllSyntax}, pf)
if err != nil {
return errors.Log(err)
}
if len(pkgs) != 1 {
err := fmt.Errorf("More than one package for path: %v", pf)
return errors.Log(err)
}
pkg := pkgs[0]
if len(pkg.GoFiles) == 0 {
err := fmt.Errorf("No Go files found in package: %+v", pkg)
return errors.Log(err)
}
// fmt.Printf("go files: %+v", pkg.GoFiles)
// return nil, err
files := pkg.GoFiles
serr := alignsl.CheckPackage(pkg)
if serr != nil {
fmt.Println(serr)
}
st.FuncGraph = make(map[string]*Function)
st.GetFuncGraph = true
doFile := func(gofp string, buf *bytes.Buffer) {
_, gofn := filepath.Split(gofp)
if st.Config.Debug {
fmt.Printf("###################################\nTranslating Go file: %s\n", gofn)
}
var afile *ast.File
var fpos token.Position
for _, sy := range pkg.Syntax {
pos := pkg.Fset.Position(sy.Package)
_, posfn := filepath.Split(pos.Filename)
if posfn == gofn {
fpos = pos
afile = sy
break
}
}
if afile == nil {
fmt.Printf("Warning: File named: %s not found in Loaded package\n", gofn)
return
}
pcfg := PrintConfig{GoToSL: st, Mode: printerMode, Tabwidth: tabWidth, ExcludeFunctions: st.ExcludeMap}
pcfg.Fprint(buf, pkg, afile)
if !st.GetFuncGraph && !st.Config.Keep {
os.Remove(fpos.Filename)
}
}
// first pass is just to get the call graph:
for fn := range st.GoVarsFiles { // do varsFiles first!!
var buf bytes.Buffer
doFile(fn, &buf)
}
for _, gofp := range files {
_, gofn := filepath.Split(gofp)
if _, ok := st.GoVarsFiles[gofn]; ok {
continue
}
var buf bytes.Buffer
doFile(gofp, &buf)
}
// st.PrintFuncGraph()
doKernelFile := func(fname string, lines [][]byte) ([][]byte, bool, bool) {
_, gofn := filepath.Split(fname)
var buf bytes.Buffer
doFile(fname, &buf)
slfix, hasSlrand, hasSltype := SlEdits(buf.Bytes())
slfix = SlRemoveComments(slfix)
exsl := st.ExtractWGSL(slfix)
lines = append(lines, []byte(""))
lines = append(lines, []byte(fmt.Sprintf("//////// import: %q", gofn)))
lines = append(lines, exsl...)
return lines, hasSlrand, hasSltype
}
// next pass is per kernel
st.GetFuncGraph = false
sys := maps.Keys(st.Systems)
sort.Strings(sys)
for _, snm := range sys {
sy := st.Systems[snm]
kns := maps.Keys(sy.Kernels)
sort.Strings(kns)
for _, knm := range kns {
kn := sy.Kernels[knm]
st.KernelFuncs = st.AllFuncs(kn.Name)
if st.KernelFuncs == nil {
continue
}
var hasSlrand, hasSltype, hasR, hasT bool
avars := st.AtomicVars(st.KernelFuncs)
// if st.Config.Debug {
fmt.Printf("###################################\nTranslating Kernel file: %s\n", kn.Name)
// }
hdr := st.GenKernelHeader(sy, kn, avars)
lines := bytes.Split([]byte(hdr), []byte("\n"))
for fn := range st.GoVarsFiles { // do varsFiles first!!
lines, hasR, hasT = doKernelFile(fn, lines)
if hasR {
hasSlrand = true
}
if hasT {
hasSltype = true
}
}
for _, gofp := range files {
_, gofn := filepath.Split(gofp)
if _, ok := st.GoVarsFiles[gofn]; ok {
continue
}
lines, hasR, hasT = doKernelFile(gofp, lines)
if hasR {
hasSlrand = true
}
if hasT {
hasSltype = true
}
}
if hasSlrand {
st.CopyPackageFile("slrand.wgsl", "cogentcore.org/lab/gosl/slrand")
hasSltype = true
}
if hasSltype {
st.CopyPackageFile("sltype.wgsl", "cogentcore.org/lab/gosl/sltype")
}
for _, im := range st.SLImportFiles {
lines = append(lines, []byte(""))
lines = append(lines, []byte(fmt.Sprintf("//////// import: %q", im.Name)))
lines = append(lines, im.Lines...)
}
kn.Lines = lines
kfn := kn.Name + ".wgsl"
fn := filepath.Join(st.Config.Output, kfn)
kn.Filename = fn
WriteFileLines(fn, lines)
st.CompileFile(kfn)
}
}
return nil
}
var (
nagaWarned = false
tintWarned = false
)
func (st *State) CompileFile(fn string) error {
dir, _ := filepath.Abs(st.Config.Output)
if _, err := exec.LookPath("naga"); err == nil {
// cmd := exec.Command("naga", "--compact", fn, fn) // produces some pretty weird code actually
cmd := exec.Command("naga", fn)
cmd.Dir = dir
out, err := cmd.CombinedOutput()
fmt.Printf("\n-----------------------------------------------------\nnaga output for: %s\n%s", fn, out)
if err != nil {
log.Println(err)
return err
}
} else {
if !nagaWarned {
fmt.Println("\nImportant: you should install the 'naga' WGSL compiler from https://github.com/gfx-rs/wgpu to get immediate validation")
nagaWarned = true
}
}
if _, err := exec.LookPath("tint"); err == nil {
cmd := exec.Command("tint", "--validate", "--format", "wgsl", "-o", "/dev/null", fn)
cmd.Dir = dir
out, err := cmd.CombinedOutput()
fmt.Printf("\n-----------------------------------------------------\ntint output for: %s\n%s", fn, out)
if err != nil {
log.Println(err)
return err
}
} else {
if !tintWarned {
fmt.Println("\nImportant: you should install the 'tint' WGSL compiler from https://dawn.googlesource.com/dawn/ to get immediate validation")
tintWarned = true
}
}
return nil
}
// Copyright (c) 2019, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
package slbool defines a WGSL friendly int32 Bool type.
The standard WGSL bool type causes obscure errors,
and the int32 obeys the 4 byte basic alignment requirements.
gosl automatically converts this Go code into appropriate WGSL code.
*/
package slbool
// Bool is an WGSL friendly int32 Bool type.
type Bool int32
const (
// False is the [Bool] false value
False Bool = 0
// True is the [Bool] true value
True Bool = 1
)
// Bool returns the Bool as a standard Go bool
func (b Bool) Bool() bool {
return b == True
}
// IsTrue returns whether the bool is true
func (b Bool) IsTrue() bool {
return b == True
}
// IsFalse returns whether the bool is false
func (b Bool) IsFalse() bool {
return b == False
}
// SetBool sets the Bool from a standard Go bool
func (b *Bool) SetBool(bb bool) {
*b = FromBool(bb)
}
// String returns the bool as a string ("true"/"false")
func (b Bool) String() string {
if b.IsTrue() {
return "true"
}
return "false"
}
// FromString sets the bool from the given string
func (b *Bool) FromString(s string) {
if s == "true" || s == "True" {
b.SetBool(true)
} else {
b.SetBool(false)
}
}
// MarshalText implements the [encoding/text.Marshaler] interface
func (b Bool) MarshalText() ([]byte, error) { return []byte(b.String()), nil }
// UnmarshalText implements the [encoding/text.Unmarshaler] interface
func (b *Bool) UnmarshalText(s []byte) error { b.FromString(string(s)); return nil }
// IsTrue returns whether the given bool is true
func IsTrue(b Bool) bool {
return b == True
}
// IsFalse returns whether the given bool is false
func IsFalse(b Bool) bool {
return b == False
}
// FromBool returns the given Go bool as a [Bool]
func FromBool(b bool) Bool {
if b {
return True
}
return False
}
// Copyright (c) 2023, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package slboolcore
import (
"cogentcore.org/core/core"
"cogentcore.org/lab/gosl/slbool"
)
func init() {
core.AddValueType[slbool.Bool, core.Switch]()
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package slrand
import (
"cogentcore.org/core/math32"
"cogentcore.org/lab/gosl/sltype"
)
// These are Go versions of the same Philox2x32 based random number generator
// functions available in .WGSL.
// Philox2x32round does one round of updating of the counter.
func Philox2x32round(counter uint64, key uint32) uint64 {
ctr := sltype.Uint64ToLoHi(counter)
mul := sltype.Uint64ToLoHi(sltype.Uint32Mul64(0xD256D193, ctr.X))
ctr.X = mul.Y ^ key ^ ctr.Y
ctr.Y = mul.X
return sltype.Uint64FromLoHi(ctr)
}
// Philox2x32bumpkey does one round of updating of the key
func Philox2x32bumpkey(key uint32) uint32 {
return key + 0x9E3779B9
}
// Philox2x32 implements the stateless counter-based RNG algorithm
// returning a random number as two uint32 values, given a
// counter and key input that determine the result.
func Philox2x32(counter uint64, key uint32) sltype.Uint32Vec2 {
// this is an unrolled loop of 10 updates based on initial counter and key,
// which produces the random deviation deterministically based on these inputs.
counter = Philox2x32round(counter, key) // 1
key = Philox2x32bumpkey(key)
counter = Philox2x32round(counter, key) // 2
key = Philox2x32bumpkey(key)
counter = Philox2x32round(counter, key) // 3
key = Philox2x32bumpkey(key)
counter = Philox2x32round(counter, key) // 4
key = Philox2x32bumpkey(key)
counter = Philox2x32round(counter, key) // 5
key = Philox2x32bumpkey(key)
counter = Philox2x32round(counter, key) // 6
key = Philox2x32bumpkey(key)
counter = Philox2x32round(counter, key) // 7
key = Philox2x32bumpkey(key)
counter = Philox2x32round(counter, key) // 8
key = Philox2x32bumpkey(key)
counter = Philox2x32round(counter, key) // 9
key = Philox2x32bumpkey(key)
return sltype.Uint64ToLoHi(Philox2x32round(counter, key)) // 10
}
/////////
// Methods below provide a standard interface with more
// readable names, mapping onto the Go rand methods.
//
// They assume a global shared counter, which is then
// incremented by a function index, defined for each function
// consuming random numbers that _could_ be called within a parallel
// processing loop. At the end of the loop, the global counter should
// be incremented by the total possible number of such functions.
// This results in fully resproducible results, invariant to
// specific processing order, and invariant to whether any one function
// actually calls the random number generator.
// Uint32Vec2 returns two uniformly distributed 32 unsigned integers,
// based on given global shared counter, function index offset from that
// counter for this specific random number call, and key as unique
// index of the item being processed.
func Uint32Vec2(counter uint64, funcIndex uint32, key uint32) sltype.Uint32Vec2 {
return Philox2x32(sltype.Uint64Add32(counter, funcIndex), key)
}
// Uint32 returns a uniformly distributed 32 unsigned integer,
// based on given global shared counter, function index offset from that
// counter for this specific random number call, and key as unique
// index of the item being processed.
func Uint32(counter uint64, funcIndex uint32, key uint32) uint32 {
return Philox2x32(sltype.Uint64Add32(counter, funcIndex), key).X
}
// Float32Vec2 returns two uniformly distributed float32 values in range (0,1),
// based on given global shared counter, function index offset from that
// counter for this specific random number call, and key as unique
// index of the item being processed.
func Float32Vec2(counter uint64, funcIndex uint32, key uint32) sltype.Float32Vec2 {
return sltype.Uint32ToFloat32Vec2(Uint32Vec2(counter, funcIndex, key))
}
// Float32 returns a uniformly distributed float32 value in range (0,1),
// based on given global shared counter, function index offset from that
// counter for this specific random number call, and key as unique
// index of the item being processed.
func Float32(counter uint64, funcIndex uint32, key uint32) float32 {
return sltype.Uint32ToFloat32(Uint32(counter, funcIndex, key))
}
// Float32Range11Vec2 returns two uniformly distributed float32 values in range [-1,1],
// based on given global shared counter, function index offset from that
// counter for this specific random number call, and key as unique
// index of the item being processed.
func Float32Range11Vec2(counter uint64, funcIndex uint32, key uint32) sltype.Float32Vec2 {
return sltype.Uint32ToFloat32Vec2(Uint32Vec2(counter, funcIndex, key))
}
// Float32Range11 returns a uniformly distributed float32 value in range [-1,1],
// based on given global shared counter, function index offset from that
// counter for this specific random number call, and key as unique
// index of the item being processed.
func Float32Range11(counter uint64, funcIndex uint32, key uint32) float32 {
return sltype.Uint32ToFloat32Range11(Uint32(counter, funcIndex, key))
}
// BoolP returns a bool true value with probability p
// based on given global shared counter, function index offset from that
// counter for this specific random number call, and key as unique
// index of the item being processed.
func BoolP(counter uint64, funcIndex uint32, key uint32, p float32) bool {
return (Float32(counter, funcIndex, key) < p)
}
func SincosPi(x float32) sltype.Float32Vec2 {
const PIf = 3.1415926535897932
var r sltype.Float32Vec2
r.Y, r.X = math32.Sincos(PIf * x)
return r
}
// Float32NormVec2 returns two random float32 numbers
// distributed according to the normal, Gaussian distribution
// with zero mean and unit variance.
// This is done very efficiently using the Box-Muller algorithm
// that consumes two random 32 bit uint values.
// Uses given global shared counter, function index offset from that
// counter for this specific random number call, and key as unique
// index of the item being processed.
func Float32NormVec2(counter uint64, funcIndex uint32, key uint32) sltype.Float32Vec2 {
ur := Uint32Vec2(counter, funcIndex, key)
f := SincosPi(sltype.Uint32ToFloat32Range11(ur.X))
r := math32.Sqrt(-2.0 * math32.Log(sltype.Uint32ToFloat32(ur.Y))) // guaranteed to avoid 0.
return f.MulScalar(r)
}
// Float32Norm returns a random float32 number
// distributed according to the normal, Gaussian distribution
// with zero mean and unit variance.
// Uses given global shared counter, function index offset from that
// counter for this specific random number call, and key as unique
// index of the item being processed.
func Float32Norm(counter uint64, funcIndex uint32, key uint32) float32 {
return Float32Vec2(counter, funcIndex, key).X
}
// Uint32N returns a uint32 in the range [0,N).
// Uses given global shared counter, function index offset from that
// counter for this specific random number call, and key as unique
// index of the item being processed.
func Uint32N(counter uint64, funcIndex uint32, key uint32, n uint32) uint32 {
v := Float32(counter, funcIndex, key)
return uint32(v * float32(n))
}
// Counter is used for storing the random counter using aligned 16 byte storage,
// with convenience methods for typical use cases.
// It retains a copy of the last Seed value, which is applied to the Hi uint32 value.
type Counter struct {
// Counter value
Counter uint64
// last seed value set by Seed method, restored by Reset()
HiSeed uint32
pad uint32
}
// Reset resets counter to last set Seed state
func (ct *Counter) Reset() {
ct.Counter = sltype.Uint64FromLoHi(sltype.Uint32Vec2{0, ct.HiSeed})
}
// Seed sets the Hi uint32 value from given seed, saving it in HiSeed field.
// Each increment in seed generates a unique sequence of over 4 billion numbers,
// so it is reasonable to just use incremental values there, but more widely
// spaced numbers will result in longer unique sequences.
// Resets Lo to 0.
// This same seed will be restored during Reset
func (ct *Counter) Seed(seed uint32) {
ct.HiSeed = seed
ct.Reset()
}
// Add increments the counter by given amount.
// Call this after completing a pass of computation
// where the value passed here is the max of funcIndex+1
// used for any possible random calls during that pass.
func (ct *Counter) Add(inc uint32) {
ct.Counter = sltype.Uint64Add32(ct.Counter, inc)
}
// Copyright (c) 2022, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sltype
import "cogentcore.org/core/math32"
// Int32Vec2 is a length 2 vector of int32
type Int32Vec2 = math32.Vector2i
// Int32Vec3 is a length 3 vector of int32
type IntVec3 = math32.Vector3i
// Int32Vec4 is a length 4 vector of int32
type Int32Vec4 struct {
X int32
Y int32
Z int32
W int32
}
//////// Unsigned
// Uint32Vec2 is a length 2 vector of uint32
type Uint32Vec2 struct {
X uint32
Y uint32
}
// Uint32Vec3 is a length 3 vector of uint32
type Uint32Vec3 struct {
X uint32
Y uint32
Z uint32
}
// Uint32Vec4 is a length 4 vector of uint32
type Uint32Vec4 struct {
X uint32
Y uint32
Z uint32
W uint32
}
func (u *Uint32Vec4) SetFromVec2(u2 Uint32Vec2) {
u.X = u2.X
u.Y = u2.Y
u.Z = 0
u.W = 1
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sltype
import (
"math"
)
// Uint32Mul64 multiplies two uint32 numbers into a uint64.
func Uint32Mul64(a, b uint32) uint64 {
return uint64(a) * uint64(b)
}
// Uint64ToLoHi splits a uint64 number into lo and hi uint32 components.
func Uint64ToLoHi(a uint64) Uint32Vec2 {
var r Uint32Vec2
r.Y = uint32(a >> 32)
r.X = uint32(a)
return r
}
// Uint64FromLoHi combines lo and hi uint32 components into a uint64 value.
func Uint64FromLoHi(a Uint32Vec2) uint64 {
return uint64(a.X) + uint64(a.Y)<<32
}
// Uint64Add32 adds given uint32 number to given uint64.
func Uint64Add32(a uint64, b uint32) uint64 {
return a + uint64(b)
}
// Uint64Incr returns increment of the given uint64.
func Uint64Incr(a uint64) uint64 {
return a + 1
}
// Uint32ToFloat32 converts a uint32 integer into a float32
// in the (0,1) interval (i.e., exclusive of 1).
// This differs from the Go standard by excluding 0, which is handy for passing
// directly to Log function, and from the reference Philox code by excluding 1
// which is in the Go standard and most other standard RNGs.
func Uint32ToFloat32(val uint32) float32 {
const factor = float32(1.) / (float32(0xffffffff) + float32(1.))
const halffactor = float32(0.5) * factor
f := float32(val)*factor + halffactor
if f == 1 { // exclude 1
return math.Float32frombits(0x3F7FFFFF)
}
return f
}
// Uint32ToFloat32Vec2 converts two uint32 bit integers
// into two corresponding 32 bit f32 values
// in the (0,1) interval (i.e., exclusive of 1).
func Uint32ToFloat32Vec2(val Uint32Vec2) Float32Vec2 {
var r Float32Vec2
r.X = Uint32ToFloat32(val.X)
r.Y = Uint32ToFloat32(val.Y)
return r
}
// Uint32ToFloat32Range11 converts a uint32 integer into a float32
// in the [-1..1] interval (inclusive of -1 and 1, never identically == 0).
func Uint32ToFloat32Range11(val uint32) float32 {
const factor = float32(1.) / (float32(0x7fffffff) + float32(1.))
const halffactor = float32(0.5) * factor
return (float32(int32(val))*factor + halffactor)
}
// Uint32ToFloat32Range11Vec2 converts two uint32 integers into two float32
// in the [-1,1] interval (inclusive of -1 and 1, never identically == 0)
func Uint32ToFloat32Range11Vec2(val Uint32Vec2) Float32Vec2 {
var r Float32Vec2
r.X = Uint32ToFloat32Range11(val.X)
r.Y = Uint32ToFloat32Range11(val.Y)
return r
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package threading provides a simple parallel run function. this will be moved elsewhere.
package threading
import (
"math"
"sync"
)
// Maps the given function across the [0, total) range of items, using
// nThreads goroutines.
func ParallelRun(fun func(st, ed int), total int, nThreads int) {
itemsPerThr := int(math.Ceil(float64(total) / float64(nThreads)))
waitGroup := sync.WaitGroup{}
for start := 0; start < total; start += itemsPerThr {
start := start // be extra sure with closure
end := min(start+itemsPerThr, total)
waitGroup.Add(1) // todo: move out of loop
go func() {
fun(start, end)
waitGroup.Done()
}()
}
waitGroup.Wait()
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package lab
import (
"io/fs"
"os"
"path/filepath"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/styles"
"cogentcore.org/core/tree"
"cogentcore.org/core/yaegicore/coresymbols"
)
// Basic is a basic data browser with the files as the left panel,
// and the Tabber as the right panel.
type Basic struct {
core.Frame
Browser
}
// Init initializes with the data and script directories
func (br *Basic) Init() {
br.Frame.Init()
br.Styler(func(s *styles.Style) {
s.Grow.Set(1, 1)
})
br.InitInterp()
br.Interpreter.Interp.Use(coresymbols.Symbols) // gui imports
br.OnShow(func(e events.Event) {
br.UpdateFiles()
})
tree.AddChildAt(br, "splits", func(w *core.Splits) {
br.Splits = w
w.SetSplits(.15, .85)
tree.AddChildAt(w, "fileframe", func(w *core.Frame) {
w.Styler(func(s *styles.Style) {
s.Direction = styles.Column
s.Overflow.Set(styles.OverflowAuto)
s.Grow.Set(1, 1)
})
tree.AddChildAt(w, "filetree", func(w *DataTree) {
br.Files = w
})
})
tree.AddChildAt(w, "tabs", func(w *Tabs) {
br.Tabs = w
})
})
br.Updater(func() {
if br.Files != nil {
br.Files.Tabber = br.Tabs
}
})
}
// NewBasicWindow returns a new data Browser window for given
// file system (nil for os files) and data directory.
// do RunWindow on resulting [core.Body] to open the window.
func NewBasicWindow(fsys fs.FS, dataDir string) (*core.Body, *Basic) {
startDir, _ := os.Getwd()
startDir = errors.Log1(filepath.Abs(startDir))
b := core.NewBody("Cogent Data Browser: " + fsx.DirAndFile(startDir))
br := NewBasic(b)
br.FS = fsys
ddr := dataDir
if fsys == nil {
ddr = errors.Log1(filepath.Abs(dataDir))
}
b.AddTopBar(func(bar *core.Frame) {
tb := core.NewToolbar(bar)
br.Toolbar = tb
tb.Maker(br.MakeToolbar)
})
br.SetDataRoot(ddr)
br.SetScriptsDir(filepath.Join(ddr, "dbscripts"))
TheBrowser = &br.Browser
CurTabber = br.Browser.Tabs
br.Interpreter.Eval("br := databrowser.TheBrowser") // grab it
br.UpdateScripts()
return b, br
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package lab
//go:generate core generate
import (
"io/fs"
"slices"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/icons"
"cogentcore.org/core/tree"
"cogentcore.org/lab/goal/interpreter"
"golang.org/x/exp/maps"
)
// TheBrowser is the current browser,
// which is valid immediately after NewBrowserWindow
// where it is used to get a local variable for subsequent use.
var TheBrowser *Browser
// Browser holds all the elements of a data browser, for browsing data
// either on an OS filesystem or as a tensorfs virtual data filesystem.
// It supports the automatic loading of [goal] scripts as toolbar actions to
// perform pre-programmed tasks on the data, to create app-like functionality.
// Scripts are ordered alphabetically and any leading #- prefix is automatically
// removed from the label, so you can use numbers to specify a custom order.
// It is not a [core.Widget] itself, and is intended to be incorporated into
// a [core.Frame] widget, potentially along with other custom elements.
// See [Basic] for a basic implementation.
type Browser struct { //types:add -setters
// FS is the filesystem, if browsing an FS.
FS fs.FS
// DataRoot is the path to the root of the data to browse.
DataRoot string
// StartDir is the starting directory, where the app was originally started.
StartDir string
// ScriptsDir is the directory containing scripts for toolbar actions.
// It defaults to DataRoot/dbscripts
ScriptsDir string
// Scripts
Scripts map[string]string `set:"-"`
// Interpreter is the interpreter to use for running Browser scripts
Interpreter *interpreter.Interpreter `set:"-"`
// Files is the [DataTree] tree browser of the tensorfs or files.
Files *DataTree
// Tabs is the [Tabber] element managing tabs of data views.
Tabs Tabber
// Toolbar is the top-level toolbar for the browser, if used.
Toolbar *core.Toolbar
// Splits is the overall [core.Splits] for the browser.
Splits *core.Splits
}
// UpdateFiles Updates the files list.
func (br *Browser) UpdateFiles() { //types:add
if br.Files == nil {
return
}
files := br.Files
if br.FS != nil {
files.SortByModTime = true
files.OpenPathFS(br.FS, br.DataRoot)
} else {
files.OpenPath(br.DataRoot)
}
}
func (br *Browser) MakeToolbar(p *tree.Plan) {
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(br.UpdateFiles).SetText("").SetIcon(icons.Refresh).SetShortcut("Command+U")
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(br.UpdateScripts).SetText("").SetIcon(icons.Code)
})
scr := maps.Keys(br.Scripts)
slices.Sort(scr)
for _, s := range scr {
lbl := TrimOrderPrefix(s)
tree.AddAt(p, lbl, func(w *core.Button) {
w.SetText(lbl).SetIcon(icons.RunCircle).
OnClick(func(e events.Event) {
br.RunScript(s)
})
sc := br.Scripts[s]
tt := FirstComment(sc)
if tt == "" {
tt = "Run Script (add a comment to top of script to provide more useful info here)"
}
w.SetTooltip(tt)
})
}
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package lab
import (
"image"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/fileinfo"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/filetree"
"cogentcore.org/core/icons"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/states"
"cogentcore.org/core/texteditor/diffbrowser"
"cogentcore.org/core/tree"
"cogentcore.org/core/types"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorfs"
)
// Treer is an interface for getting the Root node as a DataTree struct.
type Treer interface {
AsDataTree() *DataTree
}
// AsDataTree returns the given value as a [DataTree] if it has
// an AsDataTree() method, or nil otherwise.
func AsDataTree(n tree.Node) *DataTree {
if t, ok := n.(Treer); ok {
return t.AsDataTree()
}
return nil
}
// DataTree is the databrowser version of [filetree.Tree],
// which provides the Tabber to show data editors.
type DataTree struct {
filetree.Tree
// Tabber is the [Tabber] for this tree.
Tabber Tabber
}
func (ft *DataTree) AsDataTree() *DataTree {
return ft
}
func (ft *DataTree) Init() {
ft.Tree.Init()
ft.Root = ft
ft.FileNodeType = types.For[FileNode]()
}
// FileNode is databrowser version of FileNode for FileTree
type FileNode struct {
filetree.Node
}
func (fn *FileNode) Init() {
fn.Node.Init()
fn.AddContextMenu(fn.ContextMenu)
}
// Tabber returns the [Tabber] for this filenode, from root tree.
func (fn *FileNode) Tabber() Tabber {
fr := AsDataTree(fn.Root)
if fr != nil {
return fr.Tabber
}
return nil
}
func (fn *FileNode) WidgetTooltip(pos image.Point) (string, image.Point) {
res := fn.Tooltip
if fn.Info.Cat == fileinfo.Data {
ofn := fn.AsNode()
switch fn.Info.Known {
case fileinfo.Number, fileinfo.String:
dv := TensorFS(ofn)
v := dv.String()
if res != "" {
res += " "
}
res += v
}
}
return res, fn.DefaultTooltipPos()
}
// TensorFS returns the tensorfs representation of this item.
// returns nil if not a dataFS item.
func TensorFS(fn *filetree.Node) *tensorfs.Node {
dfs, ok := fn.FileRoot().FS.(*tensorfs.Node)
if !ok {
return nil
}
dfi, err := dfs.Stat(string(fn.Filepath))
if errors.Log(err) != nil {
return nil
}
return dfi.(*tensorfs.Node)
}
func (fn *FileNode) GetFileInfo() error {
err := fn.InitFileInfo()
if fn.FileRoot().FS == nil {
return err
}
d := TensorFS(fn.AsNode())
if d != nil {
fn.Info.Known = d.KnownFileInfo()
fn.Info.Cat = fileinfo.Data
switch fn.Info.Known {
case fileinfo.Tensor:
fn.Info.Ic = icons.BarChart
case fileinfo.Table:
fn.Info.Ic = icons.BarChart4Bars
case fileinfo.Number:
fn.Info.Ic = icons.Tag
case fileinfo.String:
fn.Info.Ic = icons.Title
default:
fn.Info.Ic = icons.BarChart
}
}
return err
}
func (fn *FileNode) OpenFile() error {
ofn := fn.AsNode()
ts := fn.Tabber()
if ts == nil {
return nil
}
df := fsx.DirAndFile(string(fn.Filepath))
switch {
case fn.IsDir():
d := TensorFS(ofn)
dt := tensorfs.DirTable(d, nil)
ts.AsLab().TensorTable(df, dt)
case fn.Info.Cat == fileinfo.Data:
switch fn.Info.Known {
case fileinfo.Tensor:
d := TensorFS(ofn)
ts.AsLab().TensorEditor(df, d.Tensor)
case fileinfo.Number:
dv := TensorFS(ofn)
v := dv.Tensor.Float1D(0)
d := core.NewBody(df)
core.NewText(d).SetType(core.TextSupporting).SetText(df)
sp := core.NewSpinner(d).SetValue(float32(v))
d.AddBottomBar(func(bar *core.Frame) {
d.AddCancel(bar)
d.AddOK(bar).OnClick(func(e events.Event) {
dv.Tensor.SetFloat1D(float64(sp.Value), 0)
})
})
d.RunDialog(fn)
case fileinfo.String:
dv := TensorFS(ofn)
v := dv.Tensor.String1D(0)
d := core.NewBody(df)
core.NewText(d).SetType(core.TextSupporting).SetText(df)
tf := core.NewTextField(d).SetText(v)
d.AddBottomBar(func(bar *core.Frame) {
d.AddCancel(bar)
d.AddOK(bar).OnClick(func(e events.Event) {
dv.Tensor.SetString1D(tf.Text(), 0)
})
})
d.RunDialog(fn)
default:
dt := table.New()
err := dt.OpenCSV(fsx.Filename(fn.Filepath), tensor.Tab) // todo: need more flexible data handling mode
if err != nil {
core.ErrorSnackbar(fn, err)
} else {
ts.AsLab().TensorTable(df, dt)
}
}
case fn.IsExec(): // todo: use exec?
fn.OpenFilesDefault()
case fn.Info.Cat == fileinfo.Video: // todo: use our video viewer
fn.OpenFilesDefault()
case fn.Info.Cat == fileinfo.Audio: // todo: use our audio viewer
fn.OpenFilesDefault()
case fn.Info.Cat == fileinfo.Image: // todo: use our image viewer
fn.OpenFilesDefault()
case fn.Info.Cat == fileinfo.Model: // todo: use xyz
fn.OpenFilesDefault()
case fn.Info.Cat == fileinfo.Sheet: // todo: use our spreadsheet :)
fn.OpenFilesDefault()
case fn.Info.Cat == fileinfo.Bin: // don't edit
fn.OpenFilesDefault()
case fn.Info.Cat == fileinfo.Archive || fn.Info.Cat == fileinfo.Backup: // don't edit
fn.OpenFilesDefault()
default:
ts.AsLab().EditorString(df, string(fn.Filepath))
}
return nil
}
// EditFiles calls EditFile on selected files
func (fn *FileNode) EditFiles() { //types:add
fn.SelectedFunc(func(sn *filetree.Node) {
sn.This.(*FileNode).EditFile()
})
}
// EditFile pulls up this file in a texteditor
func (fn *FileNode) EditFile() {
if fn.IsDir() {
fn.OpenFile()
return
}
ts := fn.Tabber()
if ts == nil {
return
}
if fn.Info.Cat == fileinfo.Data {
fn.OpenFile()
return
}
df := fsx.DirAndFile(string(fn.Filepath))
ts.AsLab().EditorString(df, string(fn.Filepath))
}
// PlotFiles calls PlotFile on selected files
func (fn *FileNode) PlotFiles() { //types:add
fn.SelectedFunc(func(sn *filetree.Node) {
if sfn, ok := sn.This.(*FileNode); ok {
sfn.PlotFile()
}
})
}
// PlotFile creates a plot of data.
func (fn *FileNode) PlotFile() {
ts := fn.Tabber()
if ts == nil {
return
}
d := TensorFS(fn.AsNode())
if d != nil {
ts.AsLab().PlotTensorFS(d)
return
}
if fn.Info.Cat != fileinfo.Data {
return
}
df := fsx.DirAndFile(string(fn.Filepath))
ptab := df + " Plot"
dt := table.New(df)
err := dt.OpenCSV(fsx.Filename(fn.Filepath), tensor.Tab) // todo: need more flexible data handling mode
if err != nil {
core.ErrorSnackbar(fn, err)
return
}
ts.AsLab().PlotTable(ptab, dt)
}
// todo: this is too redundant -- need a better soln
// GridFiles calls GridFile on selected files
func (fn *FileNode) GridFiles() { //types:add
fn.SelectedFunc(func(sn *filetree.Node) {
if sfn, ok := sn.This.(*FileNode); ok {
sfn.GridFile()
}
})
}
// GridFile creates a grid view of data.
func (fn *FileNode) GridFile() {
ts := fn.Tabber()
if ts == nil {
return
}
d := TensorFS(fn.AsNode())
if d != nil {
ts.AsLab().GridTensorFS(d)
return
}
}
// DiffDirs displays a browser with differences between two selected directories
func (fn *FileNode) DiffDirs() { //types:add
var da, db *filetree.Node
fn.SelectedFunc(func(sn *filetree.Node) {
if sn.IsDir() {
if da == nil {
da = sn
} else if db == nil {
db = sn
}
}
})
if da == nil || db == nil {
core.MessageSnackbar(fn, "DiffDirs requires two selected directories")
return
}
NewDiffBrowserDirs(string(da.Filepath), string(db.Filepath))
}
// NewDiffBrowserDirs returns a new diff browser for files that differ
// within the two given directories. Excludes Job and .tsv data files.
func NewDiffBrowserDirs(pathA, pathB string) {
brow, b := diffbrowser.NewBrowserWindow()
brow.DiffDirs(pathA, pathB, func(fname string) bool {
if IsTableFile(fname) {
return true
}
if strings.HasPrefix(fname, "job.") || fname == "dbmeta.toml" {
return true
}
return false
})
b.RunWindow()
}
func IsTableFile(fname string) bool {
return strings.HasSuffix(fname, ".tsv") || strings.HasSuffix(fname, ".csv")
}
func (fn *FileNode) ContextMenu(m *core.Scene) {
core.NewFuncButton(m).SetFunc(fn.EditFiles).SetText("Edit").SetIcon(icons.Edit).
Styler(func(s *styles.Style) {
s.SetState(!fn.HasSelection(), states.Disabled)
})
core.NewFuncButton(m).SetFunc(fn.PlotFiles).SetText("Plot").SetIcon(icons.Edit).
Styler(func(s *styles.Style) {
s.SetState(!fn.HasSelection() || fn.Info.Cat != fileinfo.Data, states.Disabled)
})
core.NewFuncButton(m).SetFunc(fn.GridFiles).SetText("Grid").SetIcon(icons.Edit).
Styler(func(s *styles.Style) {
s.SetState(!fn.HasSelection() || fn.Info.Cat != fileinfo.Data, states.Disabled)
})
core.NewFuncButton(m).SetFunc(fn.DiffDirs).SetText("Diff Dirs").SetIcon(icons.Edit).
Styler(func(s *styles.Style) {
s.SetState(!fn.HasSelection() || !fn.IsDir(), states.Disabled)
})
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package lab
import (
"fmt"
"log/slog"
"os"
"path/filepath"
"strconv"
"strings"
"unicode"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/base/logx"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/styles"
"cogentcore.org/lab/goal/goalib"
"cogentcore.org/lab/goal/interpreter"
"github.com/cogentcore/yaegi/interp"
)
func (br *Browser) InitInterp() {
br.Interpreter = interpreter.NewInterpreter(interp.Options{})
br.Interpreter.Config()
// logx.UserLevel = slog.LevelDebug // for debugging of init loading
}
func (br *Browser) RunScript(snm string) {
sc, ok := br.Scripts[snm]
if !ok {
slog.Error("script not found:", "Script:", snm)
return
}
logx.PrintlnDebug("\n################\nrunning script:\n", sc, "\n")
_, _, err := br.Interpreter.Eval(sc)
if err == nil {
err = br.Interpreter.Goal.TrState.DepthError()
}
br.Interpreter.Goal.TrState.ResetDepth()
}
// UpdateScripts updates the Scripts and updates the toolbar.
func (br *Browser) UpdateScripts() { //types:add
redo := (br.Scripts != nil)
scr := fsx.Filenames(br.ScriptsDir, ".goal")
br.Scripts = make(map[string]string)
for _, s := range scr {
snm := strings.TrimSuffix(s, ".goal")
sc, err := os.ReadFile(filepath.Join(br.ScriptsDir, s))
if err == nil {
if unicode.IsLower(rune(snm[0])) {
if !redo {
fmt.Println("run init script:", snm)
br.Interpreter.Eval(string(sc))
}
} else {
ssc := string(sc)
br.Scripts[snm] = ssc
}
} else {
slog.Error(err.Error())
}
}
if br.Toolbar != nil {
br.Toolbar.Update()
}
}
// TrimOrderPrefix trims any optional #- prefix from given string,
// used for ordering items by name.
func TrimOrderPrefix(s string) string {
i := strings.Index(s, "-")
if i < 0 {
return s
}
ds := s[:i]
if _, err := strconv.Atoi(ds); err != nil {
return s
}
return s[i+1:]
}
// PromptOKCancel prompts the user for whether to do something,
// calling the given function if the user clicks OK.
func PromptOKCancel(ctx core.Widget, prompt string, fun func()) {
d := core.NewBody(prompt)
d.AddBottomBar(func(bar *core.Frame) {
d.AddCancel(bar)
d.AddOK(bar).OnClick(func(e events.Event) {
if fun != nil {
fun()
}
})
})
d.RunDialog(ctx)
}
// PromptString prompts the user for a string value (initial value given),
// calling the given function if the user clicks OK.
func PromptString(ctx core.Widget, str string, prompt string, fun func(s string)) {
d := core.NewBody(prompt)
tf := core.NewTextField(d).SetText(str)
tf.Styler(func(s *styles.Style) {
s.Min.X.Ch(60)
})
d.AddBottomBar(func(bar *core.Frame) {
d.AddCancel(bar)
d.AddOK(bar).OnClick(func(e events.Event) {
if fun != nil {
fun(tf.Text())
}
})
})
d.RunDialog(ctx)
}
// PromptStruct prompts the user for the values in given struct (pass a pointer),
// calling the given function if the user clicks OK.
func PromptStruct(ctx core.Widget, str any, prompt string, fun func()) {
d := core.NewBody(prompt)
core.NewForm(d).SetStruct(str)
d.AddBottomBar(func(bar *core.Frame) {
d.AddCancel(bar)
d.AddOK(bar).OnClick(func(e events.Event) {
if fun != nil {
fun()
}
})
})
d.RunDialog(ctx)
}
// FirstComment returns the first comment lines from given .goal file,
// which is used to set the tooltip for scripts.
func FirstComment(sc string) string {
sl := goalib.SplitLines(sc)
cmt := ""
for _, l := range sl {
if !strings.HasPrefix(l, "// ") {
return cmt
}
cmt += strings.TrimSpace(l[3:]) + " "
}
return cmt
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package lab
import (
"fmt"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/core"
"cogentcore.org/core/styles"
"cogentcore.org/core/texteditor"
"cogentcore.org/lab/plotcore"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorcore"
"cogentcore.org/lab/tensorfs"
)
// CurTabber is the current Tabber. Set when one is created.
var CurTabber Tabber
// Tabber is a [core.Tabs] based widget that has support for opening
// tabs for [plotcore.PlotEditor] and [tensorcore.Table] editors,
// among others.
type Tabber interface {
core.Tabber
// AsLab returns the [lab.Tabs] widget with all the tabs methods.
AsLab() *Tabs
}
// NewTab recycles a tab with given label, or returns the existing one
// with given type of widget within it. The existing that is returned
// is the last one in the frame, allowing for there to be a toolbar at the top.
// mkfun function is called to create and configure a new widget
// if not already existing.
func NewTab[T any](tb Tabber, label string, mkfun func(tab *core.Frame) T) T {
tab := tb.AsLab().RecycleTab(label)
var zv T
if tab.HasChildren() {
nc := tab.NumChildren()
lc := tab.Child(nc - 1)
if tt, ok := lc.(T); ok {
return tt
}
err := fmt.Errorf("Name / Type conflict: tab %q does not have the expected type of content: is %T", label, lc)
core.ErrorSnackbar(tb.AsLab(), err)
return zv
}
w := mkfun(tab)
return w
}
// TabAt returns widget of given type at tab of given name, nil if tab not found.
func TabAt[T any](tb Tabber, label string) T {
var zv T
tab := tb.AsLab().TabByName(label)
if tab == nil {
return zv
}
if !tab.HasChildren() { // shouldn't happen
return zv
}
nc := tab.NumChildren()
lc := tab.Child(nc - 1)
if tt, ok := lc.(T); ok {
return tt
}
err := fmt.Errorf("Name / Type conflict: tab %q does not have the expected type of content: %T", label, lc)
core.ErrorSnackbar(tb.AsLab(), err)
return zv
}
// Tabs implements the [Tabber] interface.
type Tabs struct {
core.Tabs
}
func (ts *Tabs) Init() {
ts.Tabs.Init()
ts.Type = core.FunctionalTabs
}
func (ts *Tabs) AsLab() *Tabs {
return ts
}
// TensorTable recycles a tab with a tensorcore.Table widget
// to view given table.Table, using its own table.Table as tv.Table.
// Use tv.Table.Table to get the underlying *table.Table
// Use tv.Table.Sequential to update the Indexed to view
// all of the rows when done updating the Table, and then call br.Update()
func (ts *Tabs) TensorTable(label string, dt *table.Table) *tensorcore.Table {
tv := NewTab(ts, label, func(tab *core.Frame) *tensorcore.Table {
tb := core.NewToolbar(tab)
tv := tensorcore.NewTable(tab)
tb.Maker(tv.MakeToolbar)
return tv
})
tv.SetTable(dt)
ts.Update()
return tv
}
// TensorEditor recycles a tab with a tensorcore.TensorEditor widget
// to view given Tensor.
func (ts *Tabs) TensorEditor(label string, tsr tensor.Tensor) *tensorcore.TensorEditor {
tv := NewTab(ts, label, func(tab *core.Frame) *tensorcore.TensorEditor {
tb := core.NewToolbar(tab)
tv := tensorcore.NewTensorEditor(tab)
tb.Maker(tv.MakeToolbar)
return tv
})
tv.SetTensor(tsr)
ts.Update()
return tv
}
// TensorGrid recycles a tab with a tensorcore.TensorGrid widget
// to view given Tensor.
func (ts *Tabs) TensorGrid(label string, tsr tensor.Tensor) *tensorcore.TensorGrid {
tv := NewTab(ts, label, func(tab *core.Frame) *tensorcore.TensorGrid {
// tb := core.NewToolbar(tab)
tv := tensorcore.NewTensorGrid(tab)
// tb.Maker(tv.MakeToolbar)
return tv
})
tv.SetTensor(tsr)
ts.Update()
return tv
}
// GridTensorFS recycles a tab with a Grid of given [tensorfs.Node].
func (ts *Tabs) GridTensorFS(dfs *tensorfs.Node) *tensorcore.TensorGrid {
label := fsx.DirAndFile(dfs.Path()) + " Grid"
if dfs.IsDir() {
core.MessageSnackbar(ts, "Use Edit instead of Grid to view a directory")
return nil
}
tsr := dfs.Tensor
return ts.TensorGrid(label, tsr)
}
// PlotTable recycles a tab with a Plot of given table.Table.
func (ts *Tabs) PlotTable(label string, dt *table.Table) *plotcore.PlotEditor {
pl := NewTab(ts, label, func(tab *core.Frame) *plotcore.PlotEditor {
tb := core.NewToolbar(tab)
pl := plotcore.NewPlotEditor(tab)
tab.Styler(func(s *styles.Style) {
s.Direction = styles.Column
s.Grow.Set(1, 1)
})
tb.Maker(pl.MakeToolbar)
return pl
})
if pl != nil {
pl.SetTable(dt)
ts.Update()
}
return pl
}
// PlotTensorFS recycles a tab with a Plot of given [tensorfs.Node].
func (ts *Tabs) PlotTensorFS(dfs *tensorfs.Node) *plotcore.PlotEditor {
label := fsx.DirAndFile(dfs.Path()) + " Plot"
if dfs.IsDir() {
return ts.PlotTable(label, tensorfs.DirTable(dfs, nil))
}
tsr := dfs.Tensor
dt := table.New(label)
dt.Columns.Rows = tsr.DimSize(0)
if ix, ok := tsr.(*tensor.Rows); ok {
dt.Indexes = ix.Indexes
}
rc := dt.AddIntColumn("Row")
for r := range dt.Columns.Rows {
rc.Values[r] = r
}
dt.AddColumn(dfs.Name(), tsr.AsValues())
return ts.PlotTable(label, dt)
}
// GoUpdatePlot calls GoUpdatePlot on plot at tab with given name.
// Does nothing if tab name doesn't exist (returns nil).
func (ts *Tabs) GoUpdatePlot(label string) *plotcore.PlotEditor {
pl := TabAt[*plotcore.PlotEditor](ts, label)
if pl != nil {
pl.GoUpdatePlot()
}
return pl
}
// UpdatePlot calls UpdatePlot on plot at tab with given name.
// Does nothing if tab name doesn't exist (returns nil).
func (ts *Tabs) UpdatePlot(label string) *plotcore.PlotEditor {
pl := TabAt[*plotcore.PlotEditor](ts, label)
if pl != nil {
pl.UpdatePlot()
}
return pl
}
// SliceTable recycles a tab with a core.Table widget
// to view the given slice of structs.
func (ts *Tabs) SliceTable(label string, slc any) *core.Table {
tv := NewTab(ts, label, func(tab *core.Frame) *core.Table {
return core.NewTable(tab)
})
tv.SetSlice(slc)
ts.Update()
return tv
}
// EditorString recycles a [texteditor.Editor] tab, displaying given string.
func (ts *Tabs) EditorString(label, content string) *texteditor.Editor {
ed := NewTab(ts, label, func(tab *core.Frame) *texteditor.Editor {
ed := texteditor.NewEditor(tab)
ed.Styler(func(s *styles.Style) {
s.Grow.Set(1, 1)
})
return ed
})
if content != "" {
ed.Buffer.SetText([]byte(content))
}
ts.Update()
return ed
}
// EditorFile opens an editor tab for given file.
func (ts *Tabs) EditorFile(label, filename string) *texteditor.Editor {
ed := NewTab(ts, label, func(tab *core.Frame) *texteditor.Editor {
ed := texteditor.NewEditor(tab)
ed.Styler(func(s *styles.Style) {
s.Grow.Set(1, 1)
})
return ed
})
ed.Buffer.Open(core.Filename(filename))
ts.Update()
return ed
}
// Code generated by "core generate"; DO NOT EDIT.
package lab
import (
"io/fs"
"cogentcore.org/core/core"
"cogentcore.org/core/tree"
"cogentcore.org/core/types"
)
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/lab.Basic", IDName: "basic", Doc: "Basic is a basic data browser with the files as the left panel,\nand the Tabber as the right panel.", Embeds: []types.Field{{Name: "Frame"}, {Name: "Browser"}}})
// NewBasic returns a new [Basic] with the given optional parent:
// Basic is a basic data browser with the files as the left panel,
// and the Tabber as the right panel.
func NewBasic(parent ...tree.Node) *Basic { return tree.New[Basic](parent...) }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/lab.Browser", IDName: "browser", Doc: "Browser holds all the elements of a data browser, for browsing data\neither on an OS filesystem or as a tensorfs virtual data filesystem.\nIt supports the automatic loading of [goal] scripts as toolbar actions to\nperform pre-programmed tasks on the data, to create app-like functionality.\nScripts are ordered alphabetically and any leading #- prefix is automatically\nremoved from the label, so you can use numbers to specify a custom order.\nIt is not a [core.Widget] itself, and is intended to be incorporated into\na [core.Frame] widget, potentially along with other custom elements.\nSee [Basic] for a basic implementation.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Methods: []types.Method{{Name: "UpdateFiles", Doc: "UpdateFiles Updates the files list.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "UpdateScripts", Doc: "UpdateScripts updates the Scripts and updates the toolbar.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}}, Fields: []types.Field{{Name: "FS", Doc: "FS is the filesystem, if browsing an FS."}, {Name: "DataRoot", Doc: "DataRoot is the path to the root of the data to browse."}, {Name: "StartDir", Doc: "StartDir is the starting directory, where the app was originally started."}, {Name: "ScriptsDir", Doc: "ScriptsDir is the directory containing scripts for toolbar actions.\nIt defaults to DataRoot/dbscripts"}, {Name: "Scripts", Doc: "Scripts"}, {Name: "Interpreter", Doc: "Interpreter is the interpreter to use for running Browser scripts"}, {Name: "Files", Doc: "Files is the [DataTree] tree browser of the tensorfs or files."}, {Name: "Tabs", Doc: "Tabs is the [Tabber] element managing tabs of data views."}, {Name: "Toolbar", Doc: "Toolbar is the top-level toolbar for the browser, if used."}, {Name: "Splits", Doc: "Splits is the overall [core.Splits] for the browser."}}})
// SetFS sets the [Browser.FS]:
// FS is the filesystem, if browsing an FS.
func (t *Browser) SetFS(v fs.FS) *Browser { t.FS = v; return t }
// SetDataRoot sets the [Browser.DataRoot]:
// DataRoot is the path to the root of the data to browse.
func (t *Browser) SetDataRoot(v string) *Browser { t.DataRoot = v; return t }
// SetStartDir sets the [Browser.StartDir]:
// StartDir is the starting directory, where the app was originally started.
func (t *Browser) SetStartDir(v string) *Browser { t.StartDir = v; return t }
// SetScriptsDir sets the [Browser.ScriptsDir]:
// ScriptsDir is the directory containing scripts for toolbar actions.
// It defaults to DataRoot/dbscripts
func (t *Browser) SetScriptsDir(v string) *Browser { t.ScriptsDir = v; return t }
// SetFiles sets the [Browser.Files]:
// Files is the [DataTree] tree browser of the tensorfs or files.
func (t *Browser) SetFiles(v *DataTree) *Browser { t.Files = v; return t }
// SetTabs sets the [Browser.Tabs]:
// Tabs is the [Tabber] element managing tabs of data views.
func (t *Browser) SetTabs(v Tabber) *Browser { t.Tabs = v; return t }
// SetToolbar sets the [Browser.Toolbar]:
// Toolbar is the top-level toolbar for the browser, if used.
func (t *Browser) SetToolbar(v *core.Toolbar) *Browser { t.Toolbar = v; return t }
// SetSplits sets the [Browser.Splits]:
// Splits is the overall [core.Splits] for the browser.
func (t *Browser) SetSplits(v *core.Splits) *Browser { t.Splits = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/lab.DataTree", IDName: "data-tree", Doc: "DataTree is the databrowser version of [filetree.Tree],\nwhich provides the Tabber to show data editors.", Embeds: []types.Field{{Name: "Tree"}}, Fields: []types.Field{{Name: "Tabber", Doc: "Tabber is the [Tabber] for this tree."}}})
// NewDataTree returns a new [DataTree] with the given optional parent:
// DataTree is the databrowser version of [filetree.Tree],
// which provides the Tabber to show data editors.
func NewDataTree(parent ...tree.Node) *DataTree { return tree.New[DataTree](parent...) }
// SetTabber sets the [DataTree.Tabber]:
// Tabber is the [Tabber] for this tree.
func (t *DataTree) SetTabber(v Tabber) *DataTree { t.Tabber = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/lab.FileNode", IDName: "file-node", Doc: "FileNode is databrowser version of FileNode for FileTree", Methods: []types.Method{{Name: "EditFiles", Doc: "EditFiles calls EditFile on selected files", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "PlotFiles", Doc: "PlotFiles calls PlotFile on selected files", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "GridFiles", Doc: "GridFiles calls GridFile on selected files", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "DiffDirs", Doc: "DiffDirs displays a browser with differences between two selected directories", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}}, Embeds: []types.Field{{Name: "Node"}}})
// NewFileNode returns a new [FileNode] with the given optional parent:
// FileNode is databrowser version of FileNode for FileTree
func NewFileNode(parent ...tree.Node) *FileNode { return tree.New[FileNode](parent...) }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/lab.Tabs", IDName: "tabs", Doc: "Tabs implements the [Tabber] interface.", Embeds: []types.Field{{Name: "Tabs"}}})
// NewTabs returns a new [Tabs] with the given optional parent:
// Tabs implements the [Tabber] interface.
func NewTabs(parent ...tree.Node) *Tabs { return tree.New[Tabs](parent...) }
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package matrix
import (
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/stats/stats"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensor/tmath"
"gonum.org/v1/gonum/mat"
)
// Eig performs the eigen decomposition of the given square matrix,
// which is not symmetric. See EigSym for a symmetric square matrix.
// In this non-symmetric case, the results are typically complex valued,
// so the outputs are complex tensors. TODO: need complex support!
// The vectors are same size as the input. Each vector is a column
// in this 2D square matrix, ordered *lowest* to *highest* across the columns,
// i.e., maximum vector is the last column.
// The values are the size of one row, ordered *lowest* to *highest*.
// If the input tensor is > 2D, it is treated as a list of 2D matricies,
// and parallel threading is used where beneficial.
func Eig(a tensor.Tensor) (vecs, vals *tensor.Float64) {
vecs = tensor.NewFloat64()
vals = tensor.NewFloat64()
errors.Log(EigOut(a, vecs, vals))
return
}
// EigOut performs the eigen decomposition of the given square matrix,
// which is not symmetric. See EigSym for a symmetric square matrix.
// In this non-symmetric case, the results are typically complex valued,
// so the outputs are complex tensors. TODO: need complex support!
// The vectors are same size as the input. Each vector is a column
// in this 2D square matrix, ordered *lowest* to *highest* across the columns,
// i.e., maximum vector is the last column.
// The values are the size of one row, ordered *lowest* to *highest*.
// If the input tensor is > 2D, it is treated as a list of 2D matricies,
// and parallel threading is used where beneficial.
func EigOut(a tensor.Tensor, vecs, vals *tensor.Float64) error {
if err := StringCheck(a); err != nil {
return err
}
na := a.NumDims()
if na == 1 {
return mat.ErrShape
}
var asz []int
ea := a
if na > 2 {
asz = tensor.SplitAtInnerDims(a, 2)
if asz[0] == 1 {
ea = tensor.Reshape(a, asz[1:]...)
na = 2
}
}
if na == 2 {
if a.DimSize(0) != a.DimSize(1) {
return mat.ErrShape
}
ma, _ := NewMatrix(a)
vecs.SetShapeSizes(a.DimSize(0), a.DimSize(1))
vals.SetShapeSizes(a.DimSize(0))
do, _ := NewDense(vecs)
var eig mat.Eigen
ok := eig.Factorize(ma, mat.EigenRight)
if !ok {
return errors.New("gonum mat.Eigen Factorize failed")
}
_ = do
// eig.VectorsTo(do) // todo: requires complex!
// eig.Values(vals.Values)
return nil
}
ea = tensor.Reshape(a, asz...)
if ea.DimSize(1) != ea.DimSize(2) {
return mat.ErrShape
}
nr := ea.DimSize(0)
sz := ea.DimSize(1)
vecs.SetShapeSizes(nr, sz, sz)
vals.SetShapeSizes(nr, sz)
var errs []error
tensor.VectorizeThreaded(ea.DimSize(1)*ea.DimSize(2)*1000,
func(tsr ...tensor.Tensor) int { return nr },
func(r int, tsr ...tensor.Tensor) {
sa := tensor.Reslice(ea, r, tensor.FullAxis, tensor.FullAxis)
ma, _ := NewMatrix(sa)
do, _ := NewDense(vecs.RowTensor(r).(*tensor.Float64))
var eig mat.Eigen
ok := eig.Factorize(ma, mat.EigenRight)
if !ok {
errs = append(errs, errors.New("gonum mat.Eigen Factorize failed"))
}
_ = do
// eig.VectorsTo(do) // todo: requires complex!
// eig.Values(vals.Values[r*sz : (r+1)*sz])
})
return errors.Join(errs...)
}
// EigSym performs the eigen decomposition of the given symmetric square matrix,
// which produces real-valued results. When input is the [metric.CovarianceMatrix],
// this is known as Principal Components Analysis (PCA).
// The vectors are same size as the input. Each vector is a column
// in this 2D square matrix, ordered *lowest* to *highest* across the columns,
// i.e., maximum vector is the last column.
// The values are the size of one row, ordered *lowest* to *highest*.
// Note that Eig produces results in the *opposite* order of [SVD] (which is much faster).
// If the input tensor is > 2D, it is treated as a list of 2D matricies,
// and parallel threading is used where beneficial.
func EigSym(a tensor.Tensor) (vecs, vals *tensor.Float64) {
vecs = tensor.NewFloat64()
vals = tensor.NewFloat64()
errors.Log(EigSymOut(a, vecs, vals))
return
}
// EigSymOut performs the eigen decomposition of the given symmetric square matrix,
// which produces real-valued results. When input is the [metric.CovarianceMatrix],
// this is known as Principal Components Analysis (PCA).
// The vectors are same size as the input. Each vector is a column
// in this 2D square matrix, ordered *lowest* to *highest* across the columns,
// i.e., maximum vector is the last column.
// The values are the size of one row, ordered *lowest* to *highest*.
// Note that Eig produces results in the *opposite* order of [SVD] (which is much faster).
// If the input tensor is > 2D, it is treated as a list of 2D matricies,
// and parallel threading is used where beneficial.
func EigSymOut(a tensor.Tensor, vecs, vals *tensor.Float64) error {
if err := StringCheck(a); err != nil {
return err
}
na := a.NumDims()
if na == 1 {
return mat.ErrShape
}
var asz []int
ea := a
if na > 2 {
asz = tensor.SplitAtInnerDims(a, 2)
if asz[0] == 1 {
ea = tensor.Reshape(a, asz[1:]...)
na = 2
}
}
if na == 2 {
if a.DimSize(0) != a.DimSize(1) {
return mat.ErrShape
}
ma, _ := NewSymmetric(a)
vecs.SetShapeSizes(a.DimSize(0), a.DimSize(1))
vals.SetShapeSizes(a.DimSize(0))
do, _ := NewDense(vecs)
var eig mat.EigenSym
ok := eig.Factorize(ma, true)
if !ok {
return errors.New("gonum mat.EigenSym Factorize failed")
}
eig.VectorsTo(do)
eig.Values(vals.Values)
return nil
}
ea = tensor.Reshape(a, asz...)
if ea.DimSize(1) != ea.DimSize(2) {
return mat.ErrShape
}
nr := ea.DimSize(0)
sz := ea.DimSize(1)
vecs.SetShapeSizes(nr, sz, sz)
vals.SetShapeSizes(nr, sz)
var errs []error
tensor.VectorizeThreaded(ea.DimSize(1)*ea.DimSize(2)*1000,
func(tsr ...tensor.Tensor) int { return nr },
func(r int, tsr ...tensor.Tensor) {
sa := tensor.Reslice(ea, r, tensor.FullAxis, tensor.FullAxis)
ma, _ := NewSymmetric(sa)
do, _ := NewDense(vecs.RowTensor(r).(*tensor.Float64))
var eig mat.EigenSym
ok := eig.Factorize(ma, true)
if !ok {
errs = append(errs, errors.New("gonum mat.Eigen Factorize failed"))
}
eig.VectorsTo(do)
eig.Values(vals.Values[r*sz : (r+1)*sz])
})
return errors.Join(errs...)
}
// SVD performs the singular value decomposition of the given symmetric square matrix,
// which produces real-valued results, and is generally much faster than [EigSym],
// while producing the same results.
// The vectors are same size as the input. Each vector is a column
// in this 2D square matrix, ordered *highest* to *lowest* across the columns,
// i.e., maximum vector is the first column.
// The values are the size of one row ordered in alignment with the vectors.
// Note that SVD produces results in the *opposite* order of [EigSym].
// If the input tensor is > 2D, it is treated as a list of 2D matricies,
// and parallel threading is used where beneficial.
func SVD(a tensor.Tensor) (vecs, vals *tensor.Float64) {
vecs = tensor.NewFloat64()
vals = tensor.NewFloat64()
errors.Log(SVDOut(a, vecs, vals))
return
}
// SVDOut performs the singular value decomposition of the given symmetric square matrix,
// which produces real-valued results, and is generally much faster than [EigSym],
// while producing the same results.
// The vectors are same size as the input. Each vector is a column
// in this 2D square matrix, ordered *highest* to *lowest* across the columns,
// i.e., maximum vector is the first column.
// The values are the size of one row ordered in alignment with the vectors.
// Note that SVD produces results in the *opposite* order of [EigSym].
// If the input tensor is > 2D, it is treated as a list of 2D matricies,
// and parallel threading is used where beneficial.
func SVDOut(a tensor.Tensor, vecs, vals *tensor.Float64) error {
if err := StringCheck(a); err != nil {
return err
}
na := a.NumDims()
if na == 1 {
return mat.ErrShape
}
var asz []int
ea := a
if na > 2 {
asz = tensor.SplitAtInnerDims(a, 2)
if asz[0] == 1 {
ea = tensor.Reshape(a, asz[1:]...)
na = 2
}
}
if na == 2 {
if a.DimSize(0) != a.DimSize(1) {
return mat.ErrShape
}
ma, _ := NewSymmetric(a)
vecs.SetShapeSizes(a.DimSize(0), a.DimSize(1))
vals.SetShapeSizes(a.DimSize(0))
do, _ := NewDense(vecs)
var eig mat.SVD
ok := eig.Factorize(ma, mat.SVDFull)
if !ok {
return errors.New("gonum mat.SVD Factorize failed")
}
eig.UTo(do)
eig.Values(vals.Values)
return nil
}
ea = tensor.Reshape(a, asz...)
if ea.DimSize(1) != ea.DimSize(2) {
return mat.ErrShape
}
nr := ea.DimSize(0)
sz := ea.DimSize(1)
vecs.SetShapeSizes(nr, sz, sz)
vals.SetShapeSizes(nr, sz)
var errs []error
tensor.VectorizeThreaded(ea.DimSize(1)*ea.DimSize(2)*1000,
func(tsr ...tensor.Tensor) int { return nr },
func(r int, tsr ...tensor.Tensor) {
sa := tensor.Reslice(ea, r, tensor.FullAxis, tensor.FullAxis)
ma, _ := NewSymmetric(sa)
do, _ := NewDense(vecs.RowTensor(r).(*tensor.Float64))
var eig mat.SVD
ok := eig.Factorize(ma, mat.SVDFull)
if !ok {
errs = append(errs, errors.New("gonum mat.SVD Factorize failed"))
}
eig.UTo(do)
eig.Values(vals.Values[r*sz : (r+1)*sz])
})
return errors.Join(errs...)
}
// SVDValues performs the singular value decomposition of the given
// symmetric square matrix, which produces real-valued results,
// and is generally much faster than [EigSym], while producing the same results.
// This version only generates eigenvalues, not vectors: see [SVD].
// The values are the size of one row ordered highest to lowest,
// which is the opposite of [EigSym].
// If the input tensor is > 2D, it is treated as a list of 2D matricies,
// and parallel threading is used where beneficial.
func SVDValues(a tensor.Tensor) *tensor.Float64 {
vals := tensor.NewFloat64()
errors.Log(SVDValuesOut(a, vals))
return vals
}
// SVDValuesOut performs the singular value decomposition of the given
// symmetric square matrix, which produces real-valued results,
// and is generally much faster than [EigSym], while producing the same results.
// This version only generates eigenvalues, not vectors: see [SVDOut].
// The values are the size of one row ordered highest to lowest,
// which is the opposite of [EigSym].
// If the input tensor is > 2D, it is treated as a list of 2D matricies,
// and parallel threading is used where beneficial.
func SVDValuesOut(a tensor.Tensor, vals *tensor.Float64) error {
if err := StringCheck(a); err != nil {
return err
}
na := a.NumDims()
if na == 1 {
return mat.ErrShape
}
var asz []int
ea := a
if na > 2 {
asz = tensor.SplitAtInnerDims(a, 2)
if asz[0] == 1 {
ea = tensor.Reshape(a, asz[1:]...)
na = 2
}
}
if na == 2 {
if a.DimSize(0) != a.DimSize(1) {
return mat.ErrShape
}
ma, _ := NewSymmetric(a)
vals.SetShapeSizes(a.DimSize(0))
var eig mat.SVD
ok := eig.Factorize(ma, mat.SVDNone)
if !ok {
return errors.New("gonum mat.SVD Factorize failed")
}
eig.Values(vals.Values)
return nil
}
ea = tensor.Reshape(a, asz...)
if ea.DimSize(1) != ea.DimSize(2) {
return mat.ErrShape
}
nr := ea.DimSize(0)
sz := ea.DimSize(1)
vals.SetShapeSizes(nr, sz)
var errs []error
tensor.VectorizeThreaded(ea.DimSize(1)*ea.DimSize(2)*1000,
func(tsr ...tensor.Tensor) int { return nr },
func(r int, tsr ...tensor.Tensor) {
sa := tensor.Reslice(ea, r, tensor.FullAxis, tensor.FullAxis)
ma, _ := NewSymmetric(sa)
var eig mat.SVD
ok := eig.Factorize(ma, mat.SVDNone)
if !ok {
errs = append(errs, errors.New("gonum mat.SVD Factorize failed"))
}
eig.Values(vals.Values[r*sz : (r+1)*sz])
})
return errors.Join(errs...)
}
// ProjectOnMatrixColumn is a convenience function for projecting given vector
// of values along a specific column (2nd dimension) of the given 2D matrix,
// specified by the scalar colindex, putting results into out.
// If the vec is more than 1 dimensional, then it is treated as rows x cells,
// and each row of cells is projected through the matrix column, producing a
// 1D output with the number of rows. Otherwise a single number is produced.
// This is typically done with results from SVD or EigSym (PCA).
func ProjectOnMatrixColumn(mtx, vec, colindex tensor.Tensor) tensor.Values {
out := tensor.NewOfType(vec.DataType())
errors.Log(ProjectOnMatrixColumnOut(mtx, vec, colindex, out))
return out
}
// ProjectOnMatrixColumnOut is a convenience function for projecting given vector
// of values along a specific column (2nd dimension) of the given 2D matrix,
// specified by the scalar colindex, putting results into out.
// If the vec is more than 1 dimensional, then it is treated as rows x cells,
// and each row of cells is projected through the matrix column, producing a
// 1D output with the number of rows. Otherwise a single number is produced.
// This is typically done with results from SVD or EigSym (PCA).
func ProjectOnMatrixColumnOut(mtx, vec, colindex tensor.Tensor, out tensor.Values) error {
ci := int(colindex.Float1D(0))
col := tensor.As1D(tensor.Reslice(mtx, tensor.Slice{}, ci))
// fmt.Println(mtx.String(), col.String())
rows, cells := vec.Shape().RowCellSize()
if rows > 0 && cells > 0 {
msum := tensor.NewFloat64Scalar(0)
out.SetShapeSizes(rows)
mout := tensor.NewFloat64(cells)
for i := range rows {
err := tmath.MulOut(tensor.Cells1D(vec, i), col, mout)
if err != nil {
return err
}
stats.SumOut(mout, msum)
out.SetFloat1D(msum.Float1D(0), i)
}
} else {
mout := tensor.NewFloat64(1)
tmath.MulOut(vec, col, mout)
stats.SumOut(mout, out)
}
return nil
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package matrix
import (
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/num"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/vector"
)
// offCols is a helper function to process the optional offset_cols args
func offCols(size int, offset_cols ...int) (off, cols int) {
off = 0
cols = size
if len(offset_cols) >= 1 {
off = offset_cols[0]
}
if len(offset_cols) == 2 {
cols = offset_cols[1]
}
return
}
// Identity returns a new 2D Float64 tensor with 1s along the diagonal and
// 0s elsewhere, with the given row and column size.
// - If one additional parameter is passed, it is the offset,
// to set values above (positive) or below (negative) the diagonal.
// - If a second additional parameter is passed, it is the number of columns
// for a non-square matrix (first size parameter = number of rows).
func Identity(size int, offset_cols ...int) *tensor.Float64 {
off, cols := offCols(size, offset_cols...)
tsr := tensor.NewFloat64(size, cols)
for r := range size {
c := r + off
if c < 0 || c >= cols {
continue
}
tsr.SetFloat(1, r, c)
}
return tsr
}
// DiagonalN returns the number of elements in the along the diagonal
// of a 2D matrix of given row and column size.
// - If one additional parameter is passed, it is the offset,
// to include values above (positive) or below (negative) the diagonal.
// - If a second additional parameter is passed, it is the number of columns
// for a non-square matrix (first size parameter = number of rows).
func DiagonalN(size int, offset_cols ...int) int {
off, cols := offCols(size, offset_cols...)
rows := size
if num.Abs(off) > 0 {
oa := num.Abs(off)
if off > 0 {
if cols > rows {
return DiagonalN(rows, 0, cols-oa)
} else {
return DiagonalN(rows-oa, 0, cols-oa)
}
} else {
if rows > cols {
return DiagonalN(rows-oa, 0, cols)
} else {
return DiagonalN(rows-oa, 0, cols-oa)
}
}
}
n := min(rows, cols)
return n
}
// DiagonalIndices returns a list of indices for the diagonal elements of
// a 2D matrix of given row and column size.
// The result is a 2D list of indices, where the outer (row) dimension
// is the number of indices, and the inner dimension is 2 for the r, c coords.
// - If one additional parameter is passed, it is the offset,
// to set values above (positive) or below (negative) the diagonal.
// - If a second additional parameter is passed, it is the number of columns
// for a non-square matrix (first size parameter = number of rows).
func DiagonalIndices(size int, offset_cols ...int) *tensor.Int {
off, cols := offCols(size, offset_cols...)
dn := DiagonalN(size, off, cols)
tsr := tensor.NewInt(dn, 2)
idx := 0
for r := range size {
c := r + off
if c < 0 || c >= cols {
continue
}
tsr.SetInt(r, idx, 0)
tsr.SetInt(c, idx, 1)
idx++
}
return tsr
}
// Diagonal returns an [Indexed] view of the given tensor for the diagonal
// values, as a 1D list. An error is logged if the tensor is not 2D.
// Use the optional offset parameter to get values above (positive) or
// below (negative) the diagonal.
func Diagonal(tsr tensor.Tensor, offset ...int) *tensor.Indexed {
if tsr.NumDims() != 2 {
errors.Log(errors.New("matrix.TriLView requires a 2D tensor"))
return nil
}
off := 0
if len(offset) == 1 {
off = offset[0]
}
return tensor.NewIndexed(tsr, DiagonalIndices(tsr.DimSize(0), off, tsr.DimSize(1)))
}
// Trace returns the sum of the [Diagonal] elements of the given
// tensor, as a tensor scalar.
// An error is logged if the tensor is not 2D.
// Use the optional offset parameter to get values above (positive) or
// below (negative) the diagonal.
func Trace(tsr tensor.Tensor, offset ...int) tensor.Values {
return vector.Sum(Diagonal(tsr, offset...))
}
// Tri returns a new 2D Float64 tensor with 1s along the diagonal and
// below it, and 0s elsewhere (i.e., a filled lower triangle).
// - If one additional parameter is passed, it is the offset,
// to include values above (positive) or below (negative) the diagonal.
// - If a second additional parameter is passed, it is the number of columns
// for a non-square matrix (first size parameter = number of rows).
func Tri(size int, offset_cols ...int) *tensor.Float64 {
off, cols := offCols(size, offset_cols...)
tsr := tensor.NewFloat64(size, cols)
for r := range size {
for c := range cols {
if c <= r+off {
tsr.SetFloat(1, r, c)
}
}
}
return tsr
}
// TriUpper returns a new 2D Float64 tensor with 1s along the diagonal and
// above it, and 0s elsewhere (i.e., a filled upper triangle).
// - If one additional parameter is passed, it is the offset,
// to include values above (positive) or below (negative) the diagonal.
// - If a second additional parameter is passed, it is the number of columns
// for a non-square matrix (first size parameter = number of rows).
func TriUpper(size int, offset_cols ...int) *tensor.Float64 {
off, cols := offCols(size, offset_cols...)
tsr := tensor.NewFloat64(size, cols)
for r := range size {
for c := range cols {
if c >= r+off {
tsr.SetFloat(1, r, c)
}
}
}
return tsr
}
// TriUNum returns the number of elements in the upper triangular region
// of a 2D matrix of given row and column size, where the triangle includes the
// elements along the diagonal.
// - If one additional parameter is passed, it is the offset,
// to include values above (positive) or below (negative) the diagonal.
// - If a second additional parameter is passed, it is the number of columns
// for a non-square matrix (first size parameter = number of rows).
func TriUNum(size int, offset_cols ...int) int {
off, cols := offCols(size, offset_cols...)
rows := size
if off > 0 {
if cols > rows {
return TriUNum(rows, 0, cols-off)
} else {
return TriUNum(rows-off, 0, cols-off)
}
} else if off < 0 { // invert
return cols*rows - TriUNum(cols, -(off-1), rows)
}
if cols <= size {
return cols + (cols*(cols-1))/2
}
return rows + (rows*(2*cols-rows-1))/2
}
// TriLNum returns the number of elements in the lower triangular region
// of a 2D matrix of given row and column size, where the triangle includes the
// elements along the diagonal.
// - If one additional parameter is passed, it is the offset,
// to include values above (positive) or below (negative) the diagonal.
// - If a second additional parameter is passed, it is the number of columns
// for a non-square matrix (first size parameter = number of rows).
func TriLNum(size int, offset_cols ...int) int {
off, cols := offCols(size, offset_cols...)
return TriUNum(cols, -off, size)
}
// TriLIndicies returns the list of r, c indexes for the lower triangular
// portion of a square matrix of size n, including the diagonal.
// The result is a 2D list of indices, where the outer (row) dimension
// is the number of indices, and the inner dimension is 2 for the r, c coords.
// - If one additional parameter is passed, it is the offset,
// to include values above (positive) or below (negative) the diagonal.
// - If a second additional parameter is passed, it is the number of columns
// for a non-square matrix.
func TriLIndicies(size int, offset_cols ...int) *tensor.Int {
off, cols := offCols(size, offset_cols...)
trin := TriLNum(size, off, cols)
coords := tensor.NewInt(trin, 2)
i := 0
for r := range size {
for c := range cols {
if c <= r+off {
coords.SetInt(r, i, 0)
coords.SetInt(c, i, 1)
i++
}
}
}
return coords
}
// TriUIndicies returns the list of r, c indexes for the upper triangular
// portion of a square matrix of size n, including the diagonal.
// If one additional parameter is passed, it is the offset,
// to include values above (positive) or below (negative) the diagonal.
// If a second additional parameter is passed, it is the number of columns
// for a non-square matrix.
// The result is a 2D list of indices, where the outer (row) dimension
// is the number of indices, and the inner dimension is 2 for the r, c coords.
func TriUIndicies(size int, offset_cols ...int) *tensor.Int {
off, cols := offCols(size, offset_cols...)
trin := TriUNum(size, off, cols)
coords := tensor.NewInt(trin, 2)
i := 0
for r := range size {
for c := range cols {
if c >= r+off {
coords.SetInt(r, i, 0)
coords.SetInt(c, i, 1)
i++
}
}
}
return coords
}
// TriLView returns an [Indexed] view of the given tensor for the lower triangular
// region of values, as a 1D list. An error is logged if the tensor is not 2D.
// Use the optional offset parameter to get values above (positive) or
// below (negative) the diagonal.
func TriLView(tsr tensor.Tensor, offset ...int) *tensor.Indexed {
if tsr.NumDims() != 2 {
errors.Log(errors.New("matrix.TriLView requires a 2D tensor"))
return nil
}
off := 0
if len(offset) == 1 {
off = offset[0]
}
return tensor.NewIndexed(tsr, TriLIndicies(tsr.DimSize(0), off, tsr.DimSize(1)))
}
// TriUView returns an [Indexed] view of the given tensor for the upper triangular
// region of values, as a 1D list. An error is logged if the tensor is not 2D.
// Use the optional offset parameter to get values above (positive) or
// below (negative) the diagonal.
func TriUView(tsr tensor.Tensor, offset ...int) *tensor.Indexed {
if tsr.NumDims() != 2 {
errors.Log(errors.New("matrix.TriUView requires a 2D tensor"))
return nil
}
off := 0
if len(offset) == 1 {
off = offset[0]
}
return tensor.NewIndexed(tsr, TriUIndicies(tsr.DimSize(0), off, tsr.DimSize(1)))
}
// TriL returns a copy of the given tensor containing the lower triangular
// region of values (including the diagonal), with the lower triangular region
// zeroed. An error is logged if the tensor is not 2D.
// Use the optional offset parameter to include values above (positive) or
// below (negative) the diagonal.
func TriL(tsr tensor.Tensor, offset ...int) tensor.Tensor {
if tsr.NumDims() != 2 {
errors.Log(errors.New("matrix.TriL requires a 2D tensor"))
return nil
}
off := 0
if len(offset) == 1 {
off = offset[0]
}
off += 1
tc := tensor.Clone(tsr)
tv := TriUView(tc, off) // opposite
tensor.SetAllFloat64(tv, 0)
return tc
}
// TriU returns a copy of the given tensor containing the upper triangular
// region of values (including the diagonal), with the lower triangular region
// zeroed. An error is logged if the tensor is not 2D.
// Use the optional offset parameter to include values above (positive) or
// below (negative) the diagonal.
func TriU(tsr tensor.Tensor, offset ...int) tensor.Tensor {
if tsr.NumDims() != 2 {
errors.Log(errors.New("matrix.TriU requires a 2D tensor"))
return nil
}
off := 0
if len(offset) == 1 {
off = offset[0]
}
off -= 1
tc := tensor.Clone(tsr)
tv := TriLView(tc, off) // opposite
tensor.SetAllFloat64(tv, 0)
return tc
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package matrix
import (
"errors"
"cogentcore.org/lab/tensor"
"gonum.org/v1/gonum/mat"
)
// Matrix provides a view of the given [tensor.Tensor] as a [gonum]
// [mat.Matrix] interface type.
type Matrix struct {
Tensor tensor.Tensor
}
func StringCheck(tsr tensor.Tensor) error {
if tsr.IsString() {
return errors.New("matrix: tensor has string values; must be numeric")
}
return nil
}
// NewMatrix returns given [tensor.Tensor] as a [gonum] [mat.Matrix].
// It returns an error if the tensor is not 2D.
func NewMatrix(tsr tensor.Tensor) (*Matrix, error) {
if err := StringCheck(tsr); err != nil {
return nil, err
}
nd := tsr.NumDims()
if nd != 2 {
err := errors.New("matrix.NewMatrix: tensor is not 2D")
return nil, err
}
return &Matrix{Tensor: tsr}, nil
}
// Dims is the gonum/mat.Matrix interface method for returning the
// dimension sizes of the 2D Matrix. Assumes Row-major ordering.
func (mx *Matrix) Dims() (r, c int) {
return mx.Tensor.DimSize(0), mx.Tensor.DimSize(1)
}
// At is the gonum/mat.Matrix interface method for returning 2D
// matrix element at given row, column index. Assumes Row-major ordering.
func (mx *Matrix) At(i, j int) float64 {
return mx.Tensor.Float(i, j)
}
// T is the gonum/mat.Matrix transpose method.
// It performs an implicit transpose by returning the receiver inside a Transpose.
func (mx *Matrix) T() mat.Matrix {
return mat.Transpose{mx}
}
//////// Symmetric
// Symmetric provides a view of the given [tensor.Tensor] as a [gonum]
// [mat.Symmetric] matrix interface type.
type Symmetric struct {
Matrix
}
// NewSymmetric returns given [tensor.Tensor] as a [gonum] [mat.Symmetric] matrix.
// It returns an error if the tensor is not 2D or not symmetric.
func NewSymmetric(tsr tensor.Tensor) (*Symmetric, error) {
if tsr.IsString() {
err := errors.New("matrix.NewSymmetric: tensor has string values; must be numeric")
return nil, err
}
nd := tsr.NumDims()
if nd != 2 {
err := errors.New("matrix.NewSymmetric: tensor is not 2D")
return nil, err
}
if tsr.DimSize(0) != tsr.DimSize(1) {
err := errors.New("matrix.NewSymmetric: tensor is not symmetric")
return nil, err
}
sy := &Symmetric{}
sy.Tensor = tsr
return sy, nil
}
// SymmetricDim is the gonum/mat.Matrix interface method for returning the
// dimensionality of a symmetric 2D Matrix.
func (sy *Symmetric) SymmetricDim() (r int) {
return sy.Tensor.DimSize(0)
}
// NewDense returns given [tensor.Float64] as a [gonum] [mat.Dense]
// Matrix, on which many of the matrix operations are defined.
// It functions similar to the [tensor.Values] type, as the output
// of matrix operations. The Dense type serves as a view onto
// the tensor's data, so operations directly modify it.
func NewDense(tsr *tensor.Float64) (*mat.Dense, error) {
nd := tsr.NumDims()
if nd != 2 {
err := errors.New("matrix.NewDense: tensor is not 2D")
return nil, err
}
return mat.NewDense(tsr.DimSize(0), tsr.DimSize(1), tsr.Values), nil
}
// CopyFromDense copies a gonum mat.Dense matrix into given Tensor
// using standard Float64 interface
func CopyFromDense(to tensor.Values, dm *mat.Dense) {
nr, nc := dm.Dims()
to.SetShapeSizes(nr, nc)
idx := 0
for ri := 0; ri < nr; ri++ {
for ci := 0; ci < nc; ci++ {
v := dm.At(ri, ci)
to.SetFloat1D(v, idx)
idx++
}
}
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package matrix
import (
"slices"
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/tensor"
"gonum.org/v1/gonum/mat"
)
// CallOut1 calls an Out function with 1 input arg. All matrix functions
// require *tensor.Float64 outputs.
func CallOut1(fun func(a tensor.Tensor, out *tensor.Float64) error, a tensor.Tensor) *tensor.Float64 {
out := tensor.NewFloat64()
errors.Log(fun(a, out))
return out
}
// CallOut2 calls an Out function with 2 input args. All matrix functions
// require *tensor.Float64 outputs.
func CallOut2(fun func(a, b tensor.Tensor, out *tensor.Float64) error, a, b tensor.Tensor) *tensor.Float64 {
out := tensor.NewFloat64()
errors.Log(fun(a, b, out))
return out
}
// Mul performs matrix multiplication, using the following rules based
// on the shapes of the relevant tensors. If the tensor shapes are not
// suitable, an error is logged (see [MulOut] for a version returning the error).
// N > 2 dimensional cases use parallel threading where beneficial.
// - If both arguments are 2-D they are multiplied like conventional matrices.
// - If either argument is N-D, N > 2, it is treated as a stack of matrices
// residing in the last two indexes and broadcast accordingly.
// - If the first argument is 1-D, it is promoted to a matrix by prepending
// a 1 to its dimensions. After matrix multiplication the prepended 1 is removed.
// - If the second argument is 1-D, it is promoted to a matrix by appending
// a 1 to its dimensions. After matrix multiplication the appended 1 is removed.
func Mul(a, b tensor.Tensor) *tensor.Float64 {
return CallOut2(MulOut, a, b)
}
// MulOut performs matrix multiplication, into the given output tensor,
// using the following rules based on the shapes of the relevant tensors.
// If the tensor shapes are not suitable, a [gonum] [mat.ErrShape] error is returned.
// N > 2 dimensional cases use parallel threading where beneficial.
// - If both arguments are 2-D they are multiplied like conventional matrices.
// The result has shape a.Rows, b.Columns.
// - If either argument is N-D, N > 2, it is treated as a stack of matrices
// residing in the last two indexes and broadcast accordingly. Both cannot
// be > 2 dimensional, unless their outer dimension size is 1 or the same.
// - If the first argument is 1-D, it is promoted to a matrix by prepending
// a 1 to its dimensions. After matrix multiplication the prepended 1 is removed.
// - If the second argument is 1-D, it is promoted to a matrix by appending
// a 1 to its dimensions. After matrix multiplication the appended 1 is removed.
func MulOut(a, b tensor.Tensor, out *tensor.Float64) error {
if err := StringCheck(a); err != nil {
return err
}
if err := StringCheck(b); err != nil {
return err
}
na := a.NumDims()
nb := b.NumDims()
ea := a
eb := b
collapse := false
colDim := 0
if na == 1 {
ea = tensor.Reshape(a, 1, a.DimSize(0))
collapse = true
colDim = -2
na = 2
}
if nb == 1 {
eb = tensor.Reshape(b, b.DimSize(0), 1)
collapse = true
colDim = -1
nb = 2
}
if na > 2 {
asz := tensor.SplitAtInnerDims(a, 2)
if asz[0] == 1 {
ea = tensor.Reshape(a, asz[1:]...)
na = 2
} else {
ea = tensor.Reshape(a, asz...)
}
}
if nb > 2 {
bsz := tensor.SplitAtInnerDims(b, 2)
if bsz[0] == 1 {
eb = tensor.Reshape(b, bsz[1:]...)
nb = 2
} else {
eb = tensor.Reshape(b, bsz...)
}
}
switch {
case na == nb && na == 2:
if ea.DimSize(1) != eb.DimSize(0) {
return mat.ErrShape
}
ma, _ := NewMatrix(ea)
mb, _ := NewMatrix(eb)
out.SetShapeSizes(ea.DimSize(0), eb.DimSize(1))
do, _ := NewDense(out)
do.Mul(ma, mb)
case na > 2 && nb == 2:
if ea.DimSize(2) != eb.DimSize(0) {
return mat.ErrShape
}
mb, _ := NewMatrix(eb)
nr := ea.DimSize(0)
out.SetShapeSizes(nr, ea.DimSize(1), eb.DimSize(1))
tensor.VectorizeThreaded(ea.DimSize(1)*ea.DimSize(2)*eb.Len()*100, // always beneficial
func(tsr ...tensor.Tensor) int { return nr },
func(r int, tsr ...tensor.Tensor) {
sa := tensor.Reslice(ea, r, tensor.FullAxis, tensor.FullAxis)
ma, _ := NewMatrix(sa)
do, _ := NewDense(out.RowTensor(r).(*tensor.Float64))
do.Mul(ma, mb)
})
case nb > 2 && na == 2:
if ea.DimSize(1) != eb.DimSize(1) {
return mat.ErrShape
}
ma, _ := NewMatrix(ea)
nr := eb.DimSize(0)
out.SetShapeSizes(nr, ea.DimSize(0), eb.DimSize(2))
tensor.VectorizeThreaded(ea.Len()*eb.DimSize(1)*eb.DimSize(2)*100,
func(tsr ...tensor.Tensor) int { return nr },
func(r int, tsr ...tensor.Tensor) {
sb := tensor.Reslice(eb, r, tensor.FullAxis, tensor.FullAxis)
mb, _ := NewMatrix(sb)
do, _ := NewDense(out.RowTensor(r).(*tensor.Float64))
do.Mul(ma, mb)
})
case na > 2 && nb > 2:
if ea.DimSize(0) != eb.DimSize(0) {
return errors.New("matrix.Mul: a and b input matricies are > 2 dimensional; must have same outer dimension sizes")
}
if ea.DimSize(2) != eb.DimSize(1) {
return mat.ErrShape
}
nr := ea.DimSize(0)
out.SetShapeSizes(nr, ea.DimSize(1), eb.DimSize(2))
tensor.VectorizeThreaded(ea.DimSize(1)*ea.DimSize(2)*eb.DimSize(1)*eb.DimSize(2),
func(tsr ...tensor.Tensor) int { return nr },
func(r int, tsr ...tensor.Tensor) {
sa := tensor.Reslice(ea, r, tensor.FullAxis, tensor.FullAxis)
ma, _ := NewMatrix(sa)
sb := tensor.Reslice(eb, r, tensor.FullAxis, tensor.FullAxis)
mb, _ := NewMatrix(sb)
do, _ := NewDense(out.RowTensor(r).(*tensor.Float64))
do.Mul(ma, mb)
})
default:
return mat.ErrShape
}
if collapse {
nd := out.NumDims()
sz := slices.Clone(out.Shape().Sizes)
if colDim == -1 {
out.SetShapeSizes(sz[:nd-1]...)
} else {
out.SetShapeSizes(append(sz[:nd-2], sz[nd-1])...)
}
}
return nil
}
// todo: following should handle N>2 dim case.
// Det returns the determinant of the given tensor.
// For a 2D matrix [[a, b], [c, d]] it this is ad - bc.
// See also [LogDet] for a version that is more numerically
// stable for large matricies.
func Det(a tensor.Tensor) *tensor.Float64 {
m, err := NewMatrix(a)
if errors.Log(err) != nil {
return tensor.NewFloat64Scalar(0)
}
return tensor.NewFloat64Scalar(mat.Det(m))
}
// LogDet returns the determinant of the given tensor,
// as the log and sign of the value, which is more
// numerically stable. The return is a 1D vector of length 2,
// with the first value being the log, and the second the sign.
func LogDet(a tensor.Tensor) *tensor.Float64 {
m, err := NewMatrix(a)
if errors.Log(err) != nil {
return tensor.NewFloat64Scalar(0)
}
l, s := mat.LogDet(m)
return tensor.NewFloat64FromValues(l, s)
}
// Inverse performs matrix inversion of a square matrix,
// logging an error for non-invertable cases.
// See [InverseOut] for a version that returns an error.
// If the input tensor is > 2D, it is treated as a list of 2D matricies
// which are each inverted.
func Inverse(a tensor.Tensor) *tensor.Float64 {
return CallOut1(InverseOut, a)
}
// InverseOut performs matrix inversion of a square matrix,
// returning an error for non-invertable cases. If the input tensor
// is > 2D, it is treated as a list of 2D matricies which are each inverted.
func InverseOut(a tensor.Tensor, out *tensor.Float64) error {
if err := StringCheck(a); err != nil {
return err
}
na := a.NumDims()
if na == 1 {
return mat.ErrShape
}
var asz []int
ea := a
if na > 2 {
asz = tensor.SplitAtInnerDims(a, 2)
if asz[0] == 1 {
ea = tensor.Reshape(a, asz[1:]...)
na = 2
}
}
if na == 2 {
if a.DimSize(0) != a.DimSize(1) {
return mat.ErrShape
}
ma, _ := NewMatrix(a)
out.SetShapeSizes(a.DimSize(0), a.DimSize(1))
do, _ := NewDense(out)
return do.Inverse(ma)
}
ea = tensor.Reshape(a, asz...)
if ea.DimSize(1) != ea.DimSize(2) {
return mat.ErrShape
}
nr := ea.DimSize(0)
out.SetShapeSizes(nr, ea.DimSize(1), ea.DimSize(2))
var errs []error
tensor.VectorizeThreaded(ea.DimSize(1)*ea.DimSize(2)*100,
func(tsr ...tensor.Tensor) int { return nr },
func(r int, tsr ...tensor.Tensor) {
sa := tensor.Reslice(ea, r, tensor.FullAxis, tensor.FullAxis)
ma, _ := NewMatrix(sa)
do, _ := NewDense(out.RowTensor(r).(*tensor.Float64))
err := do.Inverse(ma)
if err != nil {
errs = append(errs, err)
}
})
return errors.Join(errs...)
}
// Copyright (c) 2019, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package patterns
import (
"cogentcore.org/lab/base/randx"
"cogentcore.org/lab/tensor"
)
// FlipBits turns nOff bits that are currently On to Off and
// nOn bits that are currently Off to On, using permuted lists.
func FlipBits(tsr tensor.Values, nOff, nOn int, onVal, offVal float64) {
ln := tsr.Len()
if ln == 0 {
return
}
var ons, offs []int
for i := range ln {
vl := tsr.Float1D(i)
if vl == offVal {
offs = append(offs, i)
} else {
ons = append(ons, i)
}
}
randx.PermuteInts(ons, RandSource)
randx.PermuteInts(offs, RandSource)
if nOff > len(ons) {
nOff = len(ons)
}
if nOn > len(offs) {
nOn = len(offs)
}
for i := range nOff {
tsr.SetFloat1D(offVal, ons[i])
}
for i := range nOn {
tsr.SetFloat1D(onVal, offs[i])
}
}
// FlipBitsRows turns nOff bits that are currently On to Off and
// nOn bits that are currently Off to On, using permuted lists.
// Iterates over the outer-most tensor dimension as rows.
func FlipBitsRows(tsr tensor.Values, nOff, nOn int, onVal, offVal float64) {
rows, _ := tsr.Shape().RowCellSize()
for i := range rows {
trow := tsr.SubSpace(i)
FlipBits(trow, nOff, nOn, onVal, offVal)
}
}
// Copyright (c) 2019, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package patterns
import (
"fmt"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/reflectx"
"cogentcore.org/lab/tensor"
)
// Mix mixes patterns from different tensors into a combined set of patterns,
// over the outermost row dimension (i.e., each source is a list of patterns over rows).
// The source tensors must have the same cell size, and the existing shape of the destination
// will be used if compatible, otherwise reshaped with linear list of sub-tensors.
// Each source list wraps around if shorter than the total number of rows specified.
func Mix(dest tensor.Values, rows int, srcs ...tensor.Values) error {
var cells int
for i, src := range srcs {
_, c := src.Shape().RowCellSize()
if i == 0 {
cells = c
} else {
if c != cells {
err := errors.Log(fmt.Errorf("MixPatterns: cells size of source number %d, %d != first source: %d", i, c, cells))
return err
}
}
}
totlen := len(srcs) * cells * rows
if dest.Len() != totlen {
_, dcells := dest.Shape().RowCellSize()
if dcells == cells*len(srcs) {
dest.SetNumRows(rows)
} else {
sz := append([]int{rows}, len(srcs), cells)
dest.SetShapeSizes(sz...)
}
}
dtype := dest.DataType()
for i, src := range srcs {
si := i * cells
srows := src.DimSize(0)
for row := range rows {
srow := row % srows
for ci := range cells {
switch {
case reflectx.KindIsFloat(dtype):
dest.SetFloatRow(src.FloatRow(srow, ci), row, si+ci)
}
}
}
}
return nil
}
// Copyright (c) 2019, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package patterns
import (
"fmt"
"math"
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/base/randx"
"cogentcore.org/lab/stats/metric"
"cogentcore.org/lab/tensor"
)
// NFromPct returns the number of bits for given pct (proportion 0-1),
// relative to total n: just int(math.Round(pct * n))
func NFromPct(pct float64, n int) int {
return int(math.Round(pct * float64(n)))
}
// PermutedBinary sets the given tensor to contain nOn onVal values and the
// remainder are offVal values, using a permuted order of tensor elements (i.e.,
// randomly shuffled or permuted).
func PermutedBinary(tsr tensor.Values, nOn int, onVal, offVal float64) {
ln := tsr.Len()
if ln == 0 {
return
}
pord := RandSource.Perm(ln)
for i := range ln {
if i < nOn {
tsr.SetFloat1D(onVal, pord[i])
} else {
tsr.SetFloat1D(offVal, pord[i])
}
}
}
// PermutedBinaryRows uses the [tensor.RowMajor] view of a tensor as a column of rows
// as in a [table.Table], setting each row to contain nOn onVal values with the
// remainder being offVal values, using a permuted order of tensor elements
// (i.e., randomly shuffled or permuted). See also [PermutedBinaryMinDiff].
func PermutedBinaryRows(tsr tensor.Values, nOn int, onVal, offVal float64) {
rows, cells := tsr.Shape().RowCellSize()
if rows == 0 || cells == 0 {
return
}
pord := RandSource.Perm(cells)
for rw := range rows {
stidx := rw * cells
for i := 0; i < cells; i++ {
if i < nOn {
tsr.SetFloat1D(onVal, stidx+pord[i])
} else {
tsr.SetFloat1D(offVal, stidx+pord[i])
}
}
randx.PermuteInts(pord, RandSource)
}
}
// MinDiffPrintIterations set this to true to see the iteration stats for
// PermutedBinaryMinDiff -- for large, long-running cases.
var MinDiffPrintIterations = false
// PermutedBinaryMinDiff uses the [tensor.RowMajor] view of a tensor as a column of rows
// as in a [table.Table], setting each row to contain nOn onVal values, with the
// remainder being offVal values, using a permuted order of tensor elements
// (i.e., randomly shuffled or permuted). This version (see also [PermutedBinaryRows])
// ensures that all patterns have at least a given minimum distance
// from each other, expressed using minDiff = number of bits that must be different
// (can't be > nOn). If the mindiff constraint cannot be met within 100 iterations,
// an error is returned and automatically logged.
func PermutedBinaryMinDiff(tsr tensor.Values, nOn int, onVal, offVal float64, minDiff int) error {
rows, cells := tsr.Shape().RowCellSize()
if rows == 0 || cells == 0 {
return errors.New("empty tensor")
}
pord := RandSource.Perm(cells)
iters := 100
nunder := make([]int, rows) // per row
fails := 0
for itr := range iters {
for rw := range rows {
if itr > 0 && nunder[rw] == 0 {
continue
}
stidx := rw * cells
for i := range cells {
if i < nOn {
tsr.SetFloat1D(onVal, stidx+pord[i])
} else {
tsr.SetFloat1D(offVal, stidx+pord[i])
}
}
randx.PermuteInts(pord, RandSource)
}
for i := range nunder {
nunder[i] = 0
}
nbad := 0
mxnun := 0
for r1 := range rows {
r1v := tsr.SubSpace(r1)
for r2 := r1 + 1; r2 < rows; r2++ {
r2v := tsr.SubSpace(r2)
dst := metric.Hamming(tensor.As1D(r1v), tensor.As1D(r2v)).Float1D(0)
df := int(math.Round(float64(.5 * dst)))
if df < minDiff {
nunder[r1]++
mxnun = max(nunder[r1])
nunder[r2]++
mxnun = max(nunder[r2])
nbad++
}
}
}
if nbad == 0 {
break
}
fails++
if MinDiffPrintIterations {
fmt.Printf("PermutedBinaryMinDiff: Itr: %d NBad: %d MaxN: %d\n", itr, nbad, mxnun)
}
}
if fails == iters {
err := errors.Log(fmt.Errorf("PermutedBinaryMinDiff: minimum difference of: %d was not met: %d times, rows: %d", minDiff, fails, rows))
return err
}
return nil
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package patterns
import "cogentcore.org/lab/base/randx"
var (
// RandSource is a random source to use for all random numbers used in patterns.
// By default it just uses the standard Go math/rand source.
// If initialized, e.g., by calling NewRand(seed), then a separate stream of
// random numbers will be generated for all calls, and the seed is saved as
// RandSeed. It can be reinstated by calling RestoreSeed.
// Can also set RandSource to another existing randx.Rand source to use it.
RandSource = &randx.SysRand{}
// Random seed last set by NewRand or SetRandSeed.
RandSeed int64
)
// NewRand sets RandSource to a new separate random number stream
// using given seed, which is saved as RandSeed -- see RestoreSeed.
func NewRand(seed int64) {
RandSource = randx.NewSysRand(seed)
RandSeed = seed
}
// SetRandSeed sets existing random number stream to use given random
// seed, starting from the next call. Saves the seed in RandSeed -- see RestoreSeed.
func SetRandSeed(seed int64) {
RandSeed = seed
RestoreSeed()
}
// RestoreSeed restores the random seed last used -- random number sequence
// will repeat what was generated from that point onward.
func RestoreSeed() {
RandSource.Seed(RandSeed)
}
// Copyright (c) 2019, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package patterns
//go:generate core generate -add-types
import (
"fmt"
"strconv"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/metadata"
"cogentcore.org/lab/stats/stats"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorfs"
)
// NOnInTensor returns the number of bits active in given tensor
func NOnInTensor(trow tensor.Values) int {
return stats.Sum(trow).Int1D(0)
}
// PctActInTensor returns the percent activity in given tensor (NOn / size)
func PctActInTensor(trow tensor.Values) float32 {
return float32(NOnInTensor(trow)) / float32(trow.Len())
}
// Note: AppendFrom can be used to concatenate tensors.
// NameRows sets strings as prefix + row number with given number
// of leading zeros.
func NameRows(tsr tensor.Values, prefix string, nzeros int) {
ft := fmt.Sprintf("%s%%0%dd", prefix, nzeros)
rows := tsr.DimSize(0)
for i := range rows {
tsr.SetString1D(fmt.Sprintf(ft, i), i)
}
}
// Shuffle returns a [tensor.Rows] view of the given source tensor
// with the outer row-wise dimension randomly shuffled (permuted).
func Shuffle(src tensor.Values) *tensor.Rows {
idx := RandSource.Perm(src.DimSize(0))
return tensor.NewRows(src, idx...)
}
// ReplicateRows adds nCopies rows of the source tensor pattern into
// the destination tensor. The destination shape is set to ensure
// it can contain the results, preserving any existing rows of data.
func ReplicateRows(dest, src tensor.Values, nCopies int) {
curRows := 0
if dest.NumDims() > 0 {
curRows = dest.DimSize(0)
}
totRows := curRows + nCopies
dshp := append([]int{totRows}, src.Shape().Sizes...)
dest.SetShapeSizes(dshp...)
for rw := range nCopies {
dest.SetRowTensor(src, curRows+rw)
}
}
// SplitRows splits a source tensor into a set of tensors in the given
// tensorfs directory, with the given list of names, splitting at given
// rows. There should be 1 more name than rows. If names are omitted then
// the source name + incrementing counter will be used.
func SplitRows(dir *tensorfs.Node, src tensor.Values, names []string, rows ...int) error {
hasNames := len(names) != 0
if hasNames && len(names) != len(rows)+1 {
err := errors.Log(fmt.Errorf("patterns.SplitRows: must pass one more name than number of rows to split on"))
return err
}
all := append(rows, src.DimSize(0)) // final row
srcName := metadata.Name(src)
srcShape := src.ShapeSizes()
dtype := src.DataType()
prev := 0
for i, cur := range all {
if prev >= cur {
err := errors.Log(fmt.Errorf("patterns.SplitRows: rows must increase progressively"))
return err
}
name := ""
switch {
case hasNames:
name = names[i]
case len(srcName) > 0:
name = fmt.Sprintf("%s_%d", srcName, i)
default:
name = strconv.Itoa(i)
}
nrows := cur - prev
srcShape[0] = nrows
spl := tensorfs.ValueType(dir, name, dtype, srcShape...)
for rw := range nrows {
spl.SubSpace(rw).CopyFrom(src.SubSpace(prev + rw))
}
prev = cur
}
return nil
}
// AddVocabDrift adds a row-by-row drifting pool to the vocabulary,
// starting from the given row in existing vocabulary item
// (which becomes starting row in this one -- drift starts in second row).
// The current row patterns are generated by taking the previous row
// pattern and flipping pctDrift percent of active bits (min of 1 bit).
// func AddVocabDrift(mp Vocab, name string, rows int, pctDrift float32, copyFrom string, copyRow int) (tensor.Values, error) {
// cp, err := mp.ByName(copyFrom)
// if err != nil {
// return nil, err
// }
// tsr := &tensor.Float32{}
// cpshp := cp.Shape().Sizes
// cpshp[0] = rows
// tsr.SetShapeSizes(cpshp...)
// mp[name] = tsr
// cprow := cp.SubSpace(copyRow).(tensor.Values)
// trow := tsr.SubSpace(0)
// trow.CopyFrom(cprow)
// nOn := NOnInTensor(cprow)
// rmdr := 0.0 // remainder carryover in drift
// drift := float64(nOn) * float64(pctDrift) // precise fractional amount of drift
// for i := 1; i < rows; i++ {
// srow := tsr.SubSpace(i - 1)
// trow := tsr.SubSpace(i)
// trow.CopyFrom(srow)
// curDrift := math.Round(drift + rmdr) // integer amount
// nDrift := int(curDrift)
// if nDrift > 0 {
// FlipBits(trow, nDrift, nDrift, 1, 0)
// }
// rmdr += drift - curDrift // accumulate remainder
// }
// return tsr, nil
// }
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Adapted initially from gonum/plot:
// Copyright ©2015 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
import (
"math"
"cogentcore.org/core/math32"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/units"
)
// AxisScales are the scaling options for how values are distributed
// along an axis: Linear, Log, etc.
type AxisScales int32 //enums:enum
const (
// Linear is a linear axis scale.
Linear AxisScales = iota
// Log is a Logarithmic axis scale.
Log
// InverseLinear is an inverted linear axis scale.
InverseLinear
// InverseLog is an inverted log axis scale.
InverseLog
)
// AxisStyle has style properties for the axis.
type AxisStyle struct { //types:add -setters
// Text has the text style parameters for the text label.
Text TextStyle
// Line has styling properties for the axis line.
Line LineStyle
// Padding between the axis line and the data. Having
// non-zero padding ensures that the data is never drawn
// on the axis, thus making it easier to see.
Padding units.Value
// NTicks is the desired number of ticks (actual likely will be different).
NTicks int
// Scale specifies how values are scaled along the axis:
// Linear, Log, Inverted
Scale AxisScales
// TickText has the text style for rendering tick labels,
// and is shared for actual rendering.
TickText TextStyle
// TickLine has line style for drawing tick lines.
TickLine LineStyle
// TickLength is the length of tick lines.
TickLength units.Value
}
func (ax *AxisStyle) Defaults() {
ax.Line.Defaults()
ax.Text.Defaults()
ax.Text.Size.Dp(20)
ax.Padding.Pt(5)
ax.NTicks = 5
ax.TickText.Defaults()
ax.TickText.Size.Dp(16)
ax.TickText.Padding.Dp(2)
ax.TickLine.Defaults()
ax.TickLength.Pt(8)
}
// Axis represents either a horizontal or vertical
// axis of a plot.
type Axis struct {
// Range has the Min, Max range of values for the axis (in raw data units.)
Range minmax.F64
// specifies which axis this is: X, Y or Z.
Axis math32.Dims
// Label for the axis.
Label Text
// Style has the style parameters for the Axis.
Style AxisStyle
// TickText is used for rendering the tick text labels.
TickText Text
// Ticker generates the tick marks. Any tick marks
// returned by the Marker function that are not in
// range of the axis are not drawn.
Ticker Ticker
// Scale transforms a value given in the data coordinate system
// to the normalized coordinate system of the axis—its distance
// along the axis as a fraction of the axis range.
Scale Normalizer
// AutoRescale enables an axis to automatically adapt its minimum
// and maximum boundaries, according to its underlying Ticker.
AutoRescale bool
// cached list of ticks, set in size
ticks []Tick
}
// Sets Defaults, range is (∞, Â∞), and thus any finite
// value is less than Min and greater than Max.
func (ax *Axis) Defaults(dim math32.Dims) {
ax.Style.Defaults()
ax.Axis = dim
if dim == math32.Y {
ax.Label.Style.Rotation = -90
ax.Style.TickText.Align = styles.End
}
ax.Scale = LinearScale{}
ax.Ticker = DefaultTicks{}
}
// drawConfig configures for drawing.
func (ax *Axis) drawConfig() {
switch ax.Style.Scale {
case Linear:
ax.Scale = LinearScale{}
case Log:
ax.Scale = LogScale{}
case InverseLinear:
ax.Scale = InvertedScale{LinearScale{}}
case InverseLog:
ax.Scale = InvertedScale{LogScale{}}
}
}
// SanitizeRange ensures that the range of the axis makes sense.
func (ax *Axis) SanitizeRange() {
ax.Range.Sanitize()
if ax.AutoRescale {
marks := ax.Ticker.Ticks(ax.Range.Min, ax.Range.Max, ax.Style.NTicks)
for _, t := range marks {
ax.Range.FitValInRange(t.Value)
}
}
}
// Normalizer rescales values from the data coordinate system to the
// normalized coordinate system.
type Normalizer interface {
// Normalize transforms a value x in the data coordinate system to
// the normalized coordinate system.
Normalize(min, max, x float64) float64
}
// LinearScale an be used as the value of an Axis.Scale function to
// set the axis to a standard linear scale.
type LinearScale struct{}
var _ Normalizer = LinearScale{}
// Normalize returns the fractional distance of x between min and max.
func (LinearScale) Normalize(min, max, x float64) float64 {
return (x - min) / (max - min)
}
// LogScale can be used as the value of an Axis.Scale function to
// set the axis to a log scale.
type LogScale struct{}
var _ Normalizer = LogScale{}
// Normalize returns the fractional logarithmic distance of
// x between min and max.
func (LogScale) Normalize(min, max, x float64) float64 {
if min <= 0 || max <= 0 || x <= 0 {
panic("Values must be greater than 0 for a log scale.")
}
logMin := math.Log(min)
return (math.Log(x) - logMin) / (math.Log(max) - logMin)
}
// InvertedScale can be used as the value of an Axis.Scale function to
// invert the axis using any Normalizer.
type InvertedScale struct{ Normalizer }
var _ Normalizer = InvertedScale{}
// Normalize returns a normalized [0, 1] value for the position of x.
func (is InvertedScale) Normalize(min, max, x float64) float64 {
return is.Normalizer.Normalize(max, min, x)
}
// Norm returns the value of x, given in the data coordinate
// system, normalized to its distance as a fraction of the
// range of this axis. For example, if x is a.Min then the return
// value is 0, and if x is a.Max then the return value is 1.
func (ax *Axis) Norm(x float64) float64 {
return ax.Scale.Normalize(ax.Range.Min, ax.Range.Max, x)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Adapted from github.com/gonum/plot:
// Copyright ©2015 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
import (
"log/slog"
"math"
"strconv"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/math32/minmax"
)
// data defines the main data interfaces for plotting
// and the different Roles for data.
var (
ErrInfinity = errors.New("plotter: infinite data point")
ErrNoData = errors.New("plotter: no data points")
)
// Data is a map of Roles and Data for that Role, providing the
// primary way of passing data to a Plotter
type Data map[Roles]Valuer
// Valuer is the data interface for plotting, supporting either
// float64 or string representations. It is satisfied by the tensor.Tensor
// interface, so a tensor can be used directly for plot Data.
type Valuer interface {
// Len returns the number of values.
Len() int
// Float1D(i int) returns float64 value at given index.
Float1D(i int) float64
// String1D(i int) returns string value at given index.
String1D(i int) string
}
// Roles are the roles that a given set of data values can play,
// designed to be sufficiently generalizable across all different
// types of plots, even if sometimes it is a bit of a stretch.
type Roles int32 //enums:enum
const (
// NoRole is the default no-role specified case.
NoRole Roles = iota
// X axis
X
// Y axis
Y
// Z axis
Z
// U is the X component of a vector or first quartile in Box plot, etc.
U
// V is the Y component of a vector or third quartile in a Box plot, etc.
V
// W is the Z component of a vector
W
// Low is a lower error bar or region.
Low
// High is an upper error bar or region.
High
// Size controls the size of points etc.
Size
// Color controls the color of points or other elements.
Color
// Label renders a label, typically from string data, but can also be used for values.
Label
)
// CheckFloats returns an error if any of the arguments are Infinity.
// or if there are no non-NaN data points available for plotting.
func CheckFloats(fs ...float64) error {
n := 0
for _, f := range fs {
switch {
case math.IsNaN(f):
case math.IsInf(f, 0):
return ErrInfinity
default:
n++
}
}
if n == 0 {
return ErrNoData
}
return nil
}
// CheckNaNs returns true if any of the floats are NaN
func CheckNaNs(fs ...float64) bool {
for _, f := range fs {
if math.IsNaN(f) {
return true
}
}
return false
}
// Range updates given Range with values from data.
func Range(data Valuer, rng *minmax.F64) {
for i := 0; i < data.Len(); i++ {
v := data.Float1D(i)
if math.IsNaN(v) {
continue
}
rng.FitValInRange(v)
}
}
// RangeClamp updates the given axis Min, Max range values based
// on the range of values in the given [Data], and the given style range.
func RangeClamp(data Valuer, axisRng *minmax.F64, styleRng *minmax.Range64) {
Range(data, axisRng)
axisRng.Min, axisRng.Max = styleRng.Clamp(axisRng.Min, axisRng.Max)
}
// CheckLengths checks that all the data elements have the same length.
// Logs and returns an error if not.
func (dt Data) CheckLengths() error {
n := 0
for _, v := range dt {
if n == 0 {
n = v.Len()
} else {
if v.Len() != n {
err := errors.New("plot.Data has inconsistent lengths -- all data elements must have the same length -- plotting aborted")
return errors.Log(err)
}
}
}
return nil
}
// Values provides a minimal implementation of the Data interface
// using a slice of float64.
type Values []float64
func (vs Values) Len() int {
return len(vs)
}
func (vs Values) Float1D(i int) float64 {
return vs[i]
}
func (vs Values) String1D(i int) string {
return strconv.FormatFloat(vs[i], 'g', -1, 64)
}
// CopyValues returns a Values that is a copy of the values
// from Data, or an error if there are no values, or if one of
// the copied values is a Infinity.
// NaN values are skipped in the copying process.
func CopyValues(data Valuer) (Values, error) {
if data == nil {
return nil, ErrNoData
}
cpy := make(Values, 0, data.Len())
for i := 0; i < data.Len(); i++ {
v := data.Float1D(i)
if math.IsNaN(v) {
continue
}
if err := CheckFloats(v); err != nil {
return nil, err
}
cpy = append(cpy, v)
}
return cpy, nil
}
// MustCopyRole returns Values copy of given role from given data map,
// logging an error and returning nil if not present.
func MustCopyRole(data Data, role Roles) Values {
d, ok := data[role]
if !ok {
slog.Error("plot Data role not present, but is required", "role:", role)
return nil
}
return errors.Log1(CopyValues(d))
}
// CopyRole returns Values copy of given role from given data map,
// returning nil if role not present.
func CopyRole(data Data, role Roles) Values {
d, ok := data[role]
if !ok {
return nil
}
v, _ := CopyValues(d)
return v
}
// PlotX returns plot pixel X coordinate values for given data.
func PlotX(plt *Plot, data Valuer) []float32 {
px := make([]float32, data.Len())
for i := range px {
px[i] = plt.PX(data.Float1D(i))
}
return px
}
// PlotY returns plot pixel Y coordinate values for given data.
func PlotY(plt *Plot, data Valuer) []float32 {
py := make([]float32, data.Len())
for i := range py {
py[i] = plt.PY(data.Float1D(i))
}
return py
}
//////// Labels
// Labels provides a minimal implementation of the Data interface
// using a slice of string. It always returns 0 for Float1D.
type Labels []string
func (lb Labels) Len() int {
return len(lb)
}
func (lb Labels) Float1D(i int) float64 {
return 0
}
func (lb Labels) String1D(i int) string {
return lb[i]
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Adapted from gonum/plot:
// Copyright ©2015 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
import (
"bufio"
"bytes"
"image"
"io"
"os"
"cogentcore.org/core/math32"
"cogentcore.org/core/styles"
)
// SVGString returns an SVG representation of the plot as a string
func (pt *Plot) SVGString() string {
b := &bytes.Buffer{}
pt.Paint.SVGOut = b
pt.svgDraw()
pt.Paint.SVGOut = nil
return b.String()
}
// svgDraw draws SVGOut writer that must already be set in Paint
func (pt *Plot) svgDraw() {
pt.drawConfig()
io.WriteString(pt.Paint.SVGOut, pt.Paint.SVGStart())
pt.Draw()
io.WriteString(pt.Paint.SVGOut, pt.Paint.SVGEnd())
}
// SVGToFile saves the SVG to given file
func (pt *Plot) SVGToFile(filename string) error {
fp, err := os.Create(filename)
if err != nil {
return err
}
defer fp.Close()
bw := bufio.NewWriter(fp)
pt.Paint.SVGOut = bw
pt.svgDraw()
pt.Paint.SVGOut = nil
return bw.Flush()
}
// drawConfig configures everything for drawing, applying styles etc.
func (pt *Plot) drawConfig() {
pt.applyStyle()
pt.Resize(pt.Size) // ensure
pt.X.drawConfig()
pt.Y.drawConfig()
pt.Z.drawConfig()
pt.Paint.ToDots()
}
// Draw draws the plot to image.
// Plotters are drawn in the order in which they were
// added to the plot.
func (pt *Plot) Draw() {
pt.drawConfig()
pc := pt.Paint
ptw := float32(pt.Size.X)
pth := float32(pt.Size.X)
ptb := image.Rectangle{Max: pt.Size}
pc.PushBounds(ptb)
if pt.Style.Background != nil {
pc.BlitBox(math32.Vector2{}, math32.FromPoint(pt.Size), pt.Style.Background)
}
if pt.Title.Text != "" {
pt.Title.Config(pt)
pos := pt.Title.PosX(ptw)
pad := pt.Title.Style.Padding.Dots
pos.Y = pad
pt.Title.Draw(pt, pos)
th := pt.Title.PaintText.BBox.Size().Y + 2*pad
pth -= th
ptb.Min.Y += int(math32.Ceil(th))
}
pt.X.SanitizeRange()
pt.Y.SanitizeRange()
ywidth, tickWidth, tpad, bpad := pt.Y.sizeY(pt)
xheight, lpad, rpad := pt.X.sizeX(pt, float32(pt.Size.X-int(ywidth)))
tb := ptb
tb.Min.X += ywidth
pc.PushBounds(tb)
pt.X.drawX(pt, lpad, rpad)
pc.PopBounds()
tb = ptb
tb.Max.Y -= xheight
pc.PushBounds(tb)
pt.Y.drawY(pt, tickWidth, tpad, bpad)
pc.PopBounds()
tb = ptb
tb.Min.X += ywidth + lpad
tb.Max.X -= rpad
tb.Max.Y -= xheight + bpad
tb.Min.Y += tpad
pt.PlotBox.SetFromRect(tb)
// don't cut off lines
tb.Min.X -= 2
tb.Min.Y -= 2
tb.Max.X += 2
tb.Max.Y += 2
pc.PushBounds(tb)
for _, plt := range pt.Plotters {
plt.Plot(pt)
}
pt.Legend.draw(pt)
pc.PopBounds()
pc.PopBounds() // global
}
////////////////////////////////////////////////////////////////
// Axis
// drawTicks returns true if the tick marks should be drawn.
func (ax *Axis) drawTicks() bool {
return ax.Style.TickLine.Width.Value > 0 && ax.Style.TickLength.Value > 0
}
// sizeX returns the total height of the axis, left and right padding
func (ax *Axis) sizeX(pt *Plot, axw float32) (ht, lpad, rpad int) {
pc := pt.Paint
uc := &pc.UnitContext
ax.Style.TickLength.ToDots(uc)
ax.ticks = ax.Ticker.Ticks(ax.Range.Min, ax.Range.Max, ax.Style.NTicks)
h := float32(0)
if ax.Label.Text != "" { // We assume that the label isn't rotated.
ax.Label.Config(pt)
h += ax.Label.PaintText.BBox.Size().Y
h += ax.Label.Style.Padding.Dots
}
lw := ax.Style.Line.Width.Dots
lpad = int(math32.Ceil(lw)) + 2
rpad = int(math32.Ceil(lw)) + 10
tht := float32(0)
if len(ax.ticks) > 0 {
if ax.drawTicks() {
h += ax.Style.TickLength.Dots
}
ftk := ax.firstTickLabel()
if ftk.Label != "" {
px, _ := ax.tickPosX(pt, ftk, axw)
if px < 0 {
lpad += int(math32.Ceil(-px))
}
tht = max(tht, ax.TickText.PaintText.BBox.Size().Y)
}
ltk := ax.lastTickLabel()
if ltk.Label != "" {
px, wd := ax.tickPosX(pt, ltk, axw)
if px+wd > axw {
rpad += int(math32.Ceil((px + wd) - axw))
}
tht = max(tht, ax.TickText.PaintText.BBox.Size().Y)
}
ax.TickText.Text = ax.longestTickLabel()
if ax.TickText.Text != "" {
ax.TickText.Config(pt)
tht = max(tht, ax.TickText.PaintText.BBox.Size().Y)
}
h += ax.TickText.Style.Padding.Dots
}
h += tht + lw + ax.Style.Padding.Dots
ht = int(math32.Ceil(h))
return
}
// tickLabelPosX returns the relative position and width for given tick along X axis
// for given total axis width
func (ax *Axis) tickPosX(pt *Plot, t Tick, axw float32) (px, wd float32) {
x := axw * float32(ax.Norm(t.Value))
if x < 0 || x > axw {
return
}
ax.TickText.Text = t.Label
ax.TickText.Config(pt)
pos := ax.TickText.PosX(0)
px = pos.X + x
wd = ax.TickText.PaintText.BBox.Size().X
return
}
func (ax *Axis) firstTickLabel() Tick {
for _, tk := range ax.ticks {
if tk.Label != "" {
return tk
}
}
return Tick{}
}
func (ax *Axis) lastTickLabel() Tick {
n := len(ax.ticks)
for i := n - 1; i >= 0; i-- {
tk := ax.ticks[i]
if tk.Label != "" {
return tk
}
}
return Tick{}
}
func (ax *Axis) longestTickLabel() string {
lst := ""
for _, tk := range ax.ticks {
if len(tk.Label) > len(lst) {
lst = tk.Label
}
}
return lst
}
func (ax *Axis) sizeY(pt *Plot) (ywidth, tickWidth, tpad, bpad int) {
pc := pt.Paint
uc := &pc.UnitContext
ax.ticks = ax.Ticker.Ticks(ax.Range.Min, ax.Range.Max, ax.Style.NTicks)
ax.Style.TickLength.ToDots(uc)
w := float32(0)
if ax.Label.Text != "" {
ax.Label.Config(pt)
w += ax.Label.PaintText.BBox.Size().X
w += ax.Label.Style.Padding.Dots
}
lw := ax.Style.Line.Width.Dots
tpad = int(math32.Ceil(lw)) + 2
bpad = int(math32.Ceil(lw)) + 2
if len(ax.ticks) > 0 {
if ax.drawTicks() {
w += ax.Style.TickLength.Dots
}
ax.TickText.Text = ax.longestTickLabel()
if ax.TickText.Text != "" {
ax.TickText.Config(pt)
tw := ax.TickText.PaintText.BBox.Size().X
w += tw
tickWidth = int(math32.Ceil(tw))
w += ax.TickText.Style.Padding.Dots
tht := int(math32.Ceil(0.5 * ax.TickText.PaintText.BBox.Size().X))
tpad += tht
bpad += tht
}
}
w += lw + ax.Style.Padding.Dots
ywidth = int(math32.Ceil(w))
return
}
// drawX draws the horizontal axis
func (ax *Axis) drawX(pt *Plot, lpad, rpad int) {
ab := pt.Paint.Bounds
ab.Min.X += lpad
ab.Max.X -= rpad
axw := float32(ab.Size().X)
// axh := float32(ab.Size().Y) // height of entire plot
if ax.Label.Text != "" {
ax.Label.Config(pt)
pos := ax.Label.PosX(axw)
pos.X += float32(ab.Min.X)
th := ax.Label.PaintText.BBox.Size().Y
pos.Y = float32(ab.Max.Y) - th
ax.Label.Draw(pt, pos)
ab.Max.Y -= int(math32.Ceil(th + ax.Label.Style.Padding.Dots))
}
tickHt := float32(0)
for _, t := range ax.ticks {
x := axw * float32(ax.Norm(t.Value))
if x < 0 || x > axw || t.IsMinor() {
continue
}
ax.TickText.Text = t.Label
ax.TickText.Config(pt)
pos := ax.TickText.PosX(0)
pos.X += x + float32(ab.Min.X)
tickHt = ax.TickText.PaintText.BBox.Size().Y + ax.TickText.Style.Padding.Dots
pos.Y += float32(ab.Max.Y) - tickHt
ax.TickText.Draw(pt, pos)
}
if len(ax.ticks) > 0 {
ab.Max.Y -= int(math32.Ceil(tickHt))
// } else {
// y += ax.Width / 2
}
if len(ax.ticks) > 0 && ax.drawTicks() {
ln := ax.Style.TickLength.Dots
for _, t := range ax.ticks {
yoff := float32(0)
if t.IsMinor() {
yoff = 0.5 * ln
}
x := axw * float32(ax.Norm(t.Value))
if x < 0 || x > axw {
continue
}
x += float32(ab.Min.X)
ax.Style.TickLine.Draw(pt, math32.Vec2(x, float32(ab.Max.Y)-yoff), math32.Vec2(x, float32(ab.Max.Y)-ln))
}
ab.Max.Y -= int(ln - 0.5*ax.Style.Line.Width.Dots)
}
ax.Style.Line.Draw(pt, math32.Vec2(float32(ab.Min.X), float32(ab.Max.Y)), math32.Vec2(float32(ab.Min.X)+axw, float32(ab.Max.Y)))
}
// drawY draws the Y axis along the left side
func (ax *Axis) drawY(pt *Plot, tickWidth, tpad, bpad int) {
ab := pt.Paint.Bounds
ab.Min.Y += tpad
ab.Max.Y -= bpad
axh := float32(ab.Size().Y)
if ax.Label.Text != "" {
ax.Label.Style.Align = styles.Center
pos := ax.Label.PosY(axh)
tw := ax.Label.PaintText.BBox.Size().X
pos.Y += float32(ab.Min.Y) + ax.Label.PaintText.BBox.Size().Y
pos.X = float32(ab.Min.X)
ax.Label.Draw(pt, pos)
ab.Min.X += int(math32.Ceil(tw + ax.Label.Style.Padding.Dots))
}
tickWd := float32(0)
for _, t := range ax.ticks {
y := axh * (1 - float32(ax.Norm(t.Value)))
if y < 0 || y > axh || t.IsMinor() {
continue
}
ax.TickText.Text = t.Label
ax.TickText.Config(pt)
pos := ax.TickText.PosX(float32(tickWidth))
pos.X += float32(ab.Min.X)
pos.Y = float32(ab.Min.Y) + y - 0.5*ax.TickText.PaintText.BBox.Size().Y
tickWd = max(tickWd, ax.TickText.PaintText.BBox.Size().X+ax.TickText.Style.Padding.Dots)
ax.TickText.Draw(pt, pos)
}
if len(ax.ticks) > 0 {
ab.Min.X += int(math32.Ceil(tickWd))
// } else {
// y += ax.Width / 2
}
if len(ax.ticks) > 0 && ax.drawTicks() {
ln := ax.Style.TickLength.Dots
for _, t := range ax.ticks {
xoff := float32(0)
if t.IsMinor() {
xoff = 0.5 * ln
}
y := axh * (1 - float32(ax.Norm(t.Value)))
if y < 0 || y > axh {
continue
}
y += float32(ab.Min.Y)
ax.Style.TickLine.Draw(pt, math32.Vec2(float32(ab.Min.X)+xoff, y), math32.Vec2(float32(ab.Min.X)+ln, y))
}
ab.Min.X += int(ln + 0.5*ax.Style.Line.Width.Dots)
}
ax.Style.Line.Draw(pt, math32.Vec2(float32(ab.Min.X), float32(ab.Min.Y)), math32.Vec2(float32(ab.Min.X), float32(ab.Max.Y)))
}
////////////////////////////////////////////////
// Legend
// draw draws the legend
func (lg *Legend) draw(pt *Plot) {
pc := pt.Paint
uc := &pc.UnitContext
ptb := pc.Bounds
lg.Style.ThumbnailWidth.ToDots(uc)
lg.Style.Position.XOffs.ToDots(uc)
lg.Style.Position.YOffs.ToDots(uc)
var ltxt Text
ltxt.Defaults()
ltxt.Style = lg.Style.Text
ltxt.ToDots(uc)
pad := math32.Ceil(ltxt.Style.Padding.Dots)
ltxt.openFont(pt)
em := ltxt.font.Face.Metrics.Em
var sz image.Point
maxTht := 0
for _, e := range lg.Entries {
ltxt.Text = e.Text
ltxt.Config(pt)
sz.X = max(sz.X, int(math32.Ceil(ltxt.PaintText.BBox.Size().X)))
tht := int(math32.Ceil(ltxt.PaintText.BBox.Size().Y + pad))
maxTht = max(tht, maxTht)
}
sz.X += int(em)
sz.Y = len(lg.Entries) * maxTht
txsz := sz
sz.X += int(lg.Style.ThumbnailWidth.Dots)
pos := ptb.Min
if lg.Style.Position.Left {
pos.X += int(lg.Style.Position.XOffs.Dots)
} else {
pos.X = ptb.Max.X - sz.X - int(lg.Style.Position.XOffs.Dots)
}
if lg.Style.Position.Top {
pos.Y += int(lg.Style.Position.YOffs.Dots)
} else {
pos.Y = ptb.Max.Y - sz.Y - int(lg.Style.Position.YOffs.Dots)
}
if lg.Style.Fill != nil {
pc.FillBox(math32.FromPoint(pos), math32.FromPoint(sz), lg.Style.Fill)
}
cp := pos
thsz := image.Point{X: int(lg.Style.ThumbnailWidth.Dots), Y: maxTht - 2*int(pad)}
for _, e := range lg.Entries {
tp := cp
tp.X += int(txsz.X)
tp.Y += int(pad)
tb := image.Rectangle{Min: tp, Max: tp.Add(thsz)}
pc.PushBounds(tb)
for _, t := range e.Thumbs {
t.Thumbnail(pt)
}
pc.PopBounds()
ltxt.Text = e.Text
ltxt.Config(pt)
ltxt.Draw(pt, math32.FromPoint(cp))
cp.Y += maxTht
}
}
// Code generated by "core generate -add-types"; DO NOT EDIT.
package plot
import (
"cogentcore.org/core/enums"
)
var _AxisScalesValues = []AxisScales{0, 1, 2, 3}
// AxisScalesN is the highest valid value for type AxisScales, plus one.
const AxisScalesN AxisScales = 4
var _AxisScalesValueMap = map[string]AxisScales{`Linear`: 0, `Log`: 1, `InverseLinear`: 2, `InverseLog`: 3}
var _AxisScalesDescMap = map[AxisScales]string{0: `Linear is a linear axis scale.`, 1: `Log is a Logarithmic axis scale.`, 2: `InverseLinear is an inverted linear axis scale.`, 3: `InverseLog is an inverted log axis scale.`}
var _AxisScalesMap = map[AxisScales]string{0: `Linear`, 1: `Log`, 2: `InverseLinear`, 3: `InverseLog`}
// String returns the string representation of this AxisScales value.
func (i AxisScales) String() string { return enums.String(i, _AxisScalesMap) }
// SetString sets the AxisScales value from its string representation,
// and returns an error if the string is invalid.
func (i *AxisScales) SetString(s string) error {
return enums.SetString(i, s, _AxisScalesValueMap, "AxisScales")
}
// Int64 returns the AxisScales value as an int64.
func (i AxisScales) Int64() int64 { return int64(i) }
// SetInt64 sets the AxisScales value from an int64.
func (i *AxisScales) SetInt64(in int64) { *i = AxisScales(in) }
// Desc returns the description of the AxisScales value.
func (i AxisScales) Desc() string { return enums.Desc(i, _AxisScalesDescMap) }
// AxisScalesValues returns all possible values for the type AxisScales.
func AxisScalesValues() []AxisScales { return _AxisScalesValues }
// Values returns all possible values for the type AxisScales.
func (i AxisScales) Values() []enums.Enum { return enums.Values(_AxisScalesValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i AxisScales) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *AxisScales) UnmarshalText(text []byte) error {
return enums.UnmarshalText(i, text, "AxisScales")
}
var _RolesValues = []Roles{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}
// RolesN is the highest valid value for type Roles, plus one.
const RolesN Roles = 12
var _RolesValueMap = map[string]Roles{`NoRole`: 0, `X`: 1, `Y`: 2, `Z`: 3, `U`: 4, `V`: 5, `W`: 6, `Low`: 7, `High`: 8, `Size`: 9, `Color`: 10, `Label`: 11}
var _RolesDescMap = map[Roles]string{0: `NoRole is the default no-role specified case.`, 1: `X axis`, 2: `Y axis`, 3: `Z axis`, 4: `U is the X component of a vector or first quartile in Box plot, etc.`, 5: `V is the Y component of a vector or third quartile in a Box plot, etc.`, 6: `W is the Z component of a vector`, 7: `Low is a lower error bar or region.`, 8: `High is an upper error bar or region.`, 9: `Size controls the size of points etc.`, 10: `Color controls the color of points or other elements.`, 11: `Label renders a label, typically from string data, but can also be used for values.`}
var _RolesMap = map[Roles]string{0: `NoRole`, 1: `X`, 2: `Y`, 3: `Z`, 4: `U`, 5: `V`, 6: `W`, 7: `Low`, 8: `High`, 9: `Size`, 10: `Color`, 11: `Label`}
// String returns the string representation of this Roles value.
func (i Roles) String() string { return enums.String(i, _RolesMap) }
// SetString sets the Roles value from its string representation,
// and returns an error if the string is invalid.
func (i *Roles) SetString(s string) error { return enums.SetString(i, s, _RolesValueMap, "Roles") }
// Int64 returns the Roles value as an int64.
func (i Roles) Int64() int64 { return int64(i) }
// SetInt64 sets the Roles value from an int64.
func (i *Roles) SetInt64(in int64) { *i = Roles(in) }
// Desc returns the description of the Roles value.
func (i Roles) Desc() string { return enums.Desc(i, _RolesDescMap) }
// RolesValues returns all possible values for the type Roles.
func RolesValues() []Roles { return _RolesValues }
// Values returns all possible values for the type Roles.
func (i Roles) Values() []enums.Enum { return enums.Values(_RolesValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i Roles) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *Roles) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Roles") }
var _StepKindValues = []StepKind{0, 1, 2, 3}
// StepKindN is the highest valid value for type StepKind, plus one.
const StepKindN StepKind = 4
var _StepKindValueMap = map[string]StepKind{`NoStep`: 0, `PreStep`: 1, `MidStep`: 2, `PostStep`: 3}
var _StepKindDescMap = map[StepKind]string{0: `NoStep connects two points by simple line.`, 1: `PreStep connects two points by following lines: vertical, horizontal.`, 2: `MidStep connects two points by following lines: horizontal, vertical, horizontal. Vertical line is placed in the middle of the interval.`, 3: `PostStep connects two points by following lines: horizontal, vertical.`}
var _StepKindMap = map[StepKind]string{0: `NoStep`, 1: `PreStep`, 2: `MidStep`, 3: `PostStep`}
// String returns the string representation of this StepKind value.
func (i StepKind) String() string { return enums.String(i, _StepKindMap) }
// SetString sets the StepKind value from its string representation,
// and returns an error if the string is invalid.
func (i *StepKind) SetString(s string) error {
return enums.SetString(i, s, _StepKindValueMap, "StepKind")
}
// Int64 returns the StepKind value as an int64.
func (i StepKind) Int64() int64 { return int64(i) }
// SetInt64 sets the StepKind value from an int64.
func (i *StepKind) SetInt64(in int64) { *i = StepKind(in) }
// Desc returns the description of the StepKind value.
func (i StepKind) Desc() string { return enums.Desc(i, _StepKindDescMap) }
// StepKindValues returns all possible values for the type StepKind.
func StepKindValues() []StepKind { return _StepKindValues }
// Values returns all possible values for the type StepKind.
func (i StepKind) Values() []enums.Enum { return enums.Values(_StepKindValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i StepKind) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *StepKind) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "StepKind") }
var _ShapesValues = []Shapes{0, 1, 2, 3, 4, 5, 6, 7}
// ShapesN is the highest valid value for type Shapes, plus one.
const ShapesN Shapes = 8
var _ShapesValueMap = map[string]Shapes{`Ring`: 0, `Circle`: 1, `Square`: 2, `Box`: 3, `Triangle`: 4, `Pyramid`: 5, `Plus`: 6, `Cross`: 7}
var _ShapesDescMap = map[Shapes]string{0: `Ring is the outline of a circle`, 1: `Circle is a solid circle`, 2: `Square is the outline of a square`, 3: `Box is a filled square`, 4: `Triangle is the outline of a triangle`, 5: `Pyramid is a filled triangle`, 6: `Plus is a plus sign`, 7: `Cross is a big X`}
var _ShapesMap = map[Shapes]string{0: `Ring`, 1: `Circle`, 2: `Square`, 3: `Box`, 4: `Triangle`, 5: `Pyramid`, 6: `Plus`, 7: `Cross`}
// String returns the string representation of this Shapes value.
func (i Shapes) String() string { return enums.String(i, _ShapesMap) }
// SetString sets the Shapes value from its string representation,
// and returns an error if the string is invalid.
func (i *Shapes) SetString(s string) error { return enums.SetString(i, s, _ShapesValueMap, "Shapes") }
// Int64 returns the Shapes value as an int64.
func (i Shapes) Int64() int64 { return int64(i) }
// SetInt64 sets the Shapes value from an int64.
func (i *Shapes) SetInt64(in int64) { *i = Shapes(in) }
// Desc returns the description of the Shapes value.
func (i Shapes) Desc() string { return enums.Desc(i, _ShapesDescMap) }
// ShapesValues returns all possible values for the type Shapes.
func ShapesValues() []Shapes { return _ShapesValues }
// Values returns all possible values for the type Shapes.
func (i Shapes) Values() []enums.Enum { return enums.Values(_ShapesValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i Shapes) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *Shapes) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Shapes") }
var _DefaultOffOnValues = []DefaultOffOn{0, 1, 2}
// DefaultOffOnN is the highest valid value for type DefaultOffOn, plus one.
const DefaultOffOnN DefaultOffOn = 3
var _DefaultOffOnValueMap = map[string]DefaultOffOn{`Default`: 0, `Off`: 1, `On`: 2}
var _DefaultOffOnDescMap = map[DefaultOffOn]string{0: `Default means use the default value.`, 1: `Off means to override the default and turn Off.`, 2: `On means to override the default and turn On.`}
var _DefaultOffOnMap = map[DefaultOffOn]string{0: `Default`, 1: `Off`, 2: `On`}
// String returns the string representation of this DefaultOffOn value.
func (i DefaultOffOn) String() string { return enums.String(i, _DefaultOffOnMap) }
// SetString sets the DefaultOffOn value from its string representation,
// and returns an error if the string is invalid.
func (i *DefaultOffOn) SetString(s string) error {
return enums.SetString(i, s, _DefaultOffOnValueMap, "DefaultOffOn")
}
// Int64 returns the DefaultOffOn value as an int64.
func (i DefaultOffOn) Int64() int64 { return int64(i) }
// SetInt64 sets the DefaultOffOn value from an int64.
func (i *DefaultOffOn) SetInt64(in int64) { *i = DefaultOffOn(in) }
// Desc returns the description of the DefaultOffOn value.
func (i DefaultOffOn) Desc() string { return enums.Desc(i, _DefaultOffOnDescMap) }
// DefaultOffOnValues returns all possible values for the type DefaultOffOn.
func DefaultOffOnValues() []DefaultOffOn { return _DefaultOffOnValues }
// Values returns all possible values for the type DefaultOffOn.
func (i DefaultOffOn) Values() []enums.Enum { return enums.Values(_DefaultOffOnValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i DefaultOffOn) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *DefaultOffOn) UnmarshalText(text []byte) error {
return enums.UnmarshalText(i, text, "DefaultOffOn")
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Copied directly from gonum/plot:
// Copyright ©2017 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This is an implementation of the Talbot, Lin and Hanrahan algorithm
// described in doi:10.1109/TVCG.2010.130 with reference to the R
// implementation in the labeling package, ©2014 Justin Talbot (Licensed
// MIT+file LICENSE|Unlimited).
package plot
import "math"
const (
// dlamchE is the machine epsilon. For IEEE this is 2^{-53}.
dlamchE = 1.0 / (1 << 53)
// dlamchB is the radix of the machine (the base of the number system).
dlamchB = 2
// dlamchP is base * eps.
dlamchP = dlamchB * dlamchE
)
const (
// free indicates no restriction on label containment.
free = iota
// containData specifies that all the data range lies
// within the interval [label_min, label_max].
containData
// withinData specifies that all labels lie within the
// interval [dMin, dMax].
withinData
)
// talbotLinHanrahan returns an optimal set of approximately want label values
// for the data range [dMin, dMax], and the step and magnitude of the step between values.
// containment is specifies are guarantees for label and data range containment, valid
// values are free, containData and withinData.
// The optional parameters Q, nice numbers, and w, weights, allow tuning of the
// algorithm but by default (when nil) are set to the parameters described in the
// paper.
// The legibility function allows tuning of the legibility assessment for labels.
// By default, when nil, legbility will set the legibility score for each candidate
// labelling scheme to 1.
// See the paper for an explanation of the function of Q, w and legibility.
func talbotLinHanrahan(dMin, dMax float64, want int, containment int, Q []float64, w *weights, legibility func(lMin, lMax, lStep float64) float64) (values []float64, step, q float64, magnitude int) {
const eps = dlamchP * 100
if dMin > dMax {
panic("labelling: invalid data range: min greater than max")
}
if Q == nil {
Q = []float64{1, 5, 2, 2.5, 4, 3}
}
if w == nil {
w = &weights{
simplicity: 0.25,
coverage: 0.2,
density: 0.5,
legibility: 0.05,
}
}
if legibility == nil {
legibility = unitLegibility
}
if r := dMax - dMin; r < eps {
l := make([]float64, want)
step := r / float64(want-1)
for i := range l {
l[i] = dMin + float64(i)*step
}
magnitude = minAbsMag(dMin, dMax)
return l, step, 0, magnitude
}
type selection struct {
// n is the number of labels selected.
n int
// lMin and lMax are the selected min
// and max label values. lq is the q
// chosen.
lMin, lMax, lStep, lq float64
// score is the score for the selection.
score float64
// magnitude is the magnitude of the
// label step distance.
magnitude int
}
best := selection{score: -2}
outer:
for skip := 1; ; skip++ {
for _, q := range Q {
sm := maxSimplicity(q, Q, skip)
if w.score(sm, 1, 1, 1) < best.score {
break outer
}
for have := 2; ; have++ {
dm := maxDensity(have, want)
if w.score(sm, 1, dm, 1) < best.score {
break
}
delta := (dMax - dMin) / float64(have+1) / float64(skip) / q
const maxExp = 309
for mag := int(math.Ceil(math.Log10(delta))); mag < maxExp; mag++ {
step := float64(skip) * q * math.Pow10(mag)
cm := maxCoverage(dMin, dMax, step*float64(have-1))
if w.score(sm, cm, dm, 1) < best.score {
break
}
fracStep := step / float64(skip)
kStep := step * float64(have-1)
minStart := (math.Floor(dMax/step) - float64(have-1)) * float64(skip)
maxStart := math.Ceil(dMax/step) * float64(skip)
for start := minStart; start <= maxStart && start != start-1; start++ {
lMin := start * fracStep
lMax := lMin + kStep
switch containment {
case containData:
if dMin < lMin || lMax < dMax {
continue
}
case withinData:
if lMin < dMin || dMax < lMax {
continue
}
case free:
// Free choice.
}
score := w.score(
simplicity(q, Q, skip, lMin, lMax, step),
coverage(dMin, dMax, lMin, lMax),
density(have, want, dMin, dMax, lMin, lMax),
legibility(lMin, lMax, step),
)
if score > best.score {
best = selection{
n: have,
lMin: lMin,
lMax: lMax,
lStep: float64(skip) * q,
lq: q,
score: score,
magnitude: mag,
}
}
}
}
}
}
}
if best.score == -2 {
l := make([]float64, want)
step := (dMax - dMin) / float64(want-1)
for i := range l {
l[i] = dMin + float64(i)*step
}
magnitude = minAbsMag(dMin, dMax)
return l, step, 0, magnitude
}
l := make([]float64, best.n)
step = best.lStep * math.Pow10(best.magnitude)
for i := range l {
l[i] = best.lMin + float64(i)*step
}
return l, best.lStep, best.lq, best.magnitude
}
// minAbsMag returns the minumum magnitude of the absolute values of a and b.
func minAbsMag(a, b float64) int {
return int(math.Min(math.Floor(math.Log10(math.Abs(a))), (math.Floor(math.Log10(math.Abs(b))))))
}
// simplicity returns the simplicity score for how will the curent q, lMin, lMax,
// lStep and skip match the given nice numbers, Q.
func simplicity(q float64, Q []float64, skip int, lMin, lMax, lStep float64) float64 {
const eps = dlamchP * 100
for i, v := range Q {
if v == q {
m := math.Mod(lMin, lStep)
v = 0
if (m < eps || lStep-m < eps) && lMin <= 0 && 0 <= lMax {
v = 1
}
return 1 - float64(i)/(float64(len(Q))-1) - float64(skip) + v
}
}
panic("labelling: invalid q for Q")
}
// maxSimplicity returns the maximum simplicity for q, Q and skip.
func maxSimplicity(q float64, Q []float64, skip int) float64 {
for i, v := range Q {
if v == q {
return 1 - float64(i)/(float64(len(Q))-1) - float64(skip) + 1
}
}
panic("labelling: invalid q for Q")
}
// coverage returns the coverage score for based on the average
// squared distance between the extreme labels, lMin and lMax, and
// the extreme data points, dMin and dMax.
func coverage(dMin, dMax, lMin, lMax float64) float64 {
r := 0.1 * (dMax - dMin)
max := dMax - lMax
min := dMin - lMin
return 1 - 0.5*(max*max+min*min)/(r*r)
}
// maxCoverage returns the maximum coverage achievable for the data
// range.
func maxCoverage(dMin, dMax, span float64) float64 {
r := dMax - dMin
if span <= r {
return 1
}
h := 0.5 * (span - r)
r *= 0.1
return 1 - (h*h)/(r*r)
}
// density returns the density score which measures the goodness of
// the labelling density compared to the user defined target
// based on the want parameter given to talbotLinHanrahan.
func density(have, want int, dMin, dMax, lMin, lMax float64) float64 {
rho := float64(have-1) / (lMax - lMin)
rhot := float64(want-1) / (math.Max(lMax, dMax) - math.Min(dMin, lMin))
if d := rho / rhot; d >= 1 {
return 2 - d
}
return 2 - rhot/rho
}
// maxDensity returns the maximum density score achievable for have and want.
func maxDensity(have, want int) float64 {
if have < want {
return 1
}
return 2 - float64(have-1)/float64(want-1)
}
// unitLegibility returns a default legibility score ignoring label
// spacing.
func unitLegibility(_, _, _ float64) float64 {
return 1
}
// weights is a helper type to calcuate the labelling scheme's total score.
type weights struct {
simplicity, coverage, density, legibility float64
}
// score returns the score for a labelling scheme with simplicity, s,
// coverage, c, density, d and legibility l.
func (w *weights) score(s, c, d, l float64) float64 {
return w.simplicity*s + w.coverage*c + w.density*d + w.legibility*l
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
import (
"image"
"cogentcore.org/core/colors"
"cogentcore.org/core/colors/gradient"
"cogentcore.org/core/styles/units"
)
// LegendStyle has the styling properties for the Legend.
type LegendStyle struct { //types:add -setters
// Column is for table-based plotting, specifying the column with legend values.
Column string
// Text is the style given to the legend entry texts.
Text TextStyle `display:"add-fields"`
// position of the legend
Position LegendPosition `display:"inline"`
// ThumbnailWidth is the width of legend thumbnails.
ThumbnailWidth units.Value `display:"inline"`
// Fill specifies the background fill color for the legend box,
// if non-nil.
Fill image.Image
}
func (ls *LegendStyle) Defaults() {
ls.Text.Defaults()
ls.Text.Padding.Dp(2)
ls.Text.Size.Dp(20)
ls.Position.Defaults()
ls.ThumbnailWidth.Pt(20)
ls.Fill = gradient.ApplyOpacity(colors.Scheme.Surface, 0.75)
}
// LegendPosition specifies where to put the legend
type LegendPosition struct {
// Top and Left specify the location of the legend.
Top, Left bool
// XOffs and YOffs are added to the legend's final position,
// relative to the relevant anchor position
XOffs, YOffs units.Value
}
func (lg *LegendPosition) Defaults() {
lg.Top = true
}
// A Legend gives a description of the meaning of different
// data elements of the plot. Each legend entry has a name
// and a thumbnail, where the thumbnail shows a small
// sample of the display style of the corresponding data.
type Legend struct {
// Style has the legend styling parameters.
Style LegendStyle
// Entries are all of the LegendEntries described by this legend.
Entries []LegendEntry
}
func (lg *Legend) Defaults() {
lg.Style.Defaults()
}
// Add adds an entry to the legend with the given name.
// The entry's thumbnail is drawn as the composite of all of the
// thumbnails.
func (lg *Legend) Add(name string, thumbs ...Thumbnailer) {
lg.Entries = append(lg.Entries, LegendEntry{Text: name, Thumbs: thumbs})
}
// LegendForPlotter returns the legend Text for given plotter,
// if it exists as a Thumbnailer in the legend entries.
// Otherwise returns empty string.
func (lg *Legend) LegendForPlotter(plt Plotter) string {
for _, e := range lg.Entries {
for _, tn := range e.Thumbs {
if tp, isp := tn.(Plotter); isp && tp == plt {
return e.Text
}
}
}
return ""
}
// Thumbnailer wraps the Thumbnail method, which draws the small
// image in a legend representing the style of data.
type Thumbnailer interface {
// Thumbnail draws an thumbnail representing a legend entry.
// The thumbnail will usually show a smaller representation
// of the style used to plot the corresponding data.
Thumbnail(pt *Plot)
}
// A LegendEntry represents a single line of a legend, it
// has a name and an icon.
type LegendEntry struct {
// text is the text associated with this entry.
Text string
// thumbs is a slice of all of the thumbnails styles
Thumbs []Thumbnailer
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
import (
"image"
"cogentcore.org/core/colors"
"cogentcore.org/core/math32"
"cogentcore.org/core/styles/units"
)
// LineStyle has style properties for drawing lines.
type LineStyle struct { //types:add -setters
// On indicates whether to plot lines.
On DefaultOffOn
// Color is the stroke color image specification.
// Setting to nil turns line off.
Color image.Image
// Width is the line width, with a default of 1 Pt (point).
// Setting to 0 turns line off.
Width units.Value
// Dashes are the dashes of the stroke. Each pair of values specifies
// the amount to paint and then the amount to skip.
Dashes []float32
// Fill is the color to fill solid regions, in a plot-specific
// way (e.g., the area below a Line plot, the bar color).
// Use nil to disable filling.
Fill image.Image
// NegativeX specifies whether to draw lines that connect points with a negative
// X-axis direction; otherwise there is a break in the line.
// default is false, so that repeated series of data across the X axis
// are plotted separately.
NegativeX bool
// Step specifies how to step the line between points.
Step StepKind
}
func (ls *LineStyle) Defaults() {
ls.Color = colors.Scheme.OnSurface
ls.Fill = colors.Uniform(colors.Transparent)
ls.Width.Pt(1)
}
// SetStroke sets the stroke style in plot paint to current line style.
// returns false if either the Width = 0 or Color is nil
func (ls *LineStyle) SetStroke(pt *Plot) bool {
if ls.On == Off || ls.Color == nil {
return false
}
pc := pt.Paint
uc := &pc.UnitContext
ls.Width.ToDots(uc)
if ls.Width.Dots == 0 {
return false
}
pc.StrokeStyle.Width = ls.Width
pc.StrokeStyle.Color = ls.Color
pc.StrokeStyle.ToDots(uc)
return true
}
func (ls *LineStyle) HasFill() bool {
if ls.Fill == nil {
return false
}
clr := colors.ToUniform(ls.Fill)
if clr == colors.Transparent {
return false
}
return true
}
// Draw draws a line between given coordinates, setting the stroke style
// to current parameters. Returns false if either Width = 0 or Color = nil
func (ls *LineStyle) Draw(pt *Plot, start, end math32.Vector2) bool {
if !ls.SetStroke(pt) {
return false
}
pc := pt.Paint
pc.MoveTo(start.X, start.Y)
pc.LineTo(end.X, end.Y)
pc.Stroke()
return true
}
// StepKind specifies a form of a connection of two consecutive points.
type StepKind int32 //enums:enum
const (
// NoStep connects two points by simple line.
NoStep StepKind = iota
// PreStep connects two points by following lines: vertical, horizontal.
PreStep
// MidStep connects two points by following lines: horizontal, vertical, horizontal.
// Vertical line is placed in the middle of the interval.
MidStep
// PostStep connects two points by following lines: horizontal, vertical.
PostStep
)
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Adapted from github.com/gonum/plot:
// Copyright ©2015 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
//go:generate core generate -add-types
import (
"image"
"cogentcore.org/core/base/iox/imagex"
"cogentcore.org/core/colors"
"cogentcore.org/core/math32"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/core/paint"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/units"
)
// XAxisStyle has overall plot level styling properties for the XAxis.
type XAxisStyle struct { //types:add -setters
// Column specifies the column to use for the common X axis,
// for [plot.NewTablePlot] [table.Table] driven plots.
// If empty, standard Group-based role binding is used: the last column
// within the same group with Role=X is used.
Column string
// Rotation is the rotation of the X Axis labels, in degrees.
Rotation float32
// Label is the optional label to use for the XAxis instead of the default.
Label string
// Range is the effective range of XAxis data to plot, where either end can be fixed.
Range minmax.Range64 `display:"inline"`
// Scale specifies how values are scaled along the X axis:
// Linear, Log, Inverted
Scale AxisScales
}
// PlotStyle has overall plot level styling properties.
// Some properties provide defaults for individual elements, which can
// then be overwritten by element-level properties.
type PlotStyle struct { //types:add -setters
// Title is the overall title of the plot.
Title string
// TitleStyle is the text styling parameters for the title.
TitleStyle TextStyle
// Background is the background of the plot.
// The default is [colors.Scheme.Surface].
Background image.Image
// Scale multiplies the plot DPI value, to change the overall scale
// of the rendered plot. Larger numbers produce larger scaling.
// Typically use larger numbers when generating plots for inclusion in
// documents or other cases where the overall plot size will be small.
Scale float32 `default:"1,2"`
// Legend has the styling properties for the Legend.
Legend LegendStyle `display:"add-fields"`
// Axis has the styling properties for the Axes.
Axis AxisStyle `display:"add-fields"`
// XAxis has plot-level XAxis style properties.
XAxis XAxisStyle `display:"add-fields"`
// YAxisLabel is the optional label to use for the YAxis instead of the default.
YAxisLabel string
// LinesOn determines whether lines are plotted by default,
// for elements that plot lines (e.g., plots.XY).
LinesOn DefaultOffOn
// LineWidth sets the default line width for data plotting lines.
LineWidth units.Value
// PointsOn determines whether points are plotted by default,
// for elements that plot points (e.g., plots.XY).
PointsOn DefaultOffOn
// PointSize sets the default point size.
PointSize units.Value
// LabelSize sets the default label text size.
LabelSize units.Value
// BarWidth for Bar plot sets the default width of the bars,
// which should be less than the Stride (1 typically) to prevent
// bar overlap. Defaults to .8.
BarWidth float64
}
func (ps *PlotStyle) Defaults() {
ps.TitleStyle.Defaults()
ps.TitleStyle.Size.Dp(24)
ps.Background = colors.Scheme.Surface
ps.Scale = 1
ps.Legend.Defaults()
ps.Axis.Defaults()
ps.LineWidth.Pt(1)
ps.PointSize.Pt(4)
ps.LabelSize.Dp(16)
ps.BarWidth = .8
}
// SetElementStyle sets the properties for given element's style
// based on the global default settings in this PlotStyle.
func (ps *PlotStyle) SetElementStyle(es *Style) {
if ps.LinesOn != Default {
es.Line.On = ps.LinesOn
}
if ps.PointsOn != Default {
es.Point.On = ps.PointsOn
}
es.Line.Width = ps.LineWidth
es.Point.Size = ps.PointSize
es.Width.Width = ps.BarWidth
es.Text.Size = ps.LabelSize
}
// PanZoom provides post-styling pan and zoom range manipulation.
type PanZoom struct {
// XOffset adds offset to X range (pan).
XOffset float64
// XScale multiplies X range (zoom).
XScale float64
// YOffset adds offset to Y range (pan).
YOffset float64
// YScale multiplies Y range (zoom).
YScale float64
}
func (pz *PanZoom) Defaults() {
pz.XScale = 1
pz.YScale = 1
}
// Plot is the basic type representing a plot.
// It renders into its own image.RGBA Pixels image,
// and can also save a corresponding SVG version.
type Plot struct {
// Title of the plot
Title Text
// Style has the styling properties for the plot.
Style PlotStyle
// standard text style with default options
StandardTextStyle styles.Text
// X, Y, and Z are the horizontal, vertical, and depth axes
// of the plot respectively.
X, Y, Z Axis
// Legend is the plot's legend.
Legend Legend
// Plotters are drawn by calling their Plot method after the axes are drawn.
Plotters []Plotter
// Size is the target size of the image to render to.
Size image.Point
// DPI is the dots per inch for rendering the image.
// Larger numbers result in larger scaling of the plot contents
// which is strongly recommended for print (e.g., use 300 for print)
DPI float32 `default:"96,160,300"`
// PanZoom provides post-styling pan and zoom range factors.
PanZoom PanZoom
// HighlightPlotter is the Plotter to highlight. Used for mouse hovering for example.
// It is the responsibility of the Plotter Plot function to implement highlighting.
HighlightPlotter Plotter
// HighlightIndex is the index of the data point to highlight, for HighlightPlotter.
HighlightIndex int
// pixels that we render into
Pixels *image.RGBA `copier:"-" json:"-" xml:"-" edit:"-"`
// Paint is the painter for rendering
Paint *paint.Context
// Current plot bounding box in image coordinates, for plotting coordinates
PlotBox math32.Box2
}
// New returns a new plot with some reasonable default settings.
func New() *Plot {
pt := &Plot{}
pt.Defaults()
return pt
}
// Defaults sets defaults
func (pt *Plot) Defaults() {
pt.Style.Defaults()
pt.Title.Defaults()
pt.Title.Style.Size.Dp(24)
pt.X.Defaults(math32.X)
pt.Y.Defaults(math32.Y)
pt.Legend.Defaults()
pt.DPI = 96
pt.PanZoom.Defaults()
pt.Size = image.Point{1280, 1024}
pt.StandardTextStyle.Defaults()
pt.StandardTextStyle.WhiteSpace = styles.WhiteSpaceNowrap
}
// applyStyle applies all the style parameters
func (pt *Plot) applyStyle() {
// first update the global plot style settings
var st Style
st.Defaults()
st.Plot = pt.Style
for _, plt := range pt.Plotters {
stlr := plt.Stylers()
stlr.Run(&st)
}
pt.Style = st.Plot
// then apply to elements
for _, plt := range pt.Plotters {
plt.ApplyStyle(&pt.Style)
}
// now style plot:
pt.DPI *= pt.Style.Scale
pt.Title.Style = pt.Style.TitleStyle
if pt.Style.Title != "" {
pt.Title.Text = pt.Style.Title
}
pt.Legend.Style = pt.Style.Legend
pt.X.Style = pt.Style.Axis
pt.X.Style.Scale = pt.Style.XAxis.Scale
pt.Y.Style = pt.Style.Axis
if pt.Style.XAxis.Label != "" {
pt.X.Label.Text = pt.Style.XAxis.Label
}
if pt.Style.YAxisLabel != "" {
pt.Y.Label.Text = pt.Style.YAxisLabel
}
pt.X.Label.Style = pt.Style.Axis.Text
pt.Y.Label.Style = pt.Style.Axis.Text
pt.X.TickText.Style = pt.Style.Axis.TickText
pt.X.TickText.Style.Rotation = pt.Style.XAxis.Rotation
pt.Y.TickText.Style = pt.Style.Axis.TickText
pt.Y.Label.Style.Rotation = -90
pt.Y.Style.TickText.Align = styles.End
pt.UpdateRange()
}
// Add adds Plotter element(s) to the plot.
// When drawing the plot, Plotters are drawn in the
// order in which they were added to the plot.
func (pt *Plot) Add(ps ...Plotter) {
pt.Plotters = append(pt.Plotters, ps...)
}
// SetPixels sets the backing pixels image to given image.RGBA.
func (pt *Plot) SetPixels(img *image.RGBA) {
pt.Pixels = img
pt.Paint = paint.NewContextFromImage(pt.Pixels)
pt.Paint.UnitContext.DPI = pt.DPI
pt.Size = pt.Pixels.Bounds().Size()
}
// Resize sets the size of the output image to given size.
// Does nothing if already the right size.
func (pt *Plot) Resize(sz image.Point) {
if pt.Pixels != nil {
ib := pt.Pixels.Bounds().Size()
if ib == sz {
pt.Size = sz
pt.Paint.UnitContext.DPI = pt.DPI
return // already good
}
}
pt.SetPixels(image.NewRGBA(image.Rectangle{Max: sz}))
}
func (pt *Plot) SaveImage(filename string) error {
return imagex.Save(pt.Pixels, filename)
}
// NominalX configures the plot to have a nominal X
// axis—an X axis with names instead of numbers. The
// X location corresponding to each name are the integers,
// e.g., the x value 0 is centered above the first name and
// 1 is above the second name, etc. Labels for x values
// that do not end up in range of the X axis will not have
// tick marks.
func (pt *Plot) NominalX(names ...string) {
pt.X.Style.TickLine.Width.Pt(0)
pt.X.Style.TickLength.Pt(0)
pt.X.Style.Line.Width.Pt(0)
// pt.Y.Padding.Pt(pt.X.Style.Tick.Label.Width(names[0]) / 2)
ticks := make([]Tick, len(names))
for i, name := range names {
ticks[i] = Tick{float64(i), name}
}
pt.X.Ticker = ConstantTicks(ticks)
}
// HideX configures the X axis so that it will not be drawn.
func (pt *Plot) HideX() {
pt.X.Style.TickLength.Pt(0)
pt.X.Style.Line.Width.Pt(0)
pt.X.Ticker = ConstantTicks([]Tick{})
}
// HideY configures the Y axis so that it will not be drawn.
func (pt *Plot) HideY() {
pt.Y.Style.TickLength.Pt(0)
pt.Y.Style.Line.Width.Pt(0)
pt.Y.Ticker = ConstantTicks([]Tick{})
}
// HideAxes hides the X and Y axes.
func (pt *Plot) HideAxes() {
pt.HideX()
pt.HideY()
}
// NominalY is like NominalX, but for the Y axis.
func (pt *Plot) NominalY(names ...string) {
pt.Y.Style.TickLine.Width.Pt(0)
pt.Y.Style.TickLength.Pt(0)
pt.Y.Style.Line.Width.Pt(0)
// pt.X.Padding = pt.Y.Tick.Label.Height(names[0]) / 2
ticks := make([]Tick, len(names))
for i, name := range names {
ticks[i] = Tick{float64(i), name}
}
pt.Y.Ticker = ConstantTicks(ticks)
}
// UpdateRange updates the axis range values based on current Plot values.
// This first resets the range so any fixed additional range values should
// be set after this point.
func (pt *Plot) UpdateRange() {
pt.X.Range.SetInfinity()
pt.Y.Range.SetInfinity()
pt.Z.Range.SetInfinity()
if pt.Style.XAxis.Range.FixMin {
pt.X.Range.Min = pt.Style.XAxis.Range.Min
}
if pt.Style.XAxis.Range.FixMax {
pt.X.Range.Max = pt.Style.XAxis.Range.Max
}
for _, pl := range pt.Plotters {
pl.UpdateRange(pt, &pt.X.Range, &pt.Y.Range, &pt.Z.Range)
}
pt.X.Range.Sanitize()
pt.Y.Range.Sanitize()
pt.Z.Range.Sanitize()
pt.X.Range.Min *= pt.PanZoom.XScale
pt.X.Range.Max *= pt.PanZoom.XScale
pt.X.Range.Min += pt.PanZoom.XOffset
pt.X.Range.Max += pt.PanZoom.XOffset
pt.Y.Range.Min *= pt.PanZoom.YScale
pt.Y.Range.Max *= pt.PanZoom.YScale
pt.Y.Range.Min += pt.PanZoom.YOffset
pt.Y.Range.Max += pt.PanZoom.YOffset
}
// PX returns the X-axis plotting coordinate for given raw data value
// using the current plot bounding region
func (pt *Plot) PX(v float64) float32 {
return pt.PlotBox.ProjectX(float32(pt.X.Norm(v)))
}
// PY returns the Y-axis plotting coordinate for given raw data value
func (pt *Plot) PY(v float64) float32 {
return pt.PlotBox.ProjectY(float32(1 - pt.Y.Norm(v)))
}
// ClosestDataToPixel returns the Plotter data point closest to given pixel point,
// in the Pixels image.
func (pt *Plot) ClosestDataToPixel(px, py int) (plt Plotter, plotterIndex, pointIndex int, dist float32, pixel math32.Vector2, data Data, legend string) {
tp := math32.Vec2(float32(px), float32(py))
dist = float32(math32.MaxFloat32)
for pi, pl := range pt.Plotters {
dts, pxX, pxY := pl.Data()
if len(pxY) != len(pxX) {
continue
}
for i, ptx := range pxX {
pty := pxY[i]
pxy := math32.Vec2(ptx, pty)
d := pxy.DistanceTo(tp)
if d < dist {
dist = d
pixel = pxy
plt = pl
plotterIndex = pi
pointIndex = i
data = dts
legend = pt.Legend.LegendForPlotter(pl)
}
}
}
return
}
// Copyright (c) 2019, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This is copied and modified directly from gonum to add better error-bar
// plotting for bar plots, along with multiple groups.
// Copyright ©2015 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plots
import (
"math"
"cogentcore.org/core/math32"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/lab/plot"
)
// BarType is be used for specifying the type name.
const BarType = "Bar"
func init() {
plot.RegisterPlotter(BarType, "A Bar presents ordinally-organized data with rectangular bars with lengths proportional to the data values, and an optional error bar at the top of the bar using the High data role.", []plot.Roles{plot.Y}, []plot.Roles{plot.High}, func(data plot.Data) plot.Plotter {
return NewBar(data)
})
}
// A Bar presents ordinally-organized data with rectangular bars
// with lengths proportional to the data values, and an optional
// error bar ("handle") at the top of the bar using the High data role.
//
// Bars are plotted centered at integer multiples of Stride plus Start offset.
// Full data range also includes Pad value to extend range beyond edge bar centers.
// Bar Width is in data units, e.g., should be <= Stride.
// Defaults provide a unit-spaced plot.
type Bar struct {
// copies of data
Y, Err plot.Values
// actual plotting X, Y values in data coordinates, taking into account stacking etc.
X, Yp plot.Values
// PX, PY are the actual pixel plotting coordinates for each XY value.
PX, PY []float32
// Style has the properties used to render the bars.
Style plot.Style
// Horizontal dictates whether the bars should be in the vertical
// (default) or horizontal direction. If Horizontal is true, all
// X locations and distances referred to here will actually be Y
// locations and distances.
Horizontal bool
// stackedOn is the bar chart upon which this bar chart is stacked.
StackedOn *Bar
stylers plot.Stylers
}
// NewBar returns a new bar plotter with a single bar for each value.
// The bars heights correspond to the values and their x locations correspond
// to the index of their value in the Valuer.
// Optional error-bar values can be provided using the High data role.
// Styler functions are obtained from the Y metadata if present.
func NewBar(data plot.Data) *Bar {
if data.CheckLengths() != nil {
return nil
}
bc := &Bar{}
bc.Y = plot.MustCopyRole(data, plot.Y)
if bc.Y == nil {
return nil
}
bc.stylers = plot.GetStylersFromData(data, plot.Y)
bc.Err = plot.CopyRole(data, plot.High)
bc.Defaults()
return bc
}
func (bc *Bar) Defaults() {
bc.Style.Defaults()
}
func (bc *Bar) Styler(f func(s *plot.Style)) *Bar {
bc.stylers.Add(f)
return bc
}
func (bc *Bar) ApplyStyle(ps *plot.PlotStyle) {
ps.SetElementStyle(&bc.Style)
bc.stylers.Run(&bc.Style)
}
func (bc *Bar) Stylers() *plot.Stylers { return &bc.stylers }
func (bc *Bar) Data() (data plot.Data, pixX, pixY []float32) {
pixX = bc.PX
pixY = bc.PY
data = plot.Data{}
data[plot.X] = bc.X
data[plot.Y] = bc.Y
if bc.Err != nil {
data[plot.High] = bc.Err
}
return
}
// BarHeight returns the maximum y value of the
// ith bar, taking into account any bars upon
// which it is stacked.
func (bc *Bar) BarHeight(i int) float64 {
ht := float64(0.0)
if bc == nil {
return 0
}
if i >= 0 && i < len(bc.Y) {
ht += bc.Y[i]
}
if bc.StackedOn != nil {
ht += bc.StackedOn.BarHeight(i)
}
return ht
}
// StackOn stacks a bar chart on top of another,
// and sets the bar positioning options to that of the
// chart upon which it is being stacked.
func (bc *Bar) StackOn(on *Bar) {
bc.Style.Width = on.Style.Width
bc.StackedOn = on
}
// Plot implements the plot.Plotter interface.
func (bc *Bar) Plot(plt *plot.Plot) {
pc := plt.Paint
bc.Style.Line.SetStroke(plt)
pc.FillStyle.Color = bc.Style.Line.Fill
bw := bc.Style.Width
nv := len(bc.Y)
bc.X = make(plot.Values, nv)
bc.Yp = make(plot.Values, nv)
bc.PX = make([]float32, nv)
bc.PY = make([]float32, nv)
hw := 0.5 * bw.Width
ew := bw.Width / 3
for i, ht := range bc.Y {
cat := bw.Offset + float64(i)*bw.Stride
var bottom float64
var catVal, catMin, catMax, valMin, valMax float32
var box math32.Box2
if bc.Horizontal {
catVal = plt.PY(cat)
catMin = plt.PY(cat - hw)
catMax = plt.PY(cat + hw)
bottom = bc.StackedOn.BarHeight(i) // nil safe
valMin = plt.PX(bottom)
valMax = plt.PX(bottom + ht)
bc.X[i] = bottom + ht
bc.Yp[i] = cat
bc.PX[i] = valMax
bc.PY[i] = catVal
box.Min.Set(valMin, catMin)
box.Max.Set(valMax, catMax)
} else {
catVal = plt.PX(cat)
catMin = plt.PX(cat - hw)
catMax = plt.PX(cat + hw)
bottom = bc.StackedOn.BarHeight(i) // nil safe
valMin = plt.PY(bottom)
valMax = plt.PY(bottom + ht)
bc.X[i] = cat
bc.Yp[i] = bottom + ht
bc.PX[i] = catVal
bc.PY[i] = valMax
box.Min.Set(catMin, valMin)
box.Max.Set(catMax, valMax)
}
pc.DrawRectangle(box.Min.X, box.Min.Y, box.Size().X, box.Size().Y)
pc.FillStrokeClear()
if i < len(bc.Err) {
errval := math.Abs(bc.Err[i])
if bc.Horizontal {
eVal := plt.PX(bottom + ht + math.Abs(errval))
pc.MoveTo(valMax, catVal)
pc.LineTo(eVal, catVal)
pc.MoveTo(eVal, plt.PY(cat-ew))
pc.LineTo(eVal, plt.PY(cat+ew))
} else {
eVal := plt.PY(bottom + ht + math.Abs(errval))
pc.MoveTo(catVal, valMax)
pc.LineTo(catVal, eVal)
pc.MoveTo(plt.PX(cat-ew), eVal)
pc.LineTo(plt.PX(cat+ew), eVal)
}
pc.Stroke()
}
}
pc.FillStyle.Color = nil
}
// UpdateRange updates the given ranges.
func (bc *Bar) UpdateRange(plt *plot.Plot, xr, yr, zr *minmax.F64) {
bw := bc.Style.Width
catMin := bw.Offset - bw.Pad
catMax := bw.Offset + float64(len(bc.Y)-1)*bw.Stride + bw.Pad
for i, val := range bc.Y {
valBot := bc.StackedOn.BarHeight(i)
valTop := valBot + val
if i < len(bc.Err) {
valTop += math.Abs(bc.Err[i])
}
if bc.Horizontal {
xr.FitValInRange(valBot)
xr.FitValInRange(valTop)
} else {
yr.FitValInRange(valBot)
yr.FitValInRange(valTop)
}
}
if bc.Horizontal {
xr.Min, xr.Max = bc.Style.Range.Clamp(xr.Min, xr.Max)
yr.FitInRange(minmax.F64{catMin, catMax})
} else {
yr.Min, yr.Max = bc.Style.Range.Clamp(yr.Min, yr.Max)
xr.FitInRange(minmax.F64{catMin, catMax})
}
}
// Thumbnail fulfills the plot.Thumbnailer interface.
func (bc *Bar) Thumbnail(plt *plot.Plot) {
pc := plt.Paint
bc.Style.Line.SetStroke(plt)
pc.FillStyle.Color = bc.Style.Line.Fill
ptb := pc.Bounds
pc.DrawRectangle(float32(ptb.Min.X), float32(ptb.Min.Y), float32(ptb.Size().X), float32(ptb.Size().Y))
pc.FillStrokeClear()
pc.FillStyle.Color = nil
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plots
import (
"math"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/lab/plot"
)
const (
// YErrorBarsType is be used for specifying the type name.
YErrorBarsType = "YErrorBars"
// XErrorBarsType is be used for specifying the type name.
XErrorBarsType = "XErrorBars"
)
func init() {
plot.RegisterPlotter(YErrorBarsType, "draws draws vertical error bars, denoting error in Y values, using either High or Low & High data roles for error deviations around X, Y coordinates.", []plot.Roles{plot.X, plot.Y, plot.High}, []plot.Roles{plot.Low}, func(data plot.Data) plot.Plotter {
return NewYErrorBars(data)
})
plot.RegisterPlotter(XErrorBarsType, "draws draws horizontal error bars, denoting error in X values, using either High or Low & High data roles for error deviations around X, Y coordinates.", []plot.Roles{plot.X, plot.Y, plot.High}, []plot.Roles{plot.Low}, func(data plot.Data) plot.Plotter {
return NewXErrorBars(data)
})
}
// YErrorBars draws vertical error bars, denoting error in Y values,
// using ether High or Low, High data roles for error deviations
// around X, Y coordinates.
type YErrorBars struct {
// copies of data for this line
X, Y, Low, High plot.Values
// PX, PY are the actual pixel plotting coordinates for each XY value.
PX, PY []float32
// Style is the style for plotting.
Style plot.Style
stylers plot.Stylers
ystylers plot.Stylers
}
func (eb *YErrorBars) Defaults() {
eb.Style.Defaults()
}
// NewYErrorBars returns a new YErrorBars plotter,
// using Low, High data roles for error deviations around X, Y coordinates.
// Styler functions are obtained from the High data if present.
func NewYErrorBars(data plot.Data) *YErrorBars {
if data.CheckLengths() != nil {
return nil
}
eb := &YErrorBars{}
eb.X = plot.MustCopyRole(data, plot.X)
eb.Y = plot.MustCopyRole(data, plot.Y)
eb.Low = plot.CopyRole(data, plot.Low)
eb.High = plot.CopyRole(data, plot.High)
if eb.Low == nil && eb.High != nil {
eb.Low = eb.High
}
if eb.X == nil || eb.Y == nil || eb.Low == nil || eb.High == nil {
return nil
}
eb.stylers = plot.GetStylersFromData(data, plot.High)
eb.ystylers = plot.GetStylersFromData(data, plot.Y)
eb.Defaults()
return eb
}
// Styler adds a style function to set style parameters.
func (eb *YErrorBars) Styler(f func(s *plot.Style)) *YErrorBars {
eb.stylers.Add(f)
return eb
}
func (eb *YErrorBars) ApplyStyle(ps *plot.PlotStyle) {
ps.SetElementStyle(&eb.Style)
yst := &plot.Style{}
eb.ystylers.Run(yst)
eb.Style.Range = yst.Range // get range from y
eb.stylers.Run(&eb.Style)
}
func (eb *YErrorBars) Stylers() *plot.Stylers { return &eb.stylers }
func (eb *YErrorBars) Data() (data plot.Data, pixX, pixY []float32) {
pixX = eb.PX
pixY = eb.PY
data = plot.Data{}
data[plot.X] = eb.X
data[plot.Y] = eb.Y
data[plot.Low] = eb.Low
data[plot.High] = eb.High
return
}
func (eb *YErrorBars) Plot(plt *plot.Plot) {
pc := plt.Paint
uc := &pc.UnitContext
eb.Style.Width.Cap.ToDots(uc)
cw := 0.5 * eb.Style.Width.Cap.Dots
nv := len(eb.X)
eb.PX = make([]float32, nv)
eb.PY = make([]float32, nv)
eb.Style.Line.SetStroke(plt)
for i, y := range eb.Y {
x := plt.PX(eb.X.Float1D(i))
ylow := plt.PY(y - math.Abs(eb.Low[i]))
yhigh := plt.PY(y + math.Abs(eb.High[i]))
eb.PX[i] = x
eb.PY[i] = yhigh
pc.MoveTo(x, ylow)
pc.LineTo(x, yhigh)
pc.MoveTo(x-cw, ylow)
pc.LineTo(x+cw, ylow)
pc.MoveTo(x-cw, yhigh)
pc.LineTo(x+cw, yhigh)
pc.Stroke()
}
}
// UpdateRange updates the given ranges.
func (eb *YErrorBars) UpdateRange(plt *plot.Plot, xr, yr, zr *minmax.F64) {
plot.Range(eb.X, xr)
plot.RangeClamp(eb.Y, yr, &eb.Style.Range)
for i, y := range eb.Y {
ylow := y - math.Abs(eb.Low[i])
yhigh := y + math.Abs(eb.High[i])
yr.FitInRange(minmax.F64{ylow, yhigh})
}
return
}
//////// XErrorBars
// XErrorBars draws horizontal error bars, denoting error in X values,
// using ether High or Low, High data roles for error deviations
// around X, Y coordinates.
type XErrorBars struct {
// copies of data for this line
X, Y, Low, High plot.Values
// PX, PY are the actual pixel plotting coordinates for each XY value.
PX, PY []float32
// Style is the style for plotting.
Style plot.Style
stylers plot.Stylers
ystylers plot.Stylers
yrange minmax.Range64
}
func (eb *XErrorBars) Defaults() {
eb.Style.Defaults()
}
// NewXErrorBars returns a new XErrorBars plotter,
// using Low, High data roles for error deviations around X, Y coordinates.
func NewXErrorBars(data plot.Data) *XErrorBars {
if data.CheckLengths() != nil {
return nil
}
eb := &XErrorBars{}
eb.X = plot.MustCopyRole(data, plot.X)
eb.Y = plot.MustCopyRole(data, plot.Y)
eb.Low = plot.MustCopyRole(data, plot.Low)
eb.High = plot.MustCopyRole(data, plot.High)
eb.Low = plot.CopyRole(data, plot.Low)
eb.High = plot.CopyRole(data, plot.High)
if eb.Low == nil && eb.High != nil {
eb.Low = eb.High
}
if eb.X == nil || eb.Y == nil || eb.Low == nil || eb.High == nil {
return nil
}
eb.stylers = plot.GetStylersFromData(data, plot.High)
eb.ystylers = plot.GetStylersFromData(data, plot.Y)
eb.Defaults()
return eb
}
// Styler adds a style function to set style parameters.
func (eb *XErrorBars) Styler(f func(s *plot.Style)) *XErrorBars {
eb.stylers.Add(f)
return eb
}
func (eb *XErrorBars) ApplyStyle(ps *plot.PlotStyle) {
ps.SetElementStyle(&eb.Style)
yst := &plot.Style{}
eb.ystylers.Run(yst)
eb.yrange = yst.Range // get range from y
eb.stylers.Run(&eb.Style)
}
func (eb *XErrorBars) Stylers() *plot.Stylers { return &eb.stylers }
func (eb *XErrorBars) Data() (data plot.Data, pixX, pixY []float32) {
pixX = eb.PX
pixY = eb.PY
data = plot.Data{}
data[plot.X] = eb.X
data[plot.Y] = eb.Y
data[plot.Low] = eb.Low
data[plot.High] = eb.High
return
}
func (eb *XErrorBars) Plot(plt *plot.Plot) {
pc := plt.Paint
uc := &pc.UnitContext
eb.Style.Width.Cap.ToDots(uc)
cw := 0.5 * eb.Style.Width.Cap.Dots
nv := len(eb.X)
eb.PX = make([]float32, nv)
eb.PY = make([]float32, nv)
eb.Style.Line.SetStroke(plt)
for i, x := range eb.X {
y := plt.PY(eb.Y.Float1D(i))
xlow := plt.PX(x - math.Abs(eb.Low[i]))
xhigh := plt.PX(x + math.Abs(eb.High[i]))
eb.PX[i] = xhigh
eb.PY[i] = y
pc.MoveTo(xlow, y)
pc.LineTo(xhigh, y)
pc.MoveTo(xlow, y-cw)
pc.LineTo(xlow, y+cw)
pc.MoveTo(xhigh, y-cw)
pc.LineTo(xhigh, y+cw)
pc.Stroke()
}
}
// UpdateRange updates the given ranges.
func (eb *XErrorBars) UpdateRange(plt *plot.Plot, xr, yr, zr *minmax.F64) {
plot.RangeClamp(eb.X, xr, &eb.Style.Range)
plot.RangeClamp(eb.Y, yr, &eb.yrange)
for i, xv := range eb.X {
xlow := xv - math.Abs(eb.Low[i])
xhigh := xv + math.Abs(eb.High[i])
xr.FitInRange(minmax.F64{xlow, xhigh})
}
return
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plots
import (
"image"
"cogentcore.org/core/math32"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/lab/plot"
)
// LabelsType is be used for specifying the type name.
const LabelsType = "Labels"
func init() {
plot.RegisterPlotter(LabelsType, "draws text labels at specified X, Y points.", []plot.Roles{plot.X, plot.Y, plot.Label}, []plot.Roles{}, func(data plot.Data) plot.Plotter {
return NewLabels(data)
})
}
// Labels draws text labels at specified X, Y points.
type Labels struct {
// copies of data for this line
X, Y plot.Values
Labels plot.Labels
// PX, PY are the actual pixel plotting coordinates for each XY value.
PX, PY []float32
// Style is the style of the label text.
Style plot.Style
// plot size and number of TextStyle when styles last generated -- don't regen
styleSize image.Point
stylers plot.Stylers
ystylers plot.Stylers
}
// NewLabels returns a new Labels using defaults
// Styler functions are obtained from the Label metadata if present.
func NewLabels(data plot.Data) *Labels {
if data.CheckLengths() != nil {
return nil
}
lb := &Labels{}
lb.X = plot.MustCopyRole(data, plot.X)
lb.Y = plot.MustCopyRole(data, plot.Y)
if lb.X == nil || lb.Y == nil {
return nil
}
ld := data[plot.Label]
if ld == nil {
return nil
}
lb.Labels = make(plot.Labels, lb.X.Len())
for i := range ld.Len() {
lb.Labels[i] = ld.String1D(i)
}
lb.stylers = plot.GetStylersFromData(data, plot.Label)
lb.ystylers = plot.GetStylersFromData(data, plot.Y)
lb.Defaults()
return lb
}
func (lb *Labels) Defaults() {
lb.Style.Defaults()
}
// Styler adds a style function to set style parameters.
func (lb *Labels) Styler(f func(s *plot.Style)) *Labels {
lb.stylers.Add(f)
return lb
}
func (lb *Labels) ApplyStyle(ps *plot.PlotStyle) {
ps.SetElementStyle(&lb.Style)
yst := &plot.Style{}
lb.ystylers.Run(yst)
lb.Style.Range = yst.Range // get range from y
lb.stylers.Run(&lb.Style) // can still override here
}
func (lb *Labels) Stylers() *plot.Stylers { return &lb.stylers }
func (lb *Labels) Data() (data plot.Data, pixX, pixY []float32) {
pixX = lb.PX
pixY = lb.PY
data = plot.Data{}
data[plot.X] = lb.X
data[plot.Y] = lb.Y
data[plot.Label] = lb.Labels
return
}
// Plot implements the Plotter interface, drawing labels.
func (lb *Labels) Plot(plt *plot.Plot) {
pc := plt.Paint
uc := &pc.UnitContext
lb.PX = plot.PlotX(plt, lb.X)
lb.PY = plot.PlotY(plt, lb.Y)
st := &lb.Style.Text
st.Offset.ToDots(uc)
var ltxt plot.Text
ltxt.Defaults()
ltxt.Style = *st
ltxt.ToDots(uc)
for i, label := range lb.Labels {
if label == "" {
continue
}
ltxt.Text = label
ltxt.Config(plt)
tht := ltxt.PaintText.BBox.Size().Y
ltxt.Draw(plt, math32.Vec2(lb.PX[i]+st.Offset.X.Dots, lb.PY[i]+st.Offset.Y.Dots-tht))
}
}
// UpdateRange updates the given ranges.
func (lb *Labels) UpdateRange(plt *plot.Plot, xr, yr, zr *minmax.F64) {
// todo: include point sizes!
plot.Range(lb.X, xr)
plot.RangeClamp(lb.Y, yr, &lb.Style.Range)
pxToData := math32.FromPoint(plt.Size)
pxToData.X = float32(xr.Range()) / pxToData.X
pxToData.Y = float32(yr.Range()) / pxToData.Y
st := &lb.Style.Text
var ltxt plot.Text
ltxt.Style = *st
for i, label := range lb.Labels {
if label == "" {
continue
}
ltxt.Text = label
ltxt.Config(plt)
tht := pxToData.Y * ltxt.PaintText.BBox.Size().Y
twd := 1.1 * pxToData.X * ltxt.PaintText.BBox.Size().X
x := lb.X[i]
y := lb.Y[i]
maxx := x + float64(pxToData.X*st.Offset.X.Dots+twd)
maxy := y + float64(pxToData.Y*st.Offset.Y.Dots+tht) // y is up here
xr.FitInRange(minmax.F64{x, maxx})
yr.FitInRange(minmax.F64{y, maxy})
}
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Adapted from github.com/gonum/plot:
// Copyright ©2015 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plots
//go:generate core generate
import (
"cogentcore.org/core/math32"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/lab/plot"
)
// XYType is be used for specifying the type name.
const XYType = "XY"
func init() {
plot.RegisterPlotter(XYType, "draws lines between and / or points for X,Y data values, using optional Size and Color data for the points, for a bubble plot.", []plot.Roles{plot.X, plot.Y}, []plot.Roles{plot.Size, plot.Color}, func(data plot.Data) plot.Plotter {
return NewXY(data)
})
}
// XY draws lines between and / or points for XY data values.
type XY struct {
// copies of data for this line
X, Y, Color, Size plot.Values
// PX, PY are the actual pixel plotting coordinates for each XY value.
PX, PY []float32
// Style is the style for plotting.
Style plot.Style
stylers plot.Stylers
}
// NewXY returns an XY plotter for given X, Y data.
// data can also include Color and / or Size for the points.
// Styler functions are obtained from the Y metadata if present.
func NewXY(data plot.Data) *XY {
if data.CheckLengths() != nil {
return nil
}
ln := &XY{}
ln.X = plot.MustCopyRole(data, plot.X)
ln.Y = plot.MustCopyRole(data, plot.Y)
if ln.X == nil || ln.Y == nil {
return nil
}
ln.stylers = plot.GetStylersFromData(data, plot.Y)
ln.Color = plot.CopyRole(data, plot.Color)
ln.Size = plot.CopyRole(data, plot.Size)
ln.Defaults()
return ln
}
// NewLine returns an XY plot drawing Lines by default.
func NewLine(data plot.Data) *XY {
ln := NewXY(data)
if ln == nil {
return ln
}
ln.Style.Line.On = plot.On
ln.Style.Point.On = plot.Off
return ln
}
// NewScatter returns an XY scatter plot drawing Points by default.
func NewScatter(data plot.Data) *XY {
ln := NewXY(data)
if ln == nil {
return ln
}
ln.Style.Line.On = plot.Off
ln.Style.Point.On = plot.On
return ln
}
func (ln *XY) Defaults() {
ln.Style.Defaults()
}
// Styler adds a style function to set style parameters.
func (ln *XY) Styler(f func(s *plot.Style)) *XY {
ln.stylers.Add(f)
return ln
}
func (ln *XY) Stylers() *plot.Stylers { return &ln.stylers }
func (ln *XY) ApplyStyle(ps *plot.PlotStyle) {
ps.SetElementStyle(&ln.Style)
ln.stylers.Run(&ln.Style)
}
func (ln *XY) Data() (data plot.Data, pixX, pixY []float32) {
pixX = ln.PX
pixY = ln.PY
data = plot.Data{}
data[plot.X] = ln.X
data[plot.Y] = ln.Y
if ln.Size != nil {
data[plot.Size] = ln.Size
}
if ln.Color != nil {
data[plot.Color] = ln.Color
}
return
}
// Plot does the drawing, implementing the plot.Plotter interface.
func (ln *XY) Plot(plt *plot.Plot) {
ln.PX = plot.PlotX(plt, ln.X)
ln.PY = plot.PlotY(plt, ln.Y)
np := len(ln.PX)
if np == 0 || len(ln.PY) == 0 {
return
}
pc := plt.Paint
if ln.Style.Line.HasFill() {
pc.FillStyle.Color = ln.Style.Line.Fill
minY := plt.PY(plt.Y.Range.Min)
prevX := ln.PX[0]
prevY := minY
pc.MoveTo(prevX, prevY)
for i, ptx := range ln.PX {
pty := ln.PY[i]
switch ln.Style.Line.Step {
case plot.NoStep:
if ptx < prevX {
pc.LineTo(prevX, minY)
pc.ClosePath()
pc.MoveTo(ptx, minY)
}
pc.LineTo(ptx, pty)
case plot.PreStep:
if i == 0 {
continue
}
if ptx < prevX {
pc.LineTo(prevX, minY)
pc.ClosePath()
pc.MoveTo(ptx, minY)
} else {
pc.LineTo(prevX, pty)
}
pc.LineTo(ptx, pty)
case plot.MidStep:
if ptx < prevX {
pc.LineTo(prevX, minY)
pc.ClosePath()
pc.MoveTo(ptx, minY)
} else {
pc.LineTo(0.5*(prevX+ptx), prevY)
pc.LineTo(0.5*(prevX+ptx), pty)
}
pc.LineTo(ptx, pty)
case plot.PostStep:
if ptx < prevX {
pc.LineTo(prevX, minY)
pc.ClosePath()
pc.MoveTo(ptx, minY)
} else {
pc.LineTo(ptx, prevY)
}
pc.LineTo(ptx, pty)
}
prevX, prevY = ptx, pty
}
pc.LineTo(prevX, minY)
pc.ClosePath()
pc.Fill()
}
pc.FillStyle.Color = nil
if ln.Style.Line.SetStroke(plt) {
if plt.HighlightPlotter == ln {
pc.StrokeStyle.Width.Dots *= 1.5
}
prevX, prevY := ln.PX[0], ln.PY[0]
pc.MoveTo(prevX, prevY)
for i := 1; i < np; i++ {
ptx, pty := ln.PX[i], ln.PY[i]
if ln.Style.Line.Step != plot.NoStep {
if ptx >= prevX {
switch ln.Style.Line.Step {
case plot.PreStep:
pc.LineTo(prevX, pty)
case plot.MidStep:
pc.LineTo(0.5*(prevX+ptx), prevY)
pc.LineTo(0.5*(prevX+ptx), pty)
case plot.PostStep:
pc.LineTo(ptx, prevY)
}
} else {
pc.MoveTo(ptx, pty)
}
}
if !ln.Style.Line.NegativeX && ptx < prevX {
pc.MoveTo(ptx, pty)
} else {
pc.LineTo(ptx, pty)
}
prevX, prevY = ptx, pty
}
pc.Stroke()
}
if ln.Style.Point.SetStroke(plt) {
origWidth := pc.StrokeStyle.Width.Dots
for i, ptx := range ln.PX {
pty := ln.PY[i]
if plt.HighlightPlotter == ln {
if i == plt.HighlightIndex {
pc.StrokeStyle.Width.Dots *= 1.5
} else {
pc.StrokeStyle.Width.Dots = origWidth
}
}
ln.Style.Point.DrawShape(pc, math32.Vec2(ptx, pty))
}
} else if plt.HighlightPlotter == ln {
op := ln.Style.Point.On
ln.Style.Point.On = plot.On
ln.Style.Point.SetStroke(plt)
ptx := ln.PX[plt.HighlightIndex]
pty := ln.PY[plt.HighlightIndex]
ln.Style.Point.DrawShape(pc, math32.Vec2(ptx, pty))
ln.Style.Point.On = op
}
pc.FillStyle.Color = nil
}
// UpdateRange updates the given ranges.
func (ln *XY) UpdateRange(plt *plot.Plot, xr, yr, zr *minmax.F64) {
// todo: include point sizes!
plot.Range(ln.X, xr)
plot.RangeClamp(ln.Y, yr, &ln.Style.Range)
}
// Thumbnail returns the thumbnail, implementing the plot.Thumbnailer interface.
func (ln *XY) Thumbnail(plt *plot.Plot) {
pc := plt.Paint
ptb := pc.Bounds
midY := 0.5 * float32(ptb.Min.Y+ptb.Max.Y)
if ln.Style.Line.Fill != nil {
tb := ptb
if ln.Style.Line.Width.Value > 0 {
tb.Min.Y = int(midY)
}
pc.FillBox(math32.FromPoint(tb.Min), math32.FromPoint(tb.Size()), ln.Style.Line.Fill)
}
if ln.Style.Line.SetStroke(plt) {
pc.MoveTo(float32(ptb.Min.X), midY)
pc.LineTo(float32(ptb.Max.X), midY)
pc.Stroke()
}
if ln.Style.Point.SetStroke(plt) {
midX := 0.5 * float32(ptb.Min.X+ptb.Max.X)
ln.Style.Point.DrawShape(pc, math32.Vec2(midX, midY))
}
pc.FillStyle.Color = nil
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
import (
"fmt"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/math32/minmax"
)
// Plotter is an interface that wraps the Plot method.
// Standard implementations of Plotter are in the [plots] package.
type Plotter interface {
// Plot draws the data to the Plot Paint.
Plot(pt *Plot)
// UpdateRange updates the given ranges.
UpdateRange(plt *Plot, xr, yr, zr *minmax.F64)
// Data returns the data by roles for this plot, for both the original
// data and the pixel-transformed X,Y coordinates for that data.
// This allows a GUI interface to inspect data etc.
Data() (data Data, pixX, pixY []float32)
// Stylers returns the styler functions for this element.
Stylers() *Stylers
// ApplyStyle applies any stylers to this element,
// first initializing from the given global plot style, which has
// already been styled with defaults and all the plot element stylers.
ApplyStyle(plotStyle *PlotStyle)
}
// PlotterType registers a Plotter so that it can be created with appropriate data.
type PlotterType struct {
// Name of the plot type.
Name string
// Doc is the documentation for this Plotter.
Doc string
// Required Data roles for this plot. Data for these Roles must be provided.
Required []Roles
// Optional Data roles for this plot.
Optional []Roles
// New returns a new plotter of this type with given data in given roles.
New func(data Data) Plotter
}
// PlotterName is the name of a specific plotter type.
type PlotterName string
// Plotters is the registry of [Plotter] types.
var Plotters = map[string]PlotterType{}
// RegisterPlotter registers a plotter type.
func RegisterPlotter(name, doc string, required, optional []Roles, newFun func(data Data) Plotter) {
Plotters[name] = PlotterType{Name: name, Doc: doc, Required: required, Optional: optional, New: newFun}
}
// PlotterByType returns [PlotterType] info for a registered [Plotter]
// of given type name, e.g., "XY", "Bar" etc,
// Returns an error and nil if type name is not a registered type.
func PlotterByType(typeName string) (*PlotterType, error) {
pt, ok := Plotters[typeName]
if !ok {
return nil, fmt.Errorf("plot.PlotterByType type name is not registered: %s", typeName)
}
return &pt, nil
}
// NewPlotter returns a new plotter of given type, e.g., "XY", "Bar" etc,
// for given data roles (which must include Required roles, and may include Optional ones).
// Logs an error and returns nil if type name is not a registered type.
func NewPlotter(typeName string, data Data) Plotter {
pt, err := PlotterByType(typeName)
if errors.Log(err) != nil {
return nil
}
return pt.New(data)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
import (
"image"
"cogentcore.org/core/colors"
"cogentcore.org/core/math32"
"cogentcore.org/core/paint"
"cogentcore.org/core/styles/units"
)
// PointStyle has style properties for drawing points as different shapes.
type PointStyle struct { //types:add -setters
// On indicates whether to plot points.
On DefaultOffOn
// Shape to draw.
Shape Shapes
// Color is the stroke color image specification.
// Setting to nil turns line off.
Color image.Image
// Fill is the color to fill solid regions, in a plot-specific
// way (e.g., the area below a Line plot, the bar color).
// Use nil to disable filling.
Fill image.Image
// Width is the line width for point glyphs, with a default of 1 Pt (point).
// Setting to 0 turns line off.
Width units.Value
// Size of shape to draw for each point.
// Defaults to 4 Pt (point).
Size units.Value
}
func (ps *PointStyle) Defaults() {
ps.Color = colors.Scheme.OnSurface
ps.Fill = colors.Scheme.OnSurface
ps.Width.Pt(1)
ps.Size.Pt(4)
}
// SetStroke sets the stroke style in plot paint to current line style.
// returns false if either the Width = 0 or Color is nil
func (ps *PointStyle) SetStroke(pt *Plot) bool {
if ps.On == Off || ps.Color == nil {
return false
}
pc := pt.Paint
uc := &pc.UnitContext
ps.Width.ToDots(uc)
ps.Size.ToDots(uc)
if ps.Width.Dots == 0 || ps.Size.Dots == 0 {
return false
}
pc.StrokeStyle.Width = ps.Width
pc.StrokeStyle.Color = ps.Color
pc.StrokeStyle.ToDots(uc)
pc.FillStyle.Color = ps.Fill
return true
}
// DrawShape draws the given shape
func (ps *PointStyle) DrawShape(pc *paint.Context, pos math32.Vector2) {
size := ps.Size.Dots
if size == 0 {
return
}
switch ps.Shape {
case Ring:
DrawRing(pc, pos, size)
case Circle:
DrawCircle(pc, pos, size)
case Square:
DrawSquare(pc, pos, size)
case Box:
DrawBox(pc, pos, size)
case Triangle:
DrawTriangle(pc, pos, size)
case Pyramid:
DrawPyramid(pc, pos, size)
case Plus:
DrawPlus(pc, pos, size)
case Cross:
DrawCross(pc, pos, size)
}
}
func DrawRing(pc *paint.Context, pos math32.Vector2, size float32) {
pc.DrawCircle(pos.X, pos.Y, size)
pc.Stroke()
}
func DrawCircle(pc *paint.Context, pos math32.Vector2, size float32) {
pc.DrawCircle(pos.X, pos.Y, size)
pc.FillStrokeClear()
}
func DrawSquare(pc *paint.Context, pos math32.Vector2, size float32) {
x := size * 0.9
pc.MoveTo(pos.X-x, pos.Y-x)
pc.LineTo(pos.X+x, pos.Y-x)
pc.LineTo(pos.X+x, pos.Y+x)
pc.LineTo(pos.X-x, pos.Y+x)
pc.ClosePath()
pc.Stroke()
}
func DrawBox(pc *paint.Context, pos math32.Vector2, size float32) {
x := size * 0.9
pc.MoveTo(pos.X-x, pos.Y-x)
pc.LineTo(pos.X+x, pos.Y-x)
pc.LineTo(pos.X+x, pos.Y+x)
pc.LineTo(pos.X-x, pos.Y+x)
pc.ClosePath()
pc.FillStrokeClear()
}
func DrawTriangle(pc *paint.Context, pos math32.Vector2, size float32) {
x := size * 0.9
pc.MoveTo(pos.X, pos.Y-x)
pc.LineTo(pos.X-x, pos.Y+x)
pc.LineTo(pos.X+x, pos.Y+x)
pc.ClosePath()
pc.Stroke()
}
func DrawPyramid(pc *paint.Context, pos math32.Vector2, size float32) {
x := size * 0.9
pc.MoveTo(pos.X, pos.Y-x)
pc.LineTo(pos.X-x, pos.Y+x)
pc.LineTo(pos.X+x, pos.Y+x)
pc.ClosePath()
pc.FillStrokeClear()
}
func DrawPlus(pc *paint.Context, pos math32.Vector2, size float32) {
x := size * 1.05
pc.MoveTo(pos.X-x, pos.Y)
pc.LineTo(pos.X+x, pos.Y)
pc.MoveTo(pos.X, pos.Y-x)
pc.LineTo(pos.X, pos.Y+x)
pc.ClosePath()
pc.Stroke()
}
func DrawCross(pc *paint.Context, pos math32.Vector2, size float32) {
x := size * 0.9
pc.MoveTo(pos.X-x, pos.Y-x)
pc.LineTo(pos.X+x, pos.Y+x)
pc.MoveTo(pos.X+x, pos.Y-x)
pc.LineTo(pos.X-x, pos.Y+x)
pc.ClosePath()
pc.Stroke()
}
// Shapes has the options for how to draw points in the plot.
type Shapes int32 //enums:enum
const (
// Ring is the outline of a circle
Ring Shapes = iota
// Circle is a solid circle
Circle
// Square is the outline of a square
Square
// Box is a filled square
Box
// Triangle is the outline of a triangle
Triangle
// Pyramid is a filled triangle
Pyramid
// Plus is a plus sign
Plus
// Cross is a big X
Cross
)
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
import (
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/core/styles/units"
)
// Style contains the plot styling properties relevant across
// most plot types. These properties apply to individual plot elements
// while the Plot properties applies to the overall plot itself.
type Style struct { //types:add -setters
// Plot has overall plot-level properties, which can be set by any
// plot element, and are updated first, before applying element-wise styles.
Plot PlotStyle `display:"-"`
// On specifies whether to plot this item, for table-based plots.
On bool
// Plotter is the type of plotter to use in plotting this data,
// for [plot.NewTablePlot] [table.Table] driven plots.
// Blank means use default ([plots.XY] is overall default).
Plotter PlotterName
// Role specifies how a particular column of data should be used,
// for [plot.NewTablePlot] [table.Table] driven plots.
Role Roles
// Group specifies a group of related data items,
// for [plot.NewTablePlot] [table.Table] driven plots,
// where different columns of data within the same Group play different Roles.
Group string
// Range is the effective range of data to plot, where either end can be fixed.
Range minmax.Range64 `display:"inline"`
// Label provides an alternative label to use for axis, if set.
Label string
// NoLegend excludes this item from the legend when it otherwise would be included,
// for [plot.NewTablePlot] [table.Table] driven plots.
// Role = Y values are included in the Legend by default.
NoLegend bool
// NTicks sets the desired number of ticks for the axis, if > 0.
NTicks int
// Line has style properties for drawing lines.
Line LineStyle `display:"add-fields"`
// Point has style properties for drawing points.
Point PointStyle `display:"add-fields"`
// Text has style properties for rendering text.
Text TextStyle `display:"add-fields"`
// Width has various plot width properties.
Width WidthStyle `display:"inline"`
}
// NewStyle returns a new Style object with defaults applied.
func NewStyle() *Style {
st := &Style{}
st.Defaults()
return st
}
func (st *Style) Defaults() {
st.Plot.Defaults()
st.Line.Defaults()
st.Point.Defaults()
st.Text.Defaults()
st.Width.Defaults()
}
// WidthStyle contains various plot width properties relevant across
// different plot types.
type WidthStyle struct { //types:add -setters
// Cap is the width of the caps drawn at the top of error bars.
// The default is 10dp
Cap units.Value
// Offset for Bar plot is the offset added to each X axis value
// relative to the Stride computed value (X = offset + index * Stride)
// Defaults to 0.
Offset float64
// Stride for Bar plot is distance between bars. Defaults to 1.
Stride float64
// Width for Bar plot is the width of the bars, as a fraction of the Stride,
// to prevent bar overlap. Defaults to .8.
Width float64 `min:"0.01" max:"1" default:"0.8"`
// Pad for Bar plot is additional space at start / end of data range,
// to keep bars from overflowing ends. This amount is subtracted from Offset
// and added to (len(Values)-1)*Stride -- no other accommodation for bar
// width is provided, so that should be built into this value as well.
// Defaults to 1.
Pad float64
}
func (ws *WidthStyle) Defaults() {
ws.Cap.Dp(10)
ws.Offset = 1
ws.Stride = 1
ws.Width = .8
ws.Pad = 1
}
// Stylers is a list of styling functions that set Style properties.
// These are called in the order added.
type Stylers []func(s *Style)
// Add Adds a styling function to the list.
func (st *Stylers) Add(f func(s *Style)) {
*st = append(*st, f)
}
// Run runs the list of styling functions on given [Style] object.
func (st *Stylers) Run(s *Style) {
for _, f := range *st {
f(s)
}
}
// NewStyle returns a new Style object with styling functions applied
// on top of Style defaults.
func (st *Stylers) NewStyle(ps *PlotStyle) *Style {
s := NewStyle()
ps.SetElementStyle(s)
st.Run(s)
return s
}
// SetStylersTo sets the [Stylers] into given object's [metadata].
func SetStylersTo(obj any, st Stylers) {
metadata.SetTo(obj, "PlotStylers", st)
}
// GetStylersFrom returns [Stylers] from given object's [metadata].
// Returns nil if none or no metadata.
func GetStylersFrom(obj any) Stylers {
st, _ := metadata.GetFrom[Stylers](obj, "PlotStylers")
return st
}
// SetStylerTo sets the [Styler] function into given object's [metadata],
// replacing anything that might have already been added.
func SetStylerTo(obj any, f func(s *Style)) {
metadata.SetTo(obj, "PlotStylers", Stylers{f})
}
// SetFirstStylerTo sets the [Styler] function into given object's [metadata],
// only if there are no other stylers present.
func SetFirstStylerTo(obj any, f func(s *Style)) {
st := GetStylersFrom(obj)
if len(st) > 0 {
return
}
metadata.SetTo(obj, "PlotStylers", Stylers{f})
}
// AddStylerTo adds the given [Styler] function into given object's [metadata].
func AddStylerTo(obj any, f func(s *Style)) {
st := GetStylersFrom(obj)
st.Add(f)
SetStylersTo(obj, st)
}
// GetStylersFromData returns [Stylers] from given role
// in given [Data]. nil if not present.
func GetStylersFromData(data Data, role Roles) Stylers {
vr, ok := data[role]
if !ok {
return nil
}
return GetStylersFrom(vr)
}
////////
// DefaultOffOn specifies whether to use the default value for a bool option,
// or to override the default and set Off or On.
type DefaultOffOn int32 //enums:enum
const (
// Default means use the default value.
Default DefaultOffOn = iota
// Off means to override the default and turn Off.
Off
// On means to override the default and turn On.
On
)
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
import (
"fmt"
"reflect"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/reflectx"
"cogentcore.org/lab/table"
"golang.org/x/exp/maps"
)
// NewTablePlot returns a new Plot with all configuration based on given
// [table.Table] set of columns and associated metadata, which must have
// [Stylers] functions set (e.g., [SetStylersTo]) that at least set basic
// table parameters, including:
// - On: Set the main (typically Role = Y) column On to include in plot.
// - Role: Set the appropriate [Roles] role for this column (Y, X, etc).
// - Group: Multiple columns used for a given Plotter type must be grouped
// together with a common name (typically the name of the main Y axis),
// e.g., for Low, High error bars, Size, Color, etc. If only one On column,
// then Group can be empty and all other such columns will be grouped.
// - Plotter: Determines the type of Plotter element to use, which in turn
// determines the additional Roles that can be used within a Group.
func NewTablePlot(dt *table.Table) (*Plot, error) {
nc := len(dt.Columns.Values)
if nc == 0 {
return nil, errors.New("plot.NewTablePlot: no columns in data table")
}
csty := make([]*Style, nc)
gps := make(map[string][]int, nc)
xi := -1 // get the _last_ role = X column -- most specific counter
var errs []error
var pstySt Style // overall PlotStyle accumulator
pstySt.Defaults()
for ci, cl := range dt.Columns.Values {
st := &Style{}
st.Defaults()
stl := GetStylersFrom(cl)
if stl != nil {
stl.Run(st)
}
csty[ci] = st
stl.Run(&pstySt)
gps[st.Group] = append(gps[st.Group], ci)
if st.Role == X {
xi = ci
}
}
psty := pstySt.Plot
globalX := false
xidxs := map[int]bool{} // map of all the _unique_ x indexes used
if psty.XAxis.Column != "" {
xc := dt.Columns.IndexByKey(psty.XAxis.Column)
if xc >= 0 {
xi = xc
globalX = true
xidxs[xi] = true
} else {
errs = append(errs, errors.New("XAxis.Column name not found: "+psty.XAxis.Column))
}
}
doneGps := map[string]bool{}
plt := New()
var legends []Thumbnailer // candidates for legend adding -- only add if > 1
var legLabels []string
var barCols []int // column indexes of bar plots
var barPlots []int // plotter indexes of bar plots
for ci, cl := range dt.Columns.Values {
cnm := dt.Columns.Keys[ci]
st := csty[ci]
if !st.On || st.Role == X {
continue
}
lbl := cnm
if st.Label != "" {
lbl = st.Label
}
gp := st.Group
if doneGps[gp] {
continue
}
if gp != "" {
doneGps[gp] = true
}
ptyp := "XY"
if st.Plotter != "" {
ptyp = string(st.Plotter)
}
pt, err := PlotterByType(ptyp)
if err != nil {
errs = append(errs, err)
continue
}
data := Data{st.Role: cl}
gcols := gps[gp]
gotReq := true
gotX := -1
if globalX {
data[X] = dt.Columns.Values[xi]
gotX = xi
}
for _, rl := range pt.Required {
if rl == st.Role || (rl == X && globalX) {
continue
}
got := false
for _, gi := range gcols {
gst := csty[gi]
if gst.Role == rl {
if rl == Y {
if !gst.On {
continue
}
}
data[rl] = dt.Columns.Values[gi]
got = true
if rl == X {
gotX = gi // fallthrough so we get the last X
} else {
break
}
}
}
if !got {
if rl == X && xi >= 0 {
gotX = xi
data[rl] = dt.Columns.Values[xi]
} else {
err = fmt.Errorf("plot.NewTablePlot: Required Role %q not found in Group %q, Plotter %q not added for Column: %q", rl.String(), gp, ptyp, cnm)
errs = append(errs, err)
gotReq = false
}
}
}
if !gotReq {
continue
}
if gotX >= 0 {
xidxs[gotX] = true
}
for _, rl := range pt.Optional {
if rl == st.Role { // should not happen
continue
}
for _, gi := range gcols {
gst := csty[gi]
if gst.Role == rl {
data[rl] = dt.Columns.Values[gi]
break
}
}
}
pl := pt.New(data)
if reflectx.IsNil(reflect.ValueOf(pl)) {
err = fmt.Errorf("plot.NewTablePlot: error in creating plotter type: %q", ptyp)
errs = append(errs, err)
continue
}
plt.Add(pl)
if !st.NoLegend {
if tn, ok := pl.(Thumbnailer); ok {
legends = append(legends, tn)
legLabels = append(legLabels, lbl)
}
}
if ptyp == "Bar" {
barCols = append(barCols, ci)
barPlots = append(barPlots, len(plt.Plotters)-1)
}
}
if len(legends) > 1 {
for i, l := range legends {
plt.Legend.Add(legLabels[i], l)
}
}
if psty.XAxis.Label == "" && len(xidxs) == 1 {
xi := maps.Keys(xidxs)[0]
lbl := dt.Columns.Keys[xi]
if csty[xi].Label != "" {
lbl = csty[xi].Label
}
if len(plt.Plotters) > 0 {
pl0 := plt.Plotters[0]
if pl0 != nil {
pl0.Stylers().Add(func(s *Style) {
s.Plot.XAxis.Label = lbl
})
}
}
}
nbar := len(barCols)
if nbar > 1 {
sz := 1.0 / (float64(nbar) + 0.5)
for bi, bp := range barPlots {
pl := plt.Plotters[bp]
pl.Stylers().Add(func(s *Style) {
s.Width.Stride = 1
s.Width.Offset = float64(bi) * sz
s.Width.Width = psty.BarWidth * sz
})
}
}
return plt, errors.Join(errs...)
}
// todo: bar chart rows, if needed
//
// netn := pl.table.NumRows() * stride
// xc := pl.table.ColumnByIndex(xi)
// vals := make([]string, netn)
// for i, dx := range pl.table.Indexes {
// pi := mid + i*stride
// if pi < netn && dx < xc.Len() {
// vals[pi] = xc.String1D(dx)
// }
// }
// plt.NominalX(vals...)
// todo:
// Use string labels for X axis if X is a string
// xc := pl.table.ColumnByIndex(xi)
// if xc.Tensor.IsString() {
// xcs := xc.Tensor.(*tensor.String)
// vals := make([]string, pl.table.NumRows())
// for i, dx := range pl.table.Indexes {
// vals[i] = xcs.Values[dx]
// }
// plt.NominalX(vals...)
// }
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
import (
"image"
"cogentcore.org/core/colors"
"cogentcore.org/core/math32"
"cogentcore.org/core/paint"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/units"
)
// DefaultFontFamily specifies a default font for plotting.
// if not set, the standard Cogent Core default font is used.
var DefaultFontFamily = ""
// TextStyle specifies styling parameters for Text elements.
type TextStyle struct { //types:add -setters
// Size of font to render. Default is 16dp
Size units.Value
// Family name for font (inherited): ordered list of comma-separated names
// from more general to more specific to use. Use split on, to parse.
Family string
// Color of text.
Color image.Image
// Align specifies how to align text along the relevant
// dimension for the text element.
Align styles.Aligns
// Padding is used in a case-dependent manner to add
// space around text elements.
Padding units.Value
// Rotation of the text, in degrees.
Rotation float32
// Offset is added directly to the final label location.
Offset units.XY
}
func (ts *TextStyle) Defaults() {
ts.Size.Dp(16)
ts.Color = colors.Scheme.OnSurface
ts.Align = styles.Center
if DefaultFontFamily != "" {
ts.Family = DefaultFontFamily
}
}
// Text specifies a single text element in a plot
type Text struct {
// text string, which can use HTML formatting
Text string
// styling for this text element
Style TextStyle
// font has the full font rendering styles.
font styles.FontRender
// PaintText is the [paint.Text] for the text.
PaintText paint.Text
}
func (tx *Text) Defaults() {
tx.Style.Defaults()
}
// config is called during the layout of the plot, prior to drawing
func (tx *Text) Config(pt *Plot) {
uc := &pt.Paint.UnitContext
fs := &tx.font
fs.Size = tx.Style.Size
fs.Family = tx.Style.Family
fs.Color = tx.Style.Color
if math32.Abs(tx.Style.Rotation) > 10 {
tx.Style.Align = styles.End
}
fs.ToDots(uc)
tx.Style.Padding.ToDots(uc)
txln := float32(len(tx.Text))
fht := fs.Size.Dots
hsz := float32(12) * txln
txs := &pt.StandardTextStyle
tx.PaintText.SetHTML(tx.Text, fs, txs, uc, nil)
tx.PaintText.Layout(txs, fs, uc, math32.Vector2{X: hsz, Y: fht})
if tx.Style.Rotation != 0 {
rotx := math32.Rotate2D(math32.DegToRad(tx.Style.Rotation))
tx.PaintText.Transform(rotx, fs, uc)
}
}
func (tx *Text) openFont(pt *Plot) {
if tx.font.Face == nil {
paint.OpenFont(&tx.font, &pt.Paint.UnitContext) // calls SetUnContext after updating metrics
}
}
func (tx *Text) ToDots(uc *units.Context) {
tx.font.ToDots(uc)
tx.Style.Padding.ToDots(uc)
}
// PosX returns the starting position for a horizontally-aligned text element,
// based on given width. Text must have been config'd already.
func (tx *Text) PosX(width float32) math32.Vector2 {
pos := math32.Vector2{}
pos.X = styles.AlignFactor(tx.Style.Align) * width
switch tx.Style.Align {
case styles.Center:
pos.X -= 0.5 * tx.PaintText.BBox.Size().X
case styles.End:
pos.X -= tx.PaintText.BBox.Size().X
}
if math32.Abs(tx.Style.Rotation) > 10 {
pos.Y += 0.5 * tx.PaintText.BBox.Size().Y
}
return pos
}
// PosY returns the starting position for a vertically-rotated text element,
// based on given height. Text must have been config'd already.
func (tx *Text) PosY(height float32) math32.Vector2 {
pos := math32.Vector2{}
pos.Y = styles.AlignFactor(tx.Style.Align) * height
switch tx.Style.Align {
case styles.Center:
pos.Y -= 0.5 * tx.PaintText.BBox.Size().Y
case styles.End:
pos.Y -= tx.PaintText.BBox.Size().Y
}
return pos
}
// Draw renders the text at given upper left position
func (tx *Text) Draw(pt *Plot, pos math32.Vector2) {
tx.PaintText.Render(pt.Paint, pos)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
import (
"math"
"strconv"
"time"
)
// A Tick is a single tick mark on an axis.
type Tick struct {
// Value is the data value marked by this Tick.
Value float64
// Label is the text to display at the tick mark.
// If Label is an empty string then this is a minor tick mark.
Label string
}
// IsMinor returns true if this is a minor tick mark.
func (tk *Tick) IsMinor() bool {
return tk.Label == ""
}
// Ticker creates Ticks in a specified range
type Ticker interface {
// Ticks returns Ticks in a specified range, with desired number of ticks,
// which can be ignored depending on the ticker type.
Ticks(min, max float64, nticks int) []Tick
}
// DefaultTicks is suitable for the Ticker field of an Axis,
// it returns a reasonable default set of tick marks.
type DefaultTicks struct{}
var _ Ticker = DefaultTicks{}
// Ticks returns Ticks in the specified range.
func (DefaultTicks) Ticks(min, max float64, nticks int) []Tick {
if max <= min {
panic("illegal range")
}
labels, step, q, mag := talbotLinHanrahan(min, max, nticks, withinData, nil, nil, nil)
majorDelta := step * math.Pow10(mag)
if q == 0 {
// Simple fall back was chosen, so
// majorDelta is the label distance.
majorDelta = labels[1] - labels[0]
}
// Choose a reasonable, but ad
// hoc formatting for labels.
fc := byte('f')
var off int
if mag < -1 || 6 < mag {
off = 1
fc = 'g'
}
if math.Trunc(q) != q {
off += 2
}
prec := minInt(6, maxInt(off, -mag))
ticks := make([]Tick, len(labels))
for i, v := range labels {
ticks[i] = Tick{Value: v, Label: strconv.FormatFloat(float64(v), fc, prec, 64)}
}
var minorDelta float64
// See talbotLinHanrahan for the values used here.
switch step {
case 1, 2.5:
minorDelta = majorDelta / 5
case 2, 3, 4, 5:
minorDelta = majorDelta / step
default:
if majorDelta/2 < dlamchP {
return ticks
}
minorDelta = majorDelta / 2
}
// Find the first minor tick not greater
// than the lowest data value.
var i float64
for labels[0]+(i-1)*minorDelta > min {
i--
}
// Add ticks at minorDelta intervals when
// they are not within minorDelta/2 of a
// labelled tick.
for {
val := labels[0] + i*minorDelta
if val > max {
break
}
found := false
for _, t := range ticks {
if math.Abs(t.Value-val) < minorDelta/2 {
found = true
}
}
if !found {
ticks = append(ticks, Tick{Value: val})
}
i++
}
return ticks
}
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
func maxInt(a, b int) int {
if a > b {
return a
}
return b
}
// LogTicks is suitable for the Ticker field of an Axis,
// it returns tick marks suitable for a log-scale axis.
type LogTicks struct {
// Prec specifies the precision of tick rendering
// according to the documentation for strconv.FormatFloat.
Prec int
}
var _ Ticker = LogTicks{}
// Ticks returns Ticks in a specified range
func (t LogTicks) Ticks(min, max float64, nticks int) []Tick {
if min <= 0 || max <= 0 {
panic("Values must be greater than 0 for a log scale.")
}
val := math.Pow10(int(math.Log10(min)))
max = math.Pow10(int(math.Ceil(math.Log10(max))))
var ticks []Tick
for val < max {
for i := 1; i < 10; i++ {
if i == 1 {
ticks = append(ticks, Tick{Value: val, Label: formatFloatTick(val, t.Prec)})
}
ticks = append(ticks, Tick{Value: val * float64(i)})
}
val *= 10
}
ticks = append(ticks, Tick{Value: val, Label: formatFloatTick(val, t.Prec)})
return ticks
}
// ConstantTicks is suitable for the Ticker field of an Axis.
// This function returns the given set of ticks.
type ConstantTicks []Tick
var _ Ticker = ConstantTicks{}
// Ticks returns Ticks in a specified range
func (ts ConstantTicks) Ticks(float64, float64, int) []Tick {
return ts
}
// UnixTimeIn returns a time conversion function for the given location.
func UnixTimeIn(loc *time.Location) func(t float64) time.Time {
return func(t float64) time.Time {
return time.Unix(int64(t), 0).In(loc)
}
}
// UTCUnixTime is the default time conversion for TimeTicks.
var UTCUnixTime = UnixTimeIn(time.UTC)
// TimeTicks is suitable for axes representing time values.
type TimeTicks struct {
// Ticker is used to generate a set of ticks.
// If nil, DefaultTicks will be used.
Ticker Ticker
// Format is the textual representation of the time value.
// If empty, time.RFC3339 will be used
Format string
// Time takes a float32 value and converts it into a time.Time.
// If nil, UTCUnixTime is used.
Time func(t float64) time.Time
}
var _ Ticker = TimeTicks{}
// Ticks implements plot.Ticker.
func (t TimeTicks) Ticks(min, max float64, nticks int) []Tick {
if t.Ticker == nil {
t.Ticker = DefaultTicks{}
}
if t.Format == "" {
t.Format = time.RFC3339
}
if t.Time == nil {
t.Time = UTCUnixTime
}
ticks := t.Ticker.Ticks(min, max, nticks)
for i := range ticks {
tick := &ticks[i]
if tick.Label == "" {
continue
}
tick.Label = t.Time(tick.Value).Format(t.Format)
}
return ticks
}
/*
// lengthOffset returns an offset that should be added to the
// tick mark's line to accout for its length. I.e., the start of
// the line for a minor tick mark must be shifted by half of
// the length.
func (t Tick) lengthOffset(len vg.Length) vg.Length {
if t.IsMinor() {
return len / 2
}
return 0
}
// tickLabelHeight returns height of the tick mark labels.
func tickLabelHeight(sty text.Style, ticks []Tick) vg.Length {
maxHeight := vg.Length(0)
for _, t := range ticks {
if t.IsMinor() {
continue
}
r := sty.Rectangle(t.Label)
h := r.Max.Y - r.Min.Y
if h > maxHeight {
maxHeight = h
}
}
return maxHeight
}
// tickLabelWidth returns the width of the widest tick mark label.
func tickLabelWidth(sty text.Style, ticks []Tick) vg.Length {
maxWidth := vg.Length(0)
for _, t := range ticks {
if t.IsMinor() {
continue
}
r := sty.Rectangle(t.Label)
w := r.Max.X - r.Min.X
if w > maxWidth {
maxWidth = w
}
}
return maxWidth
}
*/
// formatFloatTick returns a g-formated string representation of v
// to the specified precision.
func formatFloatTick(v float64, prec int) string {
return strconv.FormatFloat(float64(v), 'g', prec, 64)
}
// // TickerFunc is suitable for the Ticker field of an Axis.
// // It is an adapter which allows to quickly setup a Ticker using a function with an appropriate signature.
// type TickerFunc func(min, max float64) []Tick
//
// var _ Ticker = TickerFunc(nil)
//
// // Ticks implements plot.Ticker.
// func (f TickerFunc) Ticks(min, max float64) []Tick {
// return f(min, max)
// }
// Code generated by "core generate -add-types"; DO NOT EDIT.
package plot
import (
"image"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/units"
"cogentcore.org/core/types"
)
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.AxisScales", IDName: "axis-scales", Doc: "AxisScales are the scaling options for how values are distributed\nalong an axis: Linear, Log, etc."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.AxisStyle", IDName: "axis-style", Doc: "AxisStyle has style properties for the axis.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "Text", Doc: "Text has the text style parameters for the text label."}, {Name: "Line", Doc: "Line has styling properties for the axis line."}, {Name: "Padding", Doc: "Padding between the axis line and the data. Having\nnon-zero padding ensures that the data is never drawn\non the axis, thus making it easier to see."}, {Name: "NTicks", Doc: "NTicks is the desired number of ticks (actual likely will be different)."}, {Name: "Scale", Doc: "Scale specifies how values are scaled along the axis:\nLinear, Log, Inverted"}, {Name: "TickText", Doc: "TickText has the text style for rendering tick labels,\nand is shared for actual rendering."}, {Name: "TickLine", Doc: "TickLine has line style for drawing tick lines."}, {Name: "TickLength", Doc: "TickLength is the length of tick lines."}}})
// SetText sets the [AxisStyle.Text]:
// Text has the text style parameters for the text label.
func (t *AxisStyle) SetText(v TextStyle) *AxisStyle { t.Text = v; return t }
// SetLine sets the [AxisStyle.Line]:
// Line has styling properties for the axis line.
func (t *AxisStyle) SetLine(v LineStyle) *AxisStyle { t.Line = v; return t }
// SetPadding sets the [AxisStyle.Padding]:
// Padding between the axis line and the data. Having
// non-zero padding ensures that the data is never drawn
// on the axis, thus making it easier to see.
func (t *AxisStyle) SetPadding(v units.Value) *AxisStyle { t.Padding = v; return t }
// SetNTicks sets the [AxisStyle.NTicks]:
// NTicks is the desired number of ticks (actual likely will be different).
func (t *AxisStyle) SetNTicks(v int) *AxisStyle { t.NTicks = v; return t }
// SetScale sets the [AxisStyle.Scale]:
// Scale specifies how values are scaled along the axis:
// Linear, Log, Inverted
func (t *AxisStyle) SetScale(v AxisScales) *AxisStyle { t.Scale = v; return t }
// SetTickText sets the [AxisStyle.TickText]:
// TickText has the text style for rendering tick labels,
// and is shared for actual rendering.
func (t *AxisStyle) SetTickText(v TextStyle) *AxisStyle { t.TickText = v; return t }
// SetTickLine sets the [AxisStyle.TickLine]:
// TickLine has line style for drawing tick lines.
func (t *AxisStyle) SetTickLine(v LineStyle) *AxisStyle { t.TickLine = v; return t }
// SetTickLength sets the [AxisStyle.TickLength]:
// TickLength is the length of tick lines.
func (t *AxisStyle) SetTickLength(v units.Value) *AxisStyle { t.TickLength = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Axis", IDName: "axis", Doc: "Axis represents either a horizontal or vertical\naxis of a plot.", Fields: []types.Field{{Name: "Range", Doc: "Range has the Min, Max range of values for the axis (in raw data units.)"}, {Name: "Axis", Doc: "specifies which axis this is: X, Y or Z."}, {Name: "Label", Doc: "Label for the axis."}, {Name: "Style", Doc: "Style has the style parameters for the Axis."}, {Name: "TickText", Doc: "TickText is used for rendering the tick text labels."}, {Name: "Ticker", Doc: "Ticker generates the tick marks. Any tick marks\nreturned by the Marker function that are not in\nrange of the axis are not drawn."}, {Name: "Scale", Doc: "Scale transforms a value given in the data coordinate system\nto the normalized coordinate system of the axis—its distance\nalong the axis as a fraction of the axis range."}, {Name: "AutoRescale", Doc: "AutoRescale enables an axis to automatically adapt its minimum\nand maximum boundaries, according to its underlying Ticker."}, {Name: "ticks", Doc: "cached list of ticks, set in size"}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Normalizer", IDName: "normalizer", Doc: "Normalizer rescales values from the data coordinate system to the\nnormalized coordinate system.", Methods: []types.Method{{Name: "Normalize", Doc: "Normalize transforms a value x in the data coordinate system to\nthe normalized coordinate system.", Args: []string{"min", "max", "x"}, Returns: []string{"float64"}}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.LinearScale", IDName: "linear-scale", Doc: "LinearScale an be used as the value of an Axis.Scale function to\nset the axis to a standard linear scale."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.LogScale", IDName: "log-scale", Doc: "LogScale can be used as the value of an Axis.Scale function to\nset the axis to a log scale."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.InvertedScale", IDName: "inverted-scale", Doc: "InvertedScale can be used as the value of an Axis.Scale function to\ninvert the axis using any Normalizer.", Embeds: []types.Field{{Name: "Normalizer"}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Data", IDName: "data", Doc: "Data is a map of Roles and Data for that Role, providing the\nprimary way of passing data to a Plotter"})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Valuer", IDName: "valuer", Doc: "Valuer is the data interface for plotting, supporting either\nfloat64 or string representations. It is satisfied by the tensor.Tensor\ninterface, so a tensor can be used directly for plot Data.", Methods: []types.Method{{Name: "Len", Doc: "Len returns the number of values.", Returns: []string{"int"}}, {Name: "Float1D", Doc: "Float1D(i int) returns float64 value at given index.", Args: []string{"i"}, Returns: []string{"float64"}}, {Name: "String1D", Doc: "String1D(i int) returns string value at given index.", Args: []string{"i"}, Returns: []string{"string"}}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Roles", IDName: "roles", Doc: "Roles are the roles that a given set of data values can play,\ndesigned to be sufficiently generalizable across all different\ntypes of plots, even if sometimes it is a bit of a stretch."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Values", IDName: "values", Doc: "Values provides a minimal implementation of the Data interface\nusing a slice of float64."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Labels", IDName: "labels", Doc: "Labels provides a minimal implementation of the Data interface\nusing a slice of string. It always returns 0 for Float1D."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.selection", IDName: "selection", Fields: []types.Field{{Name: "n", Doc: "n is the number of labels selected."}, {Name: "lMin", Doc: "lMin and lMax are the selected min\nand max label values. lq is the q\nchosen."}, {Name: "lMax", Doc: "lMin and lMax are the selected min\nand max label values. lq is the q\nchosen."}, {Name: "lStep", Doc: "lMin and lMax are the selected min\nand max label values. lq is the q\nchosen."}, {Name: "lq", Doc: "lMin and lMax are the selected min\nand max label values. lq is the q\nchosen."}, {Name: "score", Doc: "score is the score for the selection."}, {Name: "magnitude", Doc: "magnitude is the magnitude of the\nlabel step distance."}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.weights", IDName: "weights", Doc: "weights is a helper type to calcuate the labelling scheme's total score.", Fields: []types.Field{{Name: "simplicity"}, {Name: "coverage"}, {Name: "density"}, {Name: "legibility"}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.LegendStyle", IDName: "legend-style", Doc: "LegendStyle has the styling properties for the Legend.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "Column", Doc: "Column is for table-based plotting, specifying the column with legend values."}, {Name: "Text", Doc: "Text is the style given to the legend entry texts."}, {Name: "Position", Doc: "position of the legend"}, {Name: "ThumbnailWidth", Doc: "ThumbnailWidth is the width of legend thumbnails."}, {Name: "Fill", Doc: "Fill specifies the background fill color for the legend box,\nif non-nil."}}})
// SetColumn sets the [LegendStyle.Column]:
// Column is for table-based plotting, specifying the column with legend values.
func (t *LegendStyle) SetColumn(v string) *LegendStyle { t.Column = v; return t }
// SetText sets the [LegendStyle.Text]:
// Text is the style given to the legend entry texts.
func (t *LegendStyle) SetText(v TextStyle) *LegendStyle { t.Text = v; return t }
// SetPosition sets the [LegendStyle.Position]:
// position of the legend
func (t *LegendStyle) SetPosition(v LegendPosition) *LegendStyle { t.Position = v; return t }
// SetThumbnailWidth sets the [LegendStyle.ThumbnailWidth]:
// ThumbnailWidth is the width of legend thumbnails.
func (t *LegendStyle) SetThumbnailWidth(v units.Value) *LegendStyle { t.ThumbnailWidth = v; return t }
// SetFill sets the [LegendStyle.Fill]:
// Fill specifies the background fill color for the legend box,
// if non-nil.
func (t *LegendStyle) SetFill(v image.Image) *LegendStyle { t.Fill = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.LegendPosition", IDName: "legend-position", Doc: "LegendPosition specifies where to put the legend", Fields: []types.Field{{Name: "Top", Doc: "Top and Left specify the location of the legend."}, {Name: "Left", Doc: "Top and Left specify the location of the legend."}, {Name: "XOffs", Doc: "XOffs and YOffs are added to the legend's final position,\nrelative to the relevant anchor position"}, {Name: "YOffs", Doc: "XOffs and YOffs are added to the legend's final position,\nrelative to the relevant anchor position"}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Legend", IDName: "legend", Doc: "A Legend gives a description of the meaning of different\ndata elements of the plot. Each legend entry has a name\nand a thumbnail, where the thumbnail shows a small\nsample of the display style of the corresponding data.", Fields: []types.Field{{Name: "Style", Doc: "Style has the legend styling parameters."}, {Name: "Entries", Doc: "Entries are all of the LegendEntries described by this legend."}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Thumbnailer", IDName: "thumbnailer", Doc: "Thumbnailer wraps the Thumbnail method, which draws the small\nimage in a legend representing the style of data.", Methods: []types.Method{{Name: "Thumbnail", Doc: "Thumbnail draws an thumbnail representing a legend entry.\nThe thumbnail will usually show a smaller representation\nof the style used to plot the corresponding data.", Args: []string{"pt"}}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.LegendEntry", IDName: "legend-entry", Doc: "A LegendEntry represents a single line of a legend, it\nhas a name and an icon.", Fields: []types.Field{{Name: "Text", Doc: "text is the text associated with this entry."}, {Name: "Thumbs", Doc: "thumbs is a slice of all of the thumbnails styles"}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.LineStyle", IDName: "line-style", Doc: "LineStyle has style properties for drawing lines.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "On", Doc: "On indicates whether to plot lines."}, {Name: "Color", Doc: "Color is the stroke color image specification.\nSetting to nil turns line off."}, {Name: "Width", Doc: "Width is the line width, with a default of 1 Pt (point).\nSetting to 0 turns line off."}, {Name: "Dashes", Doc: "Dashes are the dashes of the stroke. Each pair of values specifies\nthe amount to paint and then the amount to skip."}, {Name: "Fill", Doc: "Fill is the color to fill solid regions, in a plot-specific\nway (e.g., the area below a Line plot, the bar color).\nUse nil to disable filling."}, {Name: "NegativeX", Doc: "NegativeX specifies whether to draw lines that connect points with a negative\nX-axis direction; otherwise there is a break in the line.\ndefault is false, so that repeated series of data across the X axis\nare plotted separately."}, {Name: "Step", Doc: "Step specifies how to step the line between points."}}})
// SetOn sets the [LineStyle.On]:
// On indicates whether to plot lines.
func (t *LineStyle) SetOn(v DefaultOffOn) *LineStyle { t.On = v; return t }
// SetColor sets the [LineStyle.Color]:
// Color is the stroke color image specification.
// Setting to nil turns line off.
func (t *LineStyle) SetColor(v image.Image) *LineStyle { t.Color = v; return t }
// SetWidth sets the [LineStyle.Width]:
// Width is the line width, with a default of 1 Pt (point).
// Setting to 0 turns line off.
func (t *LineStyle) SetWidth(v units.Value) *LineStyle { t.Width = v; return t }
// SetDashes sets the [LineStyle.Dashes]:
// Dashes are the dashes of the stroke. Each pair of values specifies
// the amount to paint and then the amount to skip.
func (t *LineStyle) SetDashes(v ...float32) *LineStyle { t.Dashes = v; return t }
// SetFill sets the [LineStyle.Fill]:
// Fill is the color to fill solid regions, in a plot-specific
// way (e.g., the area below a Line plot, the bar color).
// Use nil to disable filling.
func (t *LineStyle) SetFill(v image.Image) *LineStyle { t.Fill = v; return t }
// SetNegativeX sets the [LineStyle.NegativeX]:
// NegativeX specifies whether to draw lines that connect points with a negative
// X-axis direction; otherwise there is a break in the line.
// default is false, so that repeated series of data across the X axis
// are plotted separately.
func (t *LineStyle) SetNegativeX(v bool) *LineStyle { t.NegativeX = v; return t }
// SetStep sets the [LineStyle.Step]:
// Step specifies how to step the line between points.
func (t *LineStyle) SetStep(v StepKind) *LineStyle { t.Step = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.StepKind", IDName: "step-kind", Doc: "StepKind specifies a form of a connection of two consecutive points."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.XAxisStyle", IDName: "x-axis-style", Doc: "XAxisStyle has overall plot level styling properties for the XAxis.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "Column", Doc: "Column specifies the column to use for the common X axis,\nfor [plot.NewTablePlot] [table.Table] driven plots.\nIf empty, standard Group-based role binding is used: the last column\nwithin the same group with Role=X is used."}, {Name: "Rotation", Doc: "Rotation is the rotation of the X Axis labels, in degrees."}, {Name: "Label", Doc: "Label is the optional label to use for the XAxis instead of the default."}, {Name: "Range", Doc: "Range is the effective range of XAxis data to plot, where either end can be fixed."}, {Name: "Scale", Doc: "Scale specifies how values are scaled along the X axis:\nLinear, Log, Inverted"}}})
// SetColumn sets the [XAxisStyle.Column]:
// Column specifies the column to use for the common X axis,
// for [plot.NewTablePlot] [table.Table] driven plots.
// If empty, standard Group-based role binding is used: the last column
// within the same group with Role=X is used.
func (t *XAxisStyle) SetColumn(v string) *XAxisStyle { t.Column = v; return t }
// SetRotation sets the [XAxisStyle.Rotation]:
// Rotation is the rotation of the X Axis labels, in degrees.
func (t *XAxisStyle) SetRotation(v float32) *XAxisStyle { t.Rotation = v; return t }
// SetLabel sets the [XAxisStyle.Label]:
// Label is the optional label to use for the XAxis instead of the default.
func (t *XAxisStyle) SetLabel(v string) *XAxisStyle { t.Label = v; return t }
// SetRange sets the [XAxisStyle.Range]:
// Range is the effective range of XAxis data to plot, where either end can be fixed.
func (t *XAxisStyle) SetRange(v minmax.Range64) *XAxisStyle { t.Range = v; return t }
// SetScale sets the [XAxisStyle.Scale]:
// Scale specifies how values are scaled along the X axis:
// Linear, Log, Inverted
func (t *XAxisStyle) SetScale(v AxisScales) *XAxisStyle { t.Scale = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.PlotStyle", IDName: "plot-style", Doc: "PlotStyle has overall plot level styling properties.\nSome properties provide defaults for individual elements, which can\nthen be overwritten by element-level properties.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "Title", Doc: "Title is the overall title of the plot."}, {Name: "TitleStyle", Doc: "TitleStyle is the text styling parameters for the title."}, {Name: "Background", Doc: "Background is the background of the plot.\nThe default is [colors.Scheme.Surface]."}, {Name: "Scale", Doc: "Scale multiplies the plot DPI value, to change the overall scale\nof the rendered plot. Larger numbers produce larger scaling.\nTypically use larger numbers when generating plots for inclusion in\ndocuments or other cases where the overall plot size will be small."}, {Name: "Legend", Doc: "Legend has the styling properties for the Legend."}, {Name: "Axis", Doc: "Axis has the styling properties for the Axes."}, {Name: "XAxis", Doc: "XAxis has plot-level XAxis style properties."}, {Name: "YAxisLabel", Doc: "YAxisLabel is the optional label to use for the YAxis instead of the default."}, {Name: "LinesOn", Doc: "LinesOn determines whether lines are plotted by default,\nfor elements that plot lines (e.g., plots.XY)."}, {Name: "LineWidth", Doc: "LineWidth sets the default line width for data plotting lines."}, {Name: "PointsOn", Doc: "PointsOn determines whether points are plotted by default,\nfor elements that plot points (e.g., plots.XY)."}, {Name: "PointSize", Doc: "PointSize sets the default point size."}, {Name: "LabelSize", Doc: "LabelSize sets the default label text size."}, {Name: "BarWidth", Doc: "BarWidth for Bar plot sets the default width of the bars,\nwhich should be less than the Stride (1 typically) to prevent\nbar overlap. Defaults to .8."}}})
// SetTitle sets the [PlotStyle.Title]:
// Title is the overall title of the plot.
func (t *PlotStyle) SetTitle(v string) *PlotStyle { t.Title = v; return t }
// SetTitleStyle sets the [PlotStyle.TitleStyle]:
// TitleStyle is the text styling parameters for the title.
func (t *PlotStyle) SetTitleStyle(v TextStyle) *PlotStyle { t.TitleStyle = v; return t }
// SetBackground sets the [PlotStyle.Background]:
// Background is the background of the plot.
// The default is [colors.Scheme.Surface].
func (t *PlotStyle) SetBackground(v image.Image) *PlotStyle { t.Background = v; return t }
// SetScale sets the [PlotStyle.Scale]:
// Scale multiplies the plot DPI value, to change the overall scale
// of the rendered plot. Larger numbers produce larger scaling.
// Typically use larger numbers when generating plots for inclusion in
// documents or other cases where the overall plot size will be small.
func (t *PlotStyle) SetScale(v float32) *PlotStyle { t.Scale = v; return t }
// SetLegend sets the [PlotStyle.Legend]:
// Legend has the styling properties for the Legend.
func (t *PlotStyle) SetLegend(v LegendStyle) *PlotStyle { t.Legend = v; return t }
// SetAxis sets the [PlotStyle.Axis]:
// Axis has the styling properties for the Axes.
func (t *PlotStyle) SetAxis(v AxisStyle) *PlotStyle { t.Axis = v; return t }
// SetXAxis sets the [PlotStyle.XAxis]:
// XAxis has plot-level XAxis style properties.
func (t *PlotStyle) SetXAxis(v XAxisStyle) *PlotStyle { t.XAxis = v; return t }
// SetYAxisLabel sets the [PlotStyle.YAxisLabel]:
// YAxisLabel is the optional label to use for the YAxis instead of the default.
func (t *PlotStyle) SetYAxisLabel(v string) *PlotStyle { t.YAxisLabel = v; return t }
// SetLinesOn sets the [PlotStyle.LinesOn]:
// LinesOn determines whether lines are plotted by default,
// for elements that plot lines (e.g., plots.XY).
func (t *PlotStyle) SetLinesOn(v DefaultOffOn) *PlotStyle { t.LinesOn = v; return t }
// SetLineWidth sets the [PlotStyle.LineWidth]:
// LineWidth sets the default line width for data plotting lines.
func (t *PlotStyle) SetLineWidth(v units.Value) *PlotStyle { t.LineWidth = v; return t }
// SetPointsOn sets the [PlotStyle.PointsOn]:
// PointsOn determines whether points are plotted by default,
// for elements that plot points (e.g., plots.XY).
func (t *PlotStyle) SetPointsOn(v DefaultOffOn) *PlotStyle { t.PointsOn = v; return t }
// SetPointSize sets the [PlotStyle.PointSize]:
// PointSize sets the default point size.
func (t *PlotStyle) SetPointSize(v units.Value) *PlotStyle { t.PointSize = v; return t }
// SetLabelSize sets the [PlotStyle.LabelSize]:
// LabelSize sets the default label text size.
func (t *PlotStyle) SetLabelSize(v units.Value) *PlotStyle { t.LabelSize = v; return t }
// SetBarWidth sets the [PlotStyle.BarWidth]:
// BarWidth for Bar plot sets the default width of the bars,
// which should be less than the Stride (1 typically) to prevent
// bar overlap. Defaults to .8.
func (t *PlotStyle) SetBarWidth(v float64) *PlotStyle { t.BarWidth = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.PanZoom", IDName: "pan-zoom", Doc: "PanZoom provides post-styling pan and zoom range manipulation.", Fields: []types.Field{{Name: "XOffset", Doc: "XOffset adds offset to X range (pan)."}, {Name: "XScale", Doc: "XScale multiplies X range (zoom)."}, {Name: "YOffset", Doc: "YOffset adds offset to Y range (pan)."}, {Name: "YScale", Doc: "YScale multiplies Y range (zoom)."}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Plot", IDName: "plot", Doc: "Plot is the basic type representing a plot.\nIt renders into its own image.RGBA Pixels image,\nand can also save a corresponding SVG version.", Fields: []types.Field{{Name: "Title", Doc: "Title of the plot"}, {Name: "Style", Doc: "Style has the styling properties for the plot."}, {Name: "StandardTextStyle", Doc: "standard text style with default options"}, {Name: "X", Doc: "X, Y, and Z are the horizontal, vertical, and depth axes\nof the plot respectively."}, {Name: "Y", Doc: "X, Y, and Z are the horizontal, vertical, and depth axes\nof the plot respectively."}, {Name: "Z", Doc: "X, Y, and Z are the horizontal, vertical, and depth axes\nof the plot respectively."}, {Name: "Legend", Doc: "Legend is the plot's legend."}, {Name: "Plotters", Doc: "Plotters are drawn by calling their Plot method after the axes are drawn."}, {Name: "Size", Doc: "Size is the target size of the image to render to."}, {Name: "DPI", Doc: "DPI is the dots per inch for rendering the image.\nLarger numbers result in larger scaling of the plot contents\nwhich is strongly recommended for print (e.g., use 300 for print)"}, {Name: "PanZoom", Doc: "PanZoom provides post-styling pan and zoom range factors."}, {Name: "HighlightPlotter", Doc: "\tHighlightPlotter is the Plotter to highlight. Used for mouse hovering for example.\nIt is the responsibility of the Plotter Plot function to implement highlighting."}, {Name: "HighlightIndex", Doc: "HighlightIndex is the index of the data point to highlight, for HighlightPlotter."}, {Name: "Pixels", Doc: "pixels that we render into"}, {Name: "Paint", Doc: "Paint is the painter for rendering"}, {Name: "PlotBox", Doc: "Current plot bounding box in image coordinates, for plotting coordinates"}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Plotter", IDName: "plotter", Doc: "Plotter is an interface that wraps the Plot method.\nStandard implementations of Plotter are in the [plots] package.", Methods: []types.Method{{Name: "Plot", Doc: "Plot draws the data to the Plot Paint.", Args: []string{"pt"}}, {Name: "UpdateRange", Doc: "UpdateRange updates the given ranges.", Args: []string{"plt", "xr", "yr", "zr"}}, {Name: "Data", Doc: "Data returns the data by roles for this plot, for both the original\ndata and the pixel-transformed X,Y coordinates for that data.\nThis allows a GUI interface to inspect data etc.", Returns: []string{"data", "pixX", "pixY"}}, {Name: "Stylers", Doc: "Stylers returns the styler functions for this element.", Returns: []string{"Stylers"}}, {Name: "ApplyStyle", Doc: "ApplyStyle applies any stylers to this element,\nfirst initializing from the given global plot style, which has\nalready been styled with defaults and all the plot element stylers.", Args: []string{"plotStyle"}}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.PlotterType", IDName: "plotter-type", Doc: "PlotterType registers a Plotter so that it can be created with appropriate data.", Fields: []types.Field{{Name: "Name", Doc: "Name of the plot type."}, {Name: "Doc", Doc: "Doc is the documentation for this Plotter."}, {Name: "Required", Doc: "Required Data roles for this plot. Data for these Roles must be provided."}, {Name: "Optional", Doc: "Optional Data roles for this plot."}, {Name: "New", Doc: "New returns a new plotter of this type with given data in given roles."}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.PlotterName", IDName: "plotter-name", Doc: "PlotterName is the name of a specific plotter type."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.PointStyle", IDName: "point-style", Doc: "PointStyle has style properties for drawing points as different shapes.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "On", Doc: "On indicates whether to plot points."}, {Name: "Shape", Doc: "Shape to draw."}, {Name: "Color", Doc: "Color is the stroke color image specification.\nSetting to nil turns line off."}, {Name: "Fill", Doc: "Fill is the color to fill solid regions, in a plot-specific\nway (e.g., the area below a Line plot, the bar color).\nUse nil to disable filling."}, {Name: "Width", Doc: "Width is the line width for point glyphs, with a default of 1 Pt (point).\nSetting to 0 turns line off."}, {Name: "Size", Doc: "Size of shape to draw for each point.\nDefaults to 4 Pt (point)."}}})
// SetOn sets the [PointStyle.On]:
// On indicates whether to plot points.
func (t *PointStyle) SetOn(v DefaultOffOn) *PointStyle { t.On = v; return t }
// SetShape sets the [PointStyle.Shape]:
// Shape to draw.
func (t *PointStyle) SetShape(v Shapes) *PointStyle { t.Shape = v; return t }
// SetColor sets the [PointStyle.Color]:
// Color is the stroke color image specification.
// Setting to nil turns line off.
func (t *PointStyle) SetColor(v image.Image) *PointStyle { t.Color = v; return t }
// SetFill sets the [PointStyle.Fill]:
// Fill is the color to fill solid regions, in a plot-specific
// way (e.g., the area below a Line plot, the bar color).
// Use nil to disable filling.
func (t *PointStyle) SetFill(v image.Image) *PointStyle { t.Fill = v; return t }
// SetWidth sets the [PointStyle.Width]:
// Width is the line width for point glyphs, with a default of 1 Pt (point).
// Setting to 0 turns line off.
func (t *PointStyle) SetWidth(v units.Value) *PointStyle { t.Width = v; return t }
// SetSize sets the [PointStyle.Size]:
// Size of shape to draw for each point.
// Defaults to 4 Pt (point).
func (t *PointStyle) SetSize(v units.Value) *PointStyle { t.Size = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Shapes", IDName: "shapes", Doc: "Shapes has the options for how to draw points in the plot."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Style", IDName: "style", Doc: "Style contains the plot styling properties relevant across\nmost plot types. These properties apply to individual plot elements\nwhile the Plot properties applies to the overall plot itself.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "Plot", Doc: "\tPlot has overall plot-level properties, which can be set by any\nplot element, and are updated first, before applying element-wise styles."}, {Name: "On", Doc: "On specifies whether to plot this item, for table-based plots."}, {Name: "Plotter", Doc: "Plotter is the type of plotter to use in plotting this data,\nfor [plot.NewTablePlot] [table.Table] driven plots.\nBlank means use default ([plots.XY] is overall default)."}, {Name: "Role", Doc: "Role specifies how a particular column of data should be used,\nfor [plot.NewTablePlot] [table.Table] driven plots."}, {Name: "Group", Doc: "Group specifies a group of related data items,\nfor [plot.NewTablePlot] [table.Table] driven plots,\nwhere different columns of data within the same Group play different Roles."}, {Name: "Range", Doc: "Range is the effective range of data to plot, where either end can be fixed."}, {Name: "Label", Doc: "Label provides an alternative label to use for axis, if set."}, {Name: "NoLegend", Doc: "NoLegend excludes this item from the legend when it otherwise would be included,\nfor [plot.NewTablePlot] [table.Table] driven plots.\nRole = Y values are included in the Legend by default."}, {Name: "NTicks", Doc: "NTicks sets the desired number of ticks for the axis, if > 0."}, {Name: "Line", Doc: "Line has style properties for drawing lines."}, {Name: "Point", Doc: "Point has style properties for drawing points."}, {Name: "Text", Doc: "Text has style properties for rendering text."}, {Name: "Width", Doc: "Width has various plot width properties."}}})
// SetPlot sets the [Style.Plot]:
//
// Plot has overall plot-level properties, which can be set by any
//
// plot element, and are updated first, before applying element-wise styles.
func (t *Style) SetPlot(v PlotStyle) *Style { t.Plot = v; return t }
// SetOn sets the [Style.On]:
// On specifies whether to plot this item, for table-based plots.
func (t *Style) SetOn(v bool) *Style { t.On = v; return t }
// SetPlotter sets the [Style.Plotter]:
// Plotter is the type of plotter to use in plotting this data,
// for [plot.NewTablePlot] [table.Table] driven plots.
// Blank means use default ([plots.XY] is overall default).
func (t *Style) SetPlotter(v PlotterName) *Style { t.Plotter = v; return t }
// SetRole sets the [Style.Role]:
// Role specifies how a particular column of data should be used,
// for [plot.NewTablePlot] [table.Table] driven plots.
func (t *Style) SetRole(v Roles) *Style { t.Role = v; return t }
// SetGroup sets the [Style.Group]:
// Group specifies a group of related data items,
// for [plot.NewTablePlot] [table.Table] driven plots,
// where different columns of data within the same Group play different Roles.
func (t *Style) SetGroup(v string) *Style { t.Group = v; return t }
// SetRange sets the [Style.Range]:
// Range is the effective range of data to plot, where either end can be fixed.
func (t *Style) SetRange(v minmax.Range64) *Style { t.Range = v; return t }
// SetLabel sets the [Style.Label]:
// Label provides an alternative label to use for axis, if set.
func (t *Style) SetLabel(v string) *Style { t.Label = v; return t }
// SetNoLegend sets the [Style.NoLegend]:
// NoLegend excludes this item from the legend when it otherwise would be included,
// for [plot.NewTablePlot] [table.Table] driven plots.
// Role = Y values are included in the Legend by default.
func (t *Style) SetNoLegend(v bool) *Style { t.NoLegend = v; return t }
// SetNTicks sets the [Style.NTicks]:
// NTicks sets the desired number of ticks for the axis, if > 0.
func (t *Style) SetNTicks(v int) *Style { t.NTicks = v; return t }
// SetLine sets the [Style.Line]:
// Line has style properties for drawing lines.
func (t *Style) SetLine(v LineStyle) *Style { t.Line = v; return t }
// SetPoint sets the [Style.Point]:
// Point has style properties for drawing points.
func (t *Style) SetPoint(v PointStyle) *Style { t.Point = v; return t }
// SetText sets the [Style.Text]:
// Text has style properties for rendering text.
func (t *Style) SetText(v TextStyle) *Style { t.Text = v; return t }
// SetWidth sets the [Style.Width]:
// Width has various plot width properties.
func (t *Style) SetWidth(v WidthStyle) *Style { t.Width = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.WidthStyle", IDName: "width-style", Doc: "WidthStyle contains various plot width properties relevant across\ndifferent plot types.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "Cap", Doc: "Cap is the width of the caps drawn at the top of error bars.\nThe default is 10dp"}, {Name: "Offset", Doc: "Offset for Bar plot is the offset added to each X axis value\nrelative to the Stride computed value (X = offset + index * Stride)\nDefaults to 0."}, {Name: "Stride", Doc: "Stride for Bar plot is distance between bars. Defaults to 1."}, {Name: "Width", Doc: "Width for Bar plot is the width of the bars, as a fraction of the Stride,\nto prevent bar overlap. Defaults to .8."}, {Name: "Pad", Doc: "Pad for Bar plot is additional space at start / end of data range,\nto keep bars from overflowing ends. This amount is subtracted from Offset\nand added to (len(Values)-1)*Stride -- no other accommodation for bar\nwidth is provided, so that should be built into this value as well.\nDefaults to 1."}}})
// SetCap sets the [WidthStyle.Cap]:
// Cap is the width of the caps drawn at the top of error bars.
// The default is 10dp
func (t *WidthStyle) SetCap(v units.Value) *WidthStyle { t.Cap = v; return t }
// SetOffset sets the [WidthStyle.Offset]:
// Offset for Bar plot is the offset added to each X axis value
// relative to the Stride computed value (X = offset + index * Stride)
// Defaults to 0.
func (t *WidthStyle) SetOffset(v float64) *WidthStyle { t.Offset = v; return t }
// SetStride sets the [WidthStyle.Stride]:
// Stride for Bar plot is distance between bars. Defaults to 1.
func (t *WidthStyle) SetStride(v float64) *WidthStyle { t.Stride = v; return t }
// SetWidth sets the [WidthStyle.Width]:
// Width for Bar plot is the width of the bars, as a fraction of the Stride,
// to prevent bar overlap. Defaults to .8.
func (t *WidthStyle) SetWidth(v float64) *WidthStyle { t.Width = v; return t }
// SetPad sets the [WidthStyle.Pad]:
// Pad for Bar plot is additional space at start / end of data range,
// to keep bars from overflowing ends. This amount is subtracted from Offset
// and added to (len(Values)-1)*Stride -- no other accommodation for bar
// width is provided, so that should be built into this value as well.
// Defaults to 1.
func (t *WidthStyle) SetPad(v float64) *WidthStyle { t.Pad = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Stylers", IDName: "stylers", Doc: "Stylers is a list of styling functions that set Style properties.\nThese are called in the order added."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.DefaultOffOn", IDName: "default-off-on", Doc: "DefaultOffOn specifies whether to use the default value for a bool option,\nor to override the default and set Off or On."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.TextStyle", IDName: "text-style", Doc: "TextStyle specifies styling parameters for Text elements.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "Size", Doc: "Size of font to render. Default is 16dp"}, {Name: "Family", Doc: "Family name for font (inherited): ordered list of comma-separated names\nfrom more general to more specific to use. Use split on, to parse."}, {Name: "Color", Doc: "Color of text."}, {Name: "Align", Doc: "Align specifies how to align text along the relevant\ndimension for the text element."}, {Name: "Padding", Doc: "Padding is used in a case-dependent manner to add\nspace around text elements."}, {Name: "Rotation", Doc: "Rotation of the text, in degrees."}, {Name: "Offset", Doc: "Offset is added directly to the final label location."}}})
// SetSize sets the [TextStyle.Size]:
// Size of font to render. Default is 16dp
func (t *TextStyle) SetSize(v units.Value) *TextStyle { t.Size = v; return t }
// SetFamily sets the [TextStyle.Family]:
// Family name for font (inherited): ordered list of comma-separated names
// from more general to more specific to use. Use split on, to parse.
func (t *TextStyle) SetFamily(v string) *TextStyle { t.Family = v; return t }
// SetColor sets the [TextStyle.Color]:
// Color of text.
func (t *TextStyle) SetColor(v image.Image) *TextStyle { t.Color = v; return t }
// SetAlign sets the [TextStyle.Align]:
// Align specifies how to align text along the relevant
// dimension for the text element.
func (t *TextStyle) SetAlign(v styles.Aligns) *TextStyle { t.Align = v; return t }
// SetPadding sets the [TextStyle.Padding]:
// Padding is used in a case-dependent manner to add
// space around text elements.
func (t *TextStyle) SetPadding(v units.Value) *TextStyle { t.Padding = v; return t }
// SetRotation sets the [TextStyle.Rotation]:
// Rotation of the text, in degrees.
func (t *TextStyle) SetRotation(v float32) *TextStyle { t.Rotation = v; return t }
// SetOffset sets the [TextStyle.Offset]:
// Offset is added directly to the final label location.
func (t *TextStyle) SetOffset(v units.XY) *TextStyle { t.Offset = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Text", IDName: "text", Doc: "Text specifies a single text element in a plot", Fields: []types.Field{{Name: "Text", Doc: "text string, which can use HTML formatting"}, {Name: "Style", Doc: "styling for this text element"}, {Name: "font", Doc: "font has the full font rendering styles."}, {Name: "PaintText", Doc: "PaintText is the [paint.Text] for the text."}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Tick", IDName: "tick", Doc: "A Tick is a single tick mark on an axis.", Fields: []types.Field{{Name: "Value", Doc: "Value is the data value marked by this Tick."}, {Name: "Label", Doc: "Label is the text to display at the tick mark.\nIf Label is an empty string then this is a minor tick mark."}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Ticker", IDName: "ticker", Doc: "Ticker creates Ticks in a specified range", Methods: []types.Method{{Name: "Ticks", Doc: "Ticks returns Ticks in a specified range, with desired number of ticks,\nwhich can be ignored depending on the ticker type.", Args: []string{"min", "max", "nticks"}, Returns: []string{"Tick"}}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.DefaultTicks", IDName: "default-ticks", Doc: "DefaultTicks is suitable for the Ticker field of an Axis,\nit returns a reasonable default set of tick marks."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.LogTicks", IDName: "log-ticks", Doc: "LogTicks is suitable for the Ticker field of an Axis,\nit returns tick marks suitable for a log-scale axis.", Fields: []types.Field{{Name: "Prec", Doc: "Prec specifies the precision of tick rendering\naccording to the documentation for strconv.FormatFloat."}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.ConstantTicks", IDName: "constant-ticks", Doc: "ConstantTicks is suitable for the Ticker field of an Axis.\nThis function returns the given set of ticks."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.TimeTicks", IDName: "time-ticks", Doc: "TimeTicks is suitable for axes representing time values.", Fields: []types.Field{{Name: "Ticker", Doc: "Ticker is used to generate a set of ticks.\nIf nil, DefaultTicks will be used."}, {Name: "Format", Doc: "Format is the textual representation of the time value.\nIf empty, time.RFC3339 will be used"}, {Name: "Time", Doc: "Time takes a float32 value and converts it into a time.Time.\nIf nil, UTCUnixTime is used."}}})
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plotcore
import (
"fmt"
"image"
"image/draw"
"cogentcore.org/core/colors"
"cogentcore.org/core/core"
"cogentcore.org/core/cursors"
"cogentcore.org/core/events"
"cogentcore.org/core/events/key"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/abilities"
"cogentcore.org/core/styles/states"
"cogentcore.org/core/styles/units"
"cogentcore.org/lab/plot"
)
// Plot is a widget that renders a [plot.Plot] object.
// If it is not [states.ReadOnly], the user can pan and zoom the graph.
// See [PlotEditor] for an interactive interface for selecting columns to view.
type Plot struct {
core.WidgetBase
// Plot is the Plot to display in this widget
Plot *plot.Plot `set:"-"`
// SetRangesFunc, if set, is called to adjust the data ranges
// after the point when these ranges are updated based on the plot data.
SetRangesFunc func()
}
// SetPlot sets the plot to given Plot, and calls UpdatePlot to ensure it is
// drawn at the current size of this widget
func (pt *Plot) SetPlot(pl *plot.Plot) {
if pl != nil && pt.Plot != nil && pt.Plot.Pixels != nil {
pl.DPI = pt.Styles.UnitContext.DPI
pl.SetPixels(pt.Plot.Pixels) // re-use the image!
}
pt.Plot = pl
pt.updatePlot()
}
// updatePlot draws the current plot at the size of the current widget,
// and triggers a Render so the widget will be rendered.
func (pt *Plot) updatePlot() {
if pt.Plot == nil {
pt.NeedsRender()
return
}
sz := pt.Geom.Size.Actual.Content.ToPoint()
if sz == (image.Point{}) {
return
}
pt.Plot.DPI = pt.Styles.UnitContext.DPI
pt.Plot.Resize(sz)
if pt.SetRangesFunc != nil {
pt.SetRangesFunc()
}
pt.Plot.Draw()
pt.NeedsRender()
}
func (pt *Plot) Init() {
pt.WidgetBase.Init()
pt.Styler(func(s *styles.Style) {
s.Min.Set(units.Dp(256))
ro := pt.IsReadOnly()
s.SetAbilities(!ro, abilities.Slideable, abilities.Activatable, abilities.Scrollable)
if !ro {
if s.Is(states.Active) {
s.Cursor = cursors.Grabbing
s.StateLayer = 0
} else {
s.Cursor = cursors.Grab
}
}
})
pt.On(events.SlideMove, func(e events.Event) {
e.SetHandled()
if pt.Plot == nil {
return
}
xf, yf := 1.0, 1.0
if e.HasAnyModifier(key.Shift) {
yf = 0
} else if e.HasAnyModifier(key.Alt) {
xf = 0
}
del := e.PrevDelta()
dx := -float64(del.X) * (pt.Plot.X.Range.Range()) * 0.0008 * xf
dy := float64(del.Y) * (pt.Plot.Y.Range.Range()) * 0.0008 * yf
pt.Plot.PanZoom.XOffset += dx
pt.Plot.PanZoom.YOffset += dy
pt.updatePlot()
})
pt.On(events.Scroll, func(e events.Event) {
e.SetHandled()
if pt.Plot == nil {
return
}
se := e.(*events.MouseScroll)
sc := 1 + (float64(se.Delta.Y) * 0.002)
xsc, ysc := sc, sc
if e.HasAnyModifier(key.Shift) {
ysc = 1
} else if e.HasAnyModifier(key.Alt) {
xsc = 1
}
pt.Plot.PanZoom.XScale *= xsc
pt.Plot.PanZoom.YScale *= ysc
pt.updatePlot()
})
}
func (pt *Plot) WidgetTooltip(pos image.Point) (string, image.Point) {
if pos == image.Pt(-1, -1) {
return "_", image.Point{}
}
if pt.Plot == nil {
return pt.Tooltip, pt.DefaultTooltipPos()
}
wpos := pos.Sub(pt.Geom.ContentBBox.Min)
plt, _, idx, dist, _, data, legend := pt.Plot.ClosestDataToPixel(wpos.X, wpos.Y)
if dist <= 10 {
pt.Plot.HighlightPlotter = plt
pt.Plot.HighlightIndex = idx
pt.updatePlot()
dx := 0.0
if data[plot.X] != nil {
dx = data[plot.X].Float1D(idx)
}
dy := 0.0
if data[plot.Y] != nil {
dy = data[plot.Y].Float1D(idx)
}
return fmt.Sprintf("%s[%d]: (%g, %g)", legend, idx, dx, dy), pos
} else {
if pt.Plot.HighlightPlotter != nil {
pt.Plot.HighlightPlotter = nil
pt.updatePlot()
}
}
return pt.Tooltip, pt.DefaultTooltipPos()
}
func (pt *Plot) SizeFinal() {
pt.WidgetBase.SizeFinal()
pt.updatePlot()
}
func (pt *Plot) Render() {
pt.WidgetBase.Render()
r := pt.Geom.ContentBBox
sp := pt.Geom.ScrollOffset()
if pt.Plot == nil || pt.Plot.Pixels == nil {
draw.Draw(pt.Scene.Pixels, r, colors.Scheme.Surface, sp, draw.Src)
return
}
draw.Draw(pt.Scene.Pixels, r, pt.Plot.Pixels, sp, draw.Src)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package plotcore provides Cogent Core widgets for viewing and editing plots.
package plotcore
//go:generate core generate
import (
"fmt"
"io/fs"
"log/slog"
"path/filepath"
"slices"
"strings"
"time"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/base/iox/imagex"
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/base/reflectx"
"cogentcore.org/core/colors"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/icons"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/states"
"cogentcore.org/core/system"
"cogentcore.org/core/tree"
"cogentcore.org/lab/plot"
"cogentcore.org/lab/plot/plots"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorcore"
"golang.org/x/exp/maps"
)
// PlotEditor is a widget that provides an interactive 2D plot
// of selected columns of tabular data, represented by a [table.Table] into
// a [table.Table]. Other types of tabular data can be converted into this format.
// The user can change various options for the plot and also modify the underlying data.
type PlotEditor struct { //types:add
core.Frame
// table is the table of data being plotted.
table *table.Table
// PlotStyle has the overall plot style parameters.
PlotStyle plot.PlotStyle
// plot is the plot object.
plot *plot.Plot
// current svg file
svgFile core.Filename
// current csv data file
dataFile core.Filename
// currently doing a plot
inPlot bool
columnsFrame *core.Frame
plotWidget *Plot
plotStyleModified map[string]bool
}
func (pl *PlotEditor) CopyFieldsFrom(frm tree.Node) {
fr := frm.(*PlotEditor)
pl.Frame.CopyFieldsFrom(&fr.Frame)
pl.PlotStyle = fr.PlotStyle
pl.setTable(fr.table)
}
// NewSubPlot returns a [PlotEditor] with its own separate [core.Toolbar],
// suitable for a tab or other element that is not the main plot.
func NewSubPlot(parent ...tree.Node) *PlotEditor {
fr := core.NewFrame(parent...)
tb := core.NewToolbar(fr)
pl := NewPlotEditor(fr)
fr.Styler(func(s *styles.Style) {
s.Direction = styles.Column
s.Grow.Set(1, 1)
})
tb.Maker(pl.MakeToolbar)
return pl
}
func (pl *PlotEditor) Init() {
pl.Frame.Init()
pl.PlotStyle.Defaults()
pl.Styler(func(s *styles.Style) {
s.Grow.Set(1, 1)
if pl.SizeClass() == core.SizeCompact {
s.Direction = styles.Column
}
})
pl.OnShow(func(e events.Event) {
pl.UpdatePlot()
})
pl.Updater(func() {
if pl.table != nil {
pl.plotStyleFromTable(pl.table)
}
})
tree.AddChildAt(pl, "columns", func(w *core.Frame) {
pl.columnsFrame = w
w.Styler(func(s *styles.Style) {
s.Direction = styles.Column
s.Background = colors.Scheme.SurfaceContainerLow
if w.SizeClass() == core.SizeCompact {
s.Grow.Set(1, 0)
} else {
s.Grow.Set(0, 1)
s.Overflow.Y = styles.OverflowAuto
}
})
w.Maker(pl.makeColumns)
})
tree.AddChildAt(pl, "plot", func(w *Plot) {
pl.plotWidget = w
w.Plot = pl.plot
w.Styler(func(s *styles.Style) {
s.Grow.Set(1, 1)
})
})
}
// setTable sets the table to view and does Update
// to update the Column list, which will also trigger a Layout
// and updating of the plot on next render pass.
// This is safe to call from a different goroutine.
func (pl *PlotEditor) setTable(tab *table.Table) *PlotEditor {
pl.table = tab
pl.Update()
return pl
}
// SetTable sets the table to view and does Update
// to update the Column list, which will also trigger a Layout
// and updating of the plot on next render pass.
// This is safe to call from a different goroutine.
func (pl *PlotEditor) SetTable(tab *table.Table) *PlotEditor {
pl.table = table.NewView(tab)
pl.Update()
return pl
}
// SetSlice sets the table to a [table.NewSliceTable] from the given slice.
// Optional styler functions are used for each struct field in sequence,
// and any can contain global plot style.
func (pl *PlotEditor) SetSlice(sl any, stylers ...func(s *plot.Style)) *PlotEditor {
dt, err := table.NewSliceTable(sl)
errors.Log(err)
if dt == nil {
return nil
}
mx := min(dt.NumColumns(), len(stylers))
for i := range mx {
plot.SetStylersTo(dt.Columns.Values[i], plot.Stylers{stylers[i]})
}
return pl.SetTable(dt)
}
// SaveSVG saves the plot to an svg -- first updates to ensure that plot is current
func (pl *PlotEditor) SaveSVG(fname core.Filename) { //types:add
pl.UpdatePlot()
// TODO: get plot svg saving working
// pc := pl.PlotChild()
// SaveSVGView(string(fname), pl.Plot, sv, 2)
pl.svgFile = fname
}
// SavePNG saves the current plot to a png, capturing current render
func (pl *PlotEditor) SavePNG(fname core.Filename) { //types:add
pl.UpdatePlot()
imagex.Save(pl.plot.Pixels, string(fname))
}
// SaveCSV saves the Table data to a csv (comma-separated values) file with headers (any delim)
func (pl *PlotEditor) SaveCSV(fname core.Filename, delim tensor.Delims) { //types:add
pl.table.SaveCSV(fsx.Filename(fname), delim, table.Headers)
pl.dataFile = fname
}
// SaveAll saves the current plot to a png, svg, and the data to a tsv -- full save
// Any extension is removed and appropriate extensions are added
func (pl *PlotEditor) SaveAll(fname core.Filename) { //types:add
fn := string(fname)
fn = strings.TrimSuffix(fn, filepath.Ext(fn))
pl.SaveCSV(core.Filename(fn+".tsv"), tensor.Tab)
pl.SavePNG(core.Filename(fn + ".png"))
pl.SaveSVG(core.Filename(fn + ".svg"))
}
// OpenCSV opens the Table data from a csv (comma-separated values) file (or any delim)
func (pl *PlotEditor) OpenCSV(filename core.Filename, delim tensor.Delims) { //types:add
pl.table.OpenCSV(fsx.Filename(filename), delim)
pl.dataFile = filename
pl.UpdatePlot()
}
// OpenFS opens the Table data from a csv (comma-separated values) file (or any delim)
// from the given filesystem.
func (pl *PlotEditor) OpenFS(fsys fs.FS, filename core.Filename, delim tensor.Delims) {
pl.table.OpenFS(fsys, string(filename), delim)
pl.dataFile = filename
pl.UpdatePlot()
}
// GoUpdatePlot updates the display based on current Indexed view into table.
// This version can be called from goroutines. It does Sequential() on
// the [table.Table], under the assumption that it is used for tracking a
// the latest updates of a running process.
func (pl *PlotEditor) GoUpdatePlot() {
if pl == nil || pl.This == nil {
return
}
if core.TheApp.Platform() == system.Web {
time.Sleep(time.Millisecond) // critical to prevent hanging!
}
if !pl.IsVisible() || pl.table == nil || pl.inPlot {
return
}
pl.Scene.AsyncLock()
pl.table.Sequential()
pl.genPlot()
pl.NeedsRender()
pl.Scene.AsyncUnlock()
}
// UpdatePlot updates the display based on current Indexed view into table.
// It does not automatically update the [table.Table] unless it is
// nil or out date.
func (pl *PlotEditor) UpdatePlot() {
if pl == nil || pl.This == nil {
return
}
if pl.table == nil || pl.inPlot {
return
}
if len(pl.Children) != 2 { // || len(pl.Columns) != pl.table.NumColumns() { // todo:
pl.Update()
}
if pl.table.NumRows() == 0 {
pl.table.Sequential()
}
pl.genPlot()
}
// genPlot generates the plot and renders it to SVG
// It surrounds operation with InPlot true / false to prevent multiple updates
func (pl *PlotEditor) genPlot() {
if pl.inPlot {
slog.Error("plot: in plot already") // note: this never seems to happen -- could probably nuke
return
}
pl.inPlot = true
if pl.table == nil {
pl.inPlot = false
return
}
if len(pl.table.Indexes) == 0 {
pl.table.Sequential()
} else {
lsti := pl.table.Indexes[pl.table.NumRows()-1]
if lsti >= pl.table.NumRows() { // out of date
pl.table.Sequential()
}
}
var err error
pl.plot, err = plot.NewTablePlot(pl.table)
if err != nil {
core.ErrorSnackbar(pl, fmt.Errorf("%s: %w", pl.PlotStyle.Title, err))
}
pl.plotWidget.SetPlot(pl.plot) // redraws etc
pl.inPlot = false
}
const plotColumnsHeaderN = 3
// allColumnsOff turns all columns off.
func (pl *PlotEditor) allColumnsOff() {
fr := pl.columnsFrame
for i, cli := range fr.Children {
if i < plotColumnsHeaderN {
continue
}
cl := cli.(*core.Frame)
sw := cl.Child(0).(*core.Switch)
sw.SetChecked(false)
sw.SendChange()
}
pl.Update()
}
// setColumnsByName turns columns on or off if their name contains
// the given string.
func (pl *PlotEditor) setColumnsByName(nameContains string, on bool) { //types:add
fr := pl.columnsFrame
for i, cli := range fr.Children {
if i < plotColumnsHeaderN {
continue
}
cl := cli.(*core.Frame)
if !strings.Contains(cl.Name, nameContains) {
continue
}
sw := cl.Child(0).(*core.Switch)
sw.SetChecked(on)
sw.SendChange()
}
pl.Update()
}
// makeColumns makes the Plans for columns
func (pl *PlotEditor) makeColumns(p *tree.Plan) {
tree.Add(p, func(w *core.Frame) {
tree.AddChild(w, func(w *core.Button) {
w.SetText("Clear").SetIcon(icons.ClearAll).SetType(core.ButtonAction)
w.SetTooltip("Turn all columns off")
w.OnClick(func(e events.Event) {
pl.allColumnsOff()
})
})
tree.AddChild(w, func(w *core.Button) {
w.SetText("Search").SetIcon(icons.Search).SetType(core.ButtonAction)
w.SetTooltip("Select columns by column name")
w.OnClick(func(e events.Event) {
core.CallFunc(pl, pl.setColumnsByName)
})
})
})
tree.Add(p, func(w *core.Separator) {})
if pl.table == nil {
return
}
colorIdx := 0 // index for color sequence -- skips various types
for ci, cl := range pl.table.Columns.Values {
cnm := pl.table.Columns.Keys[ci]
tree.AddAt(p, cnm, func(w *core.Frame) {
psty := plot.GetStylersFrom(cl)
cst, mods := pl.defaultColumnStyle(cl, ci, &colorIdx, psty)
stys := psty
stys.Add(func(s *plot.Style) {
mf := modFields(mods)
errors.Log(reflectx.CopyFields(s, cst, mf...))
errors.Log(reflectx.CopyFields(&s.Plot, &pl.PlotStyle, modFields(pl.plotStyleModified)...))
})
plot.SetStylersTo(cl, stys)
w.Styler(func(s *styles.Style) {
s.CenterAll()
})
tree.AddChild(w, func(w *core.Switch) {
w.SetType(core.SwitchCheckbox).SetTooltip("Turn this column on or off")
w.Styler(func(s *styles.Style) {
s.Color = cst.Line.Color
})
tree.AddChildInit(w, "stack", func(w *core.Frame) {
f := func(name string) {
tree.AddChildInit(w, name, func(w *core.Icon) {
w.Styler(func(s *styles.Style) {
s.Color = cst.Line.Color
})
})
}
f("icon-on")
f("icon-off")
f("icon-indeterminate")
})
w.OnChange(func(e events.Event) {
mods["On"] = true
cst.On = w.IsChecked()
pl.UpdatePlot()
})
w.Updater(func() {
xaxis := cst.Role == plot.X // || cp.Column == pl.Options.Legend
w.SetState(xaxis, states.Disabled, states.Indeterminate)
if xaxis {
cst.On = false
} else {
w.SetChecked(cst.On)
}
})
})
tree.AddChild(w, func(w *core.Button) {
tt := "[Edit all styling options for this column] " + metadata.Doc(cl)
w.SetText(cnm).SetType(core.ButtonAction).SetTooltip(tt)
w.OnClick(func(e events.Event) {
update := func() {
if core.TheApp.Platform().IsMobile() {
pl.Update()
return
}
// we must be async on multi-window platforms since
// it is coming from a separate window
pl.AsyncLock()
pl.Update()
pl.AsyncUnlock()
}
d := core.NewBody(cnm + " style properties")
fm := core.NewForm(d).SetStruct(cst)
fm.Modified = mods
fm.OnChange(func(e events.Event) {
update()
})
// d.AddTopBar(func(bar *core.Frame) {
// core.NewToolbar(bar).Maker(func(p *tree.Plan) {
// tree.Add(p, func(w *core.Button) {
// w.SetText("Set x-axis").OnClick(func(e events.Event) {
// pl.Options.XAxis = cp.Column
// update()
// })
// })
// tree.Add(p, func(w *core.Button) {
// w.SetText("Set legend").OnClick(func(e events.Event) {
// pl.Options.Legend = cp.Column
// update()
// })
// })
// })
// })
d.RunWindowDialog(pl)
})
})
})
}
}
// defaultColumnStyle initializes the column style with any existing stylers
// plus additional general defaults, returning the initially modified field names.
func (pl *PlotEditor) defaultColumnStyle(cl tensor.Values, ci int, colorIdx *int, psty plot.Stylers) (*plot.Style, map[string]bool) {
cst := &plot.Style{}
cst.Defaults()
if psty != nil {
psty.Run(cst)
}
mods := map[string]bool{}
isfloat := reflectx.KindIsFloat(cl.DataType())
if cst.Plotter == "" {
if isfloat {
cst.Plotter = plot.PlotterName(plots.XYType)
mods["Plotter"] = true
} else if cl.IsString() {
cst.Plotter = plot.PlotterName(plots.LabelsType)
mods["Plotter"] = true
}
}
if cst.Role == plot.NoRole {
mods["Role"] = true
if isfloat {
cst.Role = plot.Y
} else if cl.IsString() {
cst.Role = plot.Label
} else {
cst.Role = plot.X
}
}
if cst.Line.Color == colors.Scheme.OnSurface {
if cst.Role == plot.Y && isfloat {
spclr := colors.Uniform(colors.Spaced(*colorIdx))
cst.Line.Color = spclr
mods["Line.Color"] = true
cst.Point.Color = spclr
mods["Point.Color"] = true
if cst.Plotter == plots.BarType {
cst.Line.Fill = spclr
mods["Line.Fill"] = true
}
(*colorIdx)++
}
}
return cst, mods
}
func (pl *PlotEditor) plotStyleFromTable(dt *table.Table) {
if pl.plotStyleModified != nil { // already set
return
}
pst := &pl.PlotStyle
mods := map[string]bool{}
pl.plotStyleModified = mods
tst := &plot.Style{}
tst.Defaults()
tst.Plot.Defaults()
for _, cl := range pl.table.Columns.Values {
stl := plot.GetStylersFrom(cl)
if stl == nil {
continue
}
stl.Run(tst)
}
*pst = tst.Plot
if pst.PointsOn == plot.Default {
pst.PointsOn = plot.Off
mods["PointsOn"] = true
}
if pst.Title == "" {
pst.Title = metadata.Name(pl.table)
if pst.Title != "" {
mods["Title"] = true
}
}
}
// modFields returns the modified fields as field paths using . separators
func modFields(mods map[string]bool) []string {
fns := maps.Keys(mods)
rf := make([]string, 0, len(fns))
for _, f := range fns {
if mods[f] == false {
continue
}
fc := strings.ReplaceAll(f, " • ", ".")
rf = append(rf, fc)
}
slices.Sort(rf)
return rf
}
func (pl *PlotEditor) MakeToolbar(p *tree.Plan) {
if pl.table == nil {
return
}
tree.Add(p, func(w *core.Button) {
w.SetIcon(icons.PanTool).
SetTooltip("toggle the ability to zoom and pan the view").OnClick(func(e events.Event) {
pw := pl.plotWidget
pw.SetReadOnly(!pw.IsReadOnly())
pw.Restyle()
})
})
// tree.Add(p, func(w *core.Button) {
// w.SetIcon(icons.ArrowForward).
// SetTooltip("turn on select mode for selecting Plot elements").
// OnClick(func(e events.Event) {
// fmt.Println("this will select select mode")
// })
// })
tree.Add(p, func(w *core.Separator) {})
tree.Add(p, func(w *core.Button) {
w.SetText("Update").SetIcon(icons.Update).
SetTooltip("update fully redraws display, reflecting any new settings etc").
OnClick(func(e events.Event) {
pl.UpdateWidget()
pl.UpdatePlot()
})
})
tree.Add(p, func(w *core.Button) {
w.SetText("Style").SetIcon(icons.Settings).
SetTooltip("Style for how the plot is rendered").
OnClick(func(e events.Event) {
d := core.NewBody("Plot style")
fm := core.NewForm(d).SetStruct(&pl.PlotStyle)
fm.Modified = pl.plotStyleModified
fm.OnChange(func(e events.Event) {
pl.GoUpdatePlot()
})
d.RunWindowDialog(pl)
})
})
tree.Add(p, func(w *core.Button) {
w.SetText("Table").SetIcon(icons.Edit).
SetTooltip("open a Table window of the data").
OnClick(func(e events.Event) {
d := core.NewBody(pl.Name + " Data")
tv := tensorcore.NewTable(d).SetTable(pl.table)
d.AddTopBar(func(bar *core.Frame) {
core.NewToolbar(bar).Maker(tv.MakeToolbar)
})
d.RunWindowDialog(pl)
})
})
tree.Add(p, func(w *core.Separator) {})
tree.Add(p, func(w *core.Button) {
w.SetText("Save").SetIcon(icons.Save).SetMenu(func(m *core.Scene) {
core.NewFuncButton(m).SetFunc(pl.SaveSVG).SetIcon(icons.Save)
core.NewFuncButton(m).SetFunc(pl.SavePNG).SetIcon(icons.Save)
core.NewFuncButton(m).SetFunc(pl.SaveCSV).SetIcon(icons.Save)
core.NewSeparator(m)
core.NewFuncButton(m).SetFunc(pl.SaveAll).SetIcon(icons.Save)
})
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(pl.OpenCSV).SetIcon(icons.Open)
})
tree.Add(p, func(w *core.Separator) {})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(pl.table.FilterString).SetText("Filter").SetIcon(icons.FilterAlt)
w.SetAfterFunc(pl.UpdatePlot)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(pl.table.Sequential).SetText("Unfilter").SetIcon(icons.FilterAltOff)
w.SetAfterFunc(pl.UpdatePlot)
})
}
func (pt *PlotEditor) SizeFinal() {
pt.Frame.SizeFinal()
pt.UpdatePlot()
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plotcore
import (
"slices"
"cogentcore.org/core/core"
"cogentcore.org/lab/plot"
_ "cogentcore.org/lab/plot/plots"
"golang.org/x/exp/maps"
)
func init() {
core.AddValueType[plot.PlotterName, PlotterChooser]()
}
// PlotterChooser represents a [Plottername] value with a [core.Chooser]
// for selecting a plotter.
type PlotterChooser struct {
core.Chooser
}
func (fc *PlotterChooser) Init() {
fc.Chooser.Init()
pnms := maps.Keys(plot.Plotters)
slices.Sort(pnms)
fc.SetStrings(pnms...)
}
// Code generated by "core generate"; DO NOT EDIT.
package plotcore
import (
"cogentcore.org/core/tree"
"cogentcore.org/core/types"
"cogentcore.org/lab/plot"
)
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plotcore.Plot", IDName: "plot", Doc: "Plot is a widget that renders a [plot.Plot] object.\nIf it is not [states.ReadOnly], the user can pan and zoom the graph.\nSee [PlotEditor] for an interactive interface for selecting columns to view.", Embeds: []types.Field{{Name: "WidgetBase"}}, Fields: []types.Field{{Name: "Plot", Doc: "Plot is the Plot to display in this widget"}, {Name: "SetRangesFunc", Doc: "SetRangesFunc, if set, is called to adjust the data ranges\nafter the point when these ranges are updated based on the plot data."}}})
// NewPlot returns a new [Plot] with the given optional parent:
// Plot is a widget that renders a [plot.Plot] object.
// If it is not [states.ReadOnly], the user can pan and zoom the graph.
// See [PlotEditor] for an interactive interface for selecting columns to view.
func NewPlot(parent ...tree.Node) *Plot { return tree.New[Plot](parent...) }
// SetSetRangesFunc sets the [Plot.SetRangesFunc]:
// SetRangesFunc, if set, is called to adjust the data ranges
// after the point when these ranges are updated based on the plot data.
func (t *Plot) SetSetRangesFunc(v func()) *Plot { t.SetRangesFunc = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plotcore.PlotEditor", IDName: "plot-editor", Doc: "PlotEditor is a widget that provides an interactive 2D plot\nof selected columns of tabular data, represented by a [table.Table] into\na [table.Table]. Other types of tabular data can be converted into this format.\nThe user can change various options for the plot and also modify the underlying data.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Methods: []types.Method{{Name: "SaveSVG", Doc: "SaveSVG saves the plot to an svg -- first updates to ensure that plot is current", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"fname"}}, {Name: "SavePNG", Doc: "SavePNG saves the current plot to a png, capturing current render", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"fname"}}, {Name: "SaveCSV", Doc: "SaveCSV saves the Table data to a csv (comma-separated values) file with headers (any delim)", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"fname", "delim"}}, {Name: "SaveAll", Doc: "SaveAll saves the current plot to a png, svg, and the data to a tsv -- full save\nAny extension is removed and appropriate extensions are added", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"fname"}}, {Name: "OpenCSV", Doc: "OpenCSV opens the Table data from a csv (comma-separated values) file (or any delim)", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename", "delim"}}, {Name: "setColumnsByName", Doc: "setColumnsByName turns columns on or off if their name contains\nthe given string.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"nameContains", "on"}}}, Embeds: []types.Field{{Name: "Frame"}}, Fields: []types.Field{{Name: "table", Doc: "table is the table of data being plotted."}, {Name: "PlotStyle", Doc: "PlotStyle has the overall plot style parameters."}, {Name: "plot", Doc: "plot is the plot object."}, {Name: "svgFile", Doc: "current svg file"}, {Name: "dataFile", Doc: "current csv data file"}, {Name: "inPlot", Doc: "currently doing a plot"}, {Name: "columnsFrame"}, {Name: "plotWidget"}, {Name: "plotStyleModified"}}})
// NewPlotEditor returns a new [PlotEditor] with the given optional parent:
// PlotEditor is a widget that provides an interactive 2D plot
// of selected columns of tabular data, represented by a [table.Table] into
// a [table.Table]. Other types of tabular data can be converted into this format.
// The user can change various options for the plot and also modify the underlying data.
func NewPlotEditor(parent ...tree.Node) *PlotEditor { return tree.New[PlotEditor](parent...) }
// SetPlotStyle sets the [PlotEditor.PlotStyle]:
// PlotStyle has the overall plot style parameters.
func (t *PlotEditor) SetPlotStyle(v plot.PlotStyle) *PlotEditor { t.PlotStyle = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plotcore.PlotterChooser", IDName: "plotter-chooser", Doc: "PlotterChooser represents a [Plottername] value with a [core.Chooser]\nfor selecting a plotter.", Embeds: []types.Field{{Name: "Chooser"}}})
// NewPlotterChooser returns a new [PlotterChooser] with the given optional parent:
// PlotterChooser represents a [Plottername] value with a [core.Chooser]
// for selecting a plotter.
func NewPlotterChooser(parent ...tree.Node) *PlotterChooser {
return tree.New[PlotterChooser](parent...)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cluster
//go:generate core generate
import (
"fmt"
"math"
"math/rand"
"cogentcore.org/core/base/indent"
"cogentcore.org/lab/stats/stats"
"cogentcore.org/lab/tensor"
)
// todo: all of this data goes into the tensorfs
// Cluster makes a new dir, stuffs results in there!
// need a global "cwd" that it uses, so basically you cd
// to a dir, then call it.
// Node is one node in the cluster
type Node struct {
// index into original distance matrix; only valid for for terminal leaves.
Index int
// Distance value for this node, i.e., how far apart were all the kids from
// each other when this node was created. is 0 for leaf nodes
Dist float64
// ParDist is total aggregate distance from parents; The X axis offset at which our cluster starts.
ParDist float64
// Y is y-axis value for this node; if a parent, it is the average of its kids Y's,
// otherwise it counts down.
Y float64
// Kids are child nodes under this one.
Kids []*Node
}
// IsLeaf returns true if node is a leaf of the tree with no kids
func (nn *Node) IsLeaf() bool {
return len(nn.Kids) == 0
}
// Sprint prints to string
func (nn *Node) Sprint(labels tensor.Tensor, depth int) string {
if nn.IsLeaf() && labels != nil {
return labels.String1D(nn.Index) + " "
}
sv := fmt.Sprintf("\n%v%v: ", indent.Tabs(depth), nn.Dist)
for _, kn := range nn.Kids {
sv += kn.Sprint(labels, depth+1)
}
return sv
}
// Indexes collects all the indexes in this node
func (nn *Node) Indexes(ix []int, ctr *int) {
if nn.IsLeaf() {
ix[*ctr] = nn.Index
(*ctr)++
} else {
for _, kn := range nn.Kids {
kn.Indexes(ix, ctr)
}
}
}
// NewNode merges two nodes into a new node
func NewNode(na, nb *Node, dst float64) *Node {
nn := &Node{Dist: dst}
nn.Kids = []*Node{na, nb}
return nn
}
// TODO: this call signature does not fit with standard
// not sure how one might pack Node into a tensor
// Cluster implements agglomerative clustering, based on a
// distance matrix dmat, e.g., as computed by metric.Matrix method,
// using a metric that increases in value with greater dissimilarity.
// labels provides an optional String tensor list of labels for the elements
// of the distance matrix.
// This calls InitAllLeaves to initialize the root node with all of the leaves,
// and then Glom to do the iterative agglomerative clustering process.
// If you want to start with pre-defined initial clusters,
// then call Glom with a root node so-initialized.
func Cluster(funcName string, dmat, labels tensor.Tensor) *Node {
ntot := dmat.DimSize(0) // number of leaves
root := InitAllLeaves(ntot)
return Glom(root, funcName, dmat)
}
// InitAllLeaves returns a standard root node initialized with all of the leaves.
func InitAllLeaves(ntot int) *Node {
root := &Node{}
root.Kids = make([]*Node, ntot)
for i := 0; i < ntot; i++ {
root.Kids[i] = &Node{Index: i}
}
return root
}
// Glom does the iterative agglomerative clustering,
// based on a raw similarity matrix as given,
// using a root node that has already been initialized
// with the starting clusters, which is all of the
// leaves by default, but could be anything if you want
// to start with predefined clusters.
func Glom(root *Node, funcName string, dmat tensor.Tensor) *Node {
ntot := dmat.DimSize(0) // number of leaves
mout := tensor.NewFloat64Scalar(0)
stats.MaxOut(tensor.As1D(dmat), mout)
maxd := mout.Float1D(0)
// indexes in each group
aidx := make([]int, ntot)
bidx := make([]int, ntot)
for {
var ma, mb []int
mval := math.MaxFloat64
for ai, ka := range root.Kids {
actr := 0
ka.Indexes(aidx, &actr)
aix := aidx[0:actr]
for bi := 0; bi < ai; bi++ {
kb := root.Kids[bi]
bctr := 0
kb.Indexes(bidx, &bctr)
bix := bidx[0:bctr]
dv := Call(funcName, aix, bix, ntot, maxd, dmat)
if dv < mval {
mval = dv
ma = []int{ai}
mb = []int{bi}
} else if dv == mval { // do all ties at same time
ma = append(ma, ai)
mb = append(mb, bi)
}
}
}
ni := 0
if len(ma) > 1 {
ni = rand.Intn(len(ma))
}
na := ma[ni]
nb := mb[ni]
nn := NewNode(root.Kids[na], root.Kids[nb], mval)
for i := len(root.Kids) - 1; i >= 0; i-- {
if i == na || i == nb {
root.Kids = append(root.Kids[:i], root.Kids[i+1:]...)
}
}
root.Kids = append(root.Kids, nn)
if len(root.Kids) == 1 {
break
}
}
return root
}
// Code generated by "core generate"; DO NOT EDIT.
package cluster
import (
"cogentcore.org/core/enums"
)
var _MetricsValues = []Metrics{0, 1, 2, 3}
// MetricsN is the highest valid value for type Metrics, plus one.
const MetricsN Metrics = 4
var _MetricsValueMap = map[string]Metrics{`Min`: 0, `Max`: 1, `Avg`: 2, `Contrast`: 3}
var _MetricsDescMap = map[Metrics]string{0: `Min is the minimum-distance or single-linkage weighting function.`, 1: `Max is the maximum-distance or complete-linkage weighting function.`, 2: `Avg is the average-distance or average-linkage weighting function.`, 3: `Contrast computes maxd + (average within distance - average between distance).`}
var _MetricsMap = map[Metrics]string{0: `Min`, 1: `Max`, 2: `Avg`, 3: `Contrast`}
// String returns the string representation of this Metrics value.
func (i Metrics) String() string { return enums.String(i, _MetricsMap) }
// SetString sets the Metrics value from its string representation,
// and returns an error if the string is invalid.
func (i *Metrics) SetString(s string) error {
return enums.SetString(i, s, _MetricsValueMap, "Metrics")
}
// Int64 returns the Metrics value as an int64.
func (i Metrics) Int64() int64 { return int64(i) }
// SetInt64 sets the Metrics value from an int64.
func (i *Metrics) SetInt64(in int64) { *i = Metrics(in) }
// Desc returns the description of the Metrics value.
func (i Metrics) Desc() string { return enums.Desc(i, _MetricsDescMap) }
// MetricsValues returns all possible values for the type Metrics.
func MetricsValues() []Metrics { return _MetricsValues }
// Values returns all possible values for the type Metrics.
func (i Metrics) Values() []enums.Enum { return enums.Values(_MetricsValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i Metrics) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *Metrics) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Metrics") }
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cluster
import (
"log/slog"
"math"
"cogentcore.org/lab/tensor"
)
// Metrics are standard clustering distance metric functions,
// specifying how a node computes its distance based on its leaves.
type Metrics int32 //enums:enum
const (
// Min is the minimum-distance or single-linkage weighting function.
Min Metrics = iota
// Max is the maximum-distance or complete-linkage weighting function.
Max
// Avg is the average-distance or average-linkage weighting function.
Avg
// Contrast computes maxd + (average within distance - average between distance).
Contrast
)
// MetricFunc is a clustering distance metric function that evaluates aggregate distance
// between nodes, given the indexes of leaves in a and b clusters
// which are indexs into an ntot x ntot distance matrix dmat.
// maxd is the maximum distance value in the dmat, which is needed by the
// ContrastDist function and perhaps others.
type MetricFunc func(aix, bix []int, ntot int, maxd float64, dmat tensor.Tensor) float64
// Funcs is a registry of clustering metric functions,
// initialized with the standard options.
var Funcs map[string]MetricFunc
func init() {
Funcs = make(map[string]MetricFunc)
Funcs[Min.String()] = MinFunc
Funcs[Max.String()] = MaxFunc
Funcs[Avg.String()] = AvgFunc
Funcs[Contrast.String()] = ContrastFunc
}
// Call calls a cluster metric function by name.
func Call(funcName string, aix, bix []int, ntot int, maxd float64, dmat tensor.Tensor) float64 {
fun, ok := Funcs[funcName]
if !ok {
slog.Error("cluster.Call: function not found", "function:", funcName)
return 0
}
return fun(aix, bix, ntot, maxd, dmat)
}
// MinFunc is the minimum-distance or single-linkage weighting function for comparing
// two clusters a and b, given by their list of indexes.
// ntot is total number of nodes, and dmat is the square similarity matrix [ntot x ntot].
func MinFunc(aix, bix []int, ntot int, maxd float64, dmat tensor.Tensor) float64 {
md := math.MaxFloat64
for _, ai := range aix {
for _, bi := range bix {
d := dmat.Float(ai, bi)
if d < md {
md = d
}
}
}
return md
}
// MaxFunc is the maximum-distance or complete-linkage weighting function for comparing
// two clusters a and b, given by their list of indexes.
// ntot is total number of nodes, and dmat is the square similarity matrix [ntot x ntot].
func MaxFunc(aix, bix []int, ntot int, maxd float64, dmat tensor.Tensor) float64 {
md := -math.MaxFloat64
for _, ai := range aix {
for _, bi := range bix {
d := dmat.Float(ai, bi)
if d > md {
md = d
}
}
}
return md
}
// AvgFunc is the average-distance or average-linkage weighting function for comparing
// two clusters a and b, given by their list of indexes.
// ntot is total number of nodes, and dmat is the square similarity matrix [ntot x ntot].
func AvgFunc(aix, bix []int, ntot int, maxd float64, dmat tensor.Tensor) float64 {
md := 0.0
n := 0
for _, ai := range aix {
for _, bi := range bix {
d := dmat.Float(ai, bi)
md += d
n++
}
}
if n > 0 {
md /= float64(n)
}
return md
}
// ContrastFunc computes maxd + (average within distance - average between distance)
// for two clusters a and b, given by their list of indexes.
// avg between is average distance between all items in a & b versus all outside that.
// ntot is total number of nodes, and dmat is the square similarity matrix [ntot x ntot].
// maxd is the maximum distance and is needed to ensure distances are positive.
func ContrastFunc(aix, bix []int, ntot int, maxd float64, dmat tensor.Tensor) float64 {
wd := AvgFunc(aix, bix, ntot, maxd, dmat)
nab := len(aix) + len(bix)
abix := append(aix, bix...)
abmap := make(map[int]struct{}, ntot-nab)
for _, ix := range abix {
abmap[ix] = struct{}{}
}
oix := make([]int, ntot-nab)
octr := 0
for ix := 0; ix < ntot; ix++ {
if _, has := abmap[ix]; !has {
oix[octr] = ix
octr++
}
}
bd := AvgFunc(abix, oix, ntot, maxd, dmat)
return maxd + (wd - bd)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cluster
import (
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
)
// Plot sets the rows of given data table to trace out lines with labels that
// will render cluster plot starting at root node when plotted with a standard plotting package.
// The lines double-back on themselves to form a continuous line to be plotted.
func Plot(pt *table.Table, root *Node, dmat, labels tensor.Tensor) {
pt.DeleteAll()
pt.AddFloat64Column("X")
pt.AddFloat64Column("Y")
pt.AddStringColumn("Label")
nextY := 0.5
root.SetYs(&nextY)
root.SetParDist(0.0)
root.Plot(pt, dmat, labels)
}
// Plot sets the rows of given data table to trace out lines with labels that
// will render this node in a cluster plot when plotted with a standard plotting package.
// The lines double-back on themselves to form a continuous line to be plotted.
func (nn *Node) Plot(pt *table.Table, dmat, labels tensor.Tensor) {
row := pt.NumRows()
xc := pt.ColumnByIndex(0)
yc := pt.ColumnByIndex(1)
lbl := pt.ColumnByIndex(2)
if nn.IsLeaf() {
pt.SetNumRows(row + 1)
xc.SetFloatRow(nn.ParDist, row, 0)
yc.SetFloatRow(nn.Y, row, 0)
if labels.Len() > nn.Index {
lbl.SetStringRow(labels.StringValue(nn.Index), row, 0)
}
} else {
for _, kn := range nn.Kids {
pt.SetNumRows(row + 2)
xc.SetFloatRow(nn.ParDist, row, 0)
yc.SetFloatRow(kn.Y, row, 0)
row++
xc.SetFloatRow(nn.ParDist+nn.Dist, row, 0)
yc.SetFloatRow(kn.Y, row, 0)
kn.Plot(pt, dmat, labels)
row = pt.NumRows()
pt.SetNumRows(row + 1)
xc.SetFloatRow(nn.ParDist, row, 0)
yc.SetFloatRow(kn.Y, row, 0)
row++
}
pt.SetNumRows(row + 1)
xc.SetFloatRow(nn.ParDist, row, 0)
yc.SetFloatRow(nn.Y, row, 0)
}
}
// SetYs sets the Y-axis values for the nodes in preparation for plotting.
func (nn *Node) SetYs(nextY *float64) {
if nn.IsLeaf() {
nn.Y = *nextY
(*nextY) += 1.0
} else {
avgy := 0.0
for _, kn := range nn.Kids {
kn.SetYs(nextY)
avgy += kn.Y
}
avgy /= float64(len(nn.Kids))
nn.Y = avgy
}
}
// SetParDist sets the parent distance for the nodes in preparation for plotting.
func (nn *Node) SetParDist(pard float64) {
nn.ParDist = pard
if !nn.IsLeaf() {
pard += nn.Dist
for _, kn := range nn.Kids {
kn.SetParDist(pard)
}
}
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package convolve
//go:generate core generate
import (
"errors"
"cogentcore.org/core/base/slicesx"
)
// Slice32 convolves given kernel with given source slice, putting results in
// destination, which is ensured to be the same size as the source slice,
// using existing capacity if available, and otherwise making a new slice.
// The kernel should be normalized, and odd-sized do it is symmetric about 0.
// Returns an error if sizes are not valid.
// No parallelization is used -- see Slice32Parallel for very large slices.
// Edges are handled separately with renormalized kernels -- they can be
// clipped from dest by excluding the kernel half-width from each end.
func Slice32(dest *[]float32, src []float32, kern []float32) error {
sz := len(src)
ksz := len(kern)
if ksz == 0 || sz == 0 {
return errors.New("convolve.Slice32: kernel or source are empty")
}
if ksz%2 == 0 {
return errors.New("convolve.Slice32: kernel is not odd sized")
}
if sz < ksz {
return errors.New("convolve.Slice32: source must be > kernel in size")
}
khalf := (ksz - 1) / 2
*dest = slicesx.SetLength(*dest, sz)
for i := khalf; i < sz-khalf; i++ {
var sum float32
for j := 0; j < ksz; j++ {
sum += src[(i-khalf)+j] * kern[j]
}
(*dest)[i] = sum
}
for i := 0; i < khalf; i++ {
var sum, ksum float32
for j := 0; j <= khalf+i; j++ {
ki := (j + khalf) - i // 0: 1+kh, 1: etc
si := i + (ki - khalf)
// fmt.Printf("i: %d j: %d ki: %d si: %d\n", i, j, ki, si)
sum += src[si] * kern[ki]
ksum += kern[ki]
}
(*dest)[i] = sum / ksum
}
for i := sz - khalf; i < sz; i++ {
var sum, ksum float32
ei := sz - i - 1
for j := 0; j <= khalf+ei; j++ {
ki := ((ksz - 1) - (j + khalf)) + ei
si := i + (ki - khalf)
// fmt.Printf("i: %d j: %d ki: %d si: %d ei: %d\n", i, j, ki, si, ei)
sum += src[si] * kern[ki]
ksum += kern[ki]
}
(*dest)[i] = sum / ksum
}
return nil
}
// Slice64 convolves given kernel with given source slice, putting results in
// destination, which is ensured to be the same size as the source slice,
// using existing capacity if available, and otherwise making a new slice.
// The kernel should be normalized, and odd-sized do it is symmetric about 0.
// Returns an error if sizes are not valid.
// No parallelization is used -- see Slice64Parallel for very large slices.
// Edges are handled separately with renormalized kernels -- they can be
// clipped from dest by excluding the kernel half-width from each end.
func Slice64(dest *[]float64, src []float64, kern []float64) error {
sz := len(src)
ksz := len(kern)
if ksz == 0 || sz == 0 {
return errors.New("convolve.Slice64: kernel or source are empty")
}
if ksz%2 == 0 {
return errors.New("convolve.Slice64: kernel is not odd sized")
}
if sz < ksz {
return errors.New("convolve.Slice64: source must be > kernel in size")
}
khalf := (ksz - 1) / 2
*dest = slicesx.SetLength(*dest, sz)
for i := khalf; i < sz-khalf; i++ {
var sum float64
for j := 0; j < ksz; j++ {
sum += src[(i-khalf)+j] * kern[j]
}
(*dest)[i] = sum
}
for i := 0; i < khalf; i++ {
var sum, ksum float64
for j := 0; j <= khalf+i; j++ {
ki := (j + khalf) - i // 0: 1+kh, 1: etc
si := i + (ki - khalf)
// fmt.Printf("i: %d j: %d ki: %d si: %d\n", i, j, ki, si)
sum += src[si] * kern[ki]
ksum += kern[ki]
}
(*dest)[i] = sum / ksum
}
for i := sz - khalf; i < sz; i++ {
var sum, ksum float64
ei := sz - i - 1
for j := 0; j <= khalf+ei; j++ {
ki := ((ksz - 1) - (j + khalf)) + ei
si := i + (ki - khalf)
// fmt.Printf("i: %d j: %d ki: %d si: %d ei: %d\n", i, j, ki, si, ei)
sum += src[si] * kern[ki]
ksum += kern[ki]
}
(*dest)[i] = sum / ksum
}
return nil
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package convolve
import (
"math"
"cogentcore.org/core/math32"
)
// GaussianKernel32 returns a normalized gaussian kernel for smoothing
// with given half-width and normalized sigma (actual sigma = khalf * sigma).
// A sigma value of .5 is typical for smaller half-widths for containing
// most of the gaussian efficiently -- anything lower than .33 is inefficient --
// generally just use a lower half-width instead.
func GaussianKernel32(khalf int, sigma float32) []float32 {
ksz := khalf*2 + 1
kern := make([]float32, ksz)
sigdiv := 1 / (sigma * float32(khalf))
var sum float32
for i := 0; i < ksz; i++ {
x := sigdiv * float32(i-khalf)
kv := math32.Exp(-0.5 * x * x)
kern[i] = kv
sum += kv
}
nfac := 1 / sum
for i := 0; i < ksz; i++ {
kern[i] *= nfac
}
return kern
}
// GaussianKernel64 returns a normalized gaussian kernel
// with given half-width and normalized sigma (actual sigma = khalf * sigma)
// A sigma value of .5 is typical for smaller half-widths for containing
// most of the gaussian efficiently -- anything lower than .33 is inefficient --
// generally just use a lower half-width instead.
func GaussianKernel64(khalf int, sigma float64) []float64 {
ksz := khalf*2 + 1
kern := make([]float64, ksz)
sigdiv := 1 / (sigma * float64(khalf))
var sum float64
for i := 0; i < ksz; i++ {
x := sigdiv * float64(i-khalf)
kv := math.Exp(-0.5 * x * x)
kern[i] = kv
sum += kv
}
nfac := 1 / sum
for i := 0; i < ksz; i++ {
kern[i] *= nfac
}
return kern
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package glm
import (
"fmt"
"math"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
)
// todo: add tests
// GLM contains results and parameters for running a general
// linear model, which is a general form of multivariate linear
// regression, supporting multiple independent and dependent
// variables. Make a NewGLM and then do Run() on a tensor
// table.Table with the relevant data in columns of the table.
// Batch-mode gradient descent is used and the relevant parameters
// can be altered from defaults before calling Run as needed.
type GLM struct {
// Coeff are the coefficients to map from input independent variables
// to the dependent variables. The first, outer dimension is number of
// dependent variables, and the second, inner dimension is number of
// independent variables plus one for the offset (b) (last element).
Coeff tensor.Float64
// mean squared error of the fitted values relative to data
MSE float64
// R2 is the r^2 total variance accounted for by the linear model,
// for each dependent variable = 1 - (ErrVariance / ObsVariance)
R2 []float64
// Observed variance of each of the dependent variables to be predicted.
ObsVariance []float64
// Variance of the error residuals per dependent variables
ErrVariance []float64
// optional names of the independent variables, for reporting results
IndepNames []string
// optional names of the dependent variables, for reporting results
DepNames []string
//////// Parameters for the GLM model fitting:
// ZeroOffset restricts the offset of the linear function to 0,
// forcing it to pass through the origin. Otherwise, a constant offset "b"
// is fit during the model fitting process.
ZeroOffset bool
// learning rate parameter, which can be adjusted to reduce iterations based on
// specific properties of the data, but the default is reasonable for most "typical" data.
LRate float64 `default:"0.1"`
// tolerance on difference in mean squared error (MSE) across iterations to stop
// iterating and consider the result to be converged.
StopTolerance float64 `default:"0.0001"`
// Constant cost factor subtracted from weights, for the L1 norm or "Lasso"
// regression. This is good for producing sparse results but can arbitrarily
// select one of multiple correlated independent variables.
L1Cost float64
// Cost factor proportional to the coefficient value, for the L2 norm or "Ridge"
// regression. This is good for generally keeping weights small and equally
// penalizes correlated independent variables.
L2Cost float64
// CostStartIter is the iteration when we start applying the L1, L2 Cost factors.
// It is often a good idea to have a few unconstrained iterations prior to
// applying the cost factors.
CostStartIter int `default:"5"`
// maximum number of iterations to perform
MaxIters int `default:"50"`
//////// Cached values from the table
// Table of data
Table *table.Table
// tensor columns from table with the respective variables
IndepVars, DepVars, PredVars, ErrVars tensor.RowMajor
// Number of independent and dependent variables
NIndepVars, NDepVars int
}
func NewGLM() *GLM {
glm := &GLM{}
glm.Defaults()
return glm
}
func (glm *GLM) Defaults() {
glm.LRate = 0.1
glm.StopTolerance = 0.001
glm.MaxIters = 50
glm.CostStartIter = 5
}
func (glm *GLM) init(nIv, nDv int) {
glm.NIndepVars = nIv
glm.NDepVars = nDv
glm.Coeff.SetShapeSizes(nDv, nIv+1)
// glm.Coeff.SetNames("DepVars", "IndepVars")
glm.R2 = make([]float64, nDv)
glm.ObsVariance = make([]float64, nDv)
glm.ErrVariance = make([]float64, nDv)
glm.IndepNames = make([]string, nIv)
glm.DepNames = make([]string, nDv)
}
// SetTable sets the data to use from given indexview of table, where
// each of the Vars args specifies a column in the table, which can have either a
// single scalar value for each row, or a tensor cell with multiple values.
// predVars and errVars (predicted values and error values) are optional.
func (glm *GLM) SetTable(dt *table.Table, indepVars, depVars, predVars, errVars string) error {
iv := dt.Column(indepVars)
dv := dt.Column(depVars)
var pv, ev *tensor.Rows
if predVars != "" {
pv = dt.Column(predVars)
}
if errVars != "" {
ev = dt.Column(errVars)
}
if pv != nil && !pv.Shape().IsEqual(dv.Shape()) {
return fmt.Errorf("predVars must have same shape as depVars")
}
if ev != nil && !ev.Shape().IsEqual(dv.Shape()) {
return fmt.Errorf("errVars must have same shape as depVars")
}
_, nIv := iv.Shape().RowCellSize()
_, nDv := dv.Shape().RowCellSize()
glm.init(nIv, nDv)
glm.Table = dt
glm.IndepVars = iv
glm.DepVars = dv
glm.PredVars = pv
glm.ErrVars = ev
return nil
}
// Run performs the multi-variate linear regression using data SetTable function,
// learning linear coefficients and an overall static offset that best
// fits the observed dependent variables as a function of the independent variables.
// Initial values of the coefficients, and other parameters for the regression,
// should be set prior to running.
func (glm *GLM) Run() {
dt := glm.Table
iv := glm.IndepVars
dv := glm.DepVars
pv := glm.PredVars
ev := glm.ErrVars
if pv == nil {
pv = tensor.Clone(dv)
}
if ev == nil {
ev = tensor.Clone(dv)
}
nDv := glm.NDepVars
nIv := glm.NIndepVars
nCi := nIv + 1
dc := glm.Coeff.Clone().(*tensor.Float64)
lastItr := false
sse := 0.0
prevmse := 0.0
n := dt.NumRows()
norm := 1.0 / float64(n)
lrate := norm * glm.LRate
for itr := 0; itr < glm.MaxIters; itr++ {
for i := range dc.Values {
dc.Values[i] = 0
}
sse = 0
if (itr+1)%10 == 0 {
lrate *= 0.5
}
for i := 0; i < n; i++ {
row := dt.RowIndex(i)
for di := 0; di < nDv; di++ {
pred := 0.0
for ii := 0; ii < nIv; ii++ {
pred += glm.Coeff.Float(di, ii) * iv.FloatRow(row, ii)
}
if !glm.ZeroOffset {
pred += glm.Coeff.Float(di, nIv)
}
targ := dv.FloatRow(row, di)
err := targ - pred
sse += err * err
for ii := 0; ii < nIv; ii++ {
dc.Values[di*nCi+ii] += err * iv.FloatRow(row, ii)
}
if !glm.ZeroOffset {
dc.Values[di*nCi+nIv] += err
}
if lastItr {
pv.SetFloatRow(pred, row, di)
if ev != nil {
ev.SetFloatRow(err, row, di)
}
}
}
}
for di := 0; di < nDv; di++ {
for ii := 0; ii <= nIv; ii++ {
if glm.ZeroOffset && ii == nIv {
continue
}
idx := di*(nCi+1) + ii
w := glm.Coeff.Values[idx]
d := dc.Values[idx]
sgn := 1.0
if w < 0 {
sgn = -1.0
} else if w == 0 {
sgn = 0
}
glm.Coeff.Values[idx] += lrate * (d - glm.L1Cost*sgn - glm.L2Cost*w)
}
}
glm.MSE = norm * sse
if lastItr {
break
}
if itr > 0 {
dmse := glm.MSE - prevmse
if math.Abs(dmse) < glm.StopTolerance || itr == glm.MaxIters-2 {
lastItr = true
}
}
fmt.Println(itr, glm.MSE)
prevmse = glm.MSE
}
obsMeans := make([]float64, nDv)
errMeans := make([]float64, nDv)
for i := 0; i < n; i++ {
row := i
if dt.Indexes != nil {
row = dt.Indexes[i]
}
for di := 0; di < nDv; di++ {
obsMeans[di] += dv.FloatRow(row, di)
errMeans[di] += ev.FloatRow(row, di)
}
}
for di := 0; di < nDv; di++ {
obsMeans[di] *= norm
errMeans[di] *= norm
glm.ObsVariance[di] = 0
glm.ErrVariance[di] = 0
}
for i := 0; i < n; i++ {
row := i
if dt.Indexes != nil {
row = dt.Indexes[i]
}
for di := 0; di < nDv; di++ {
o := dv.FloatRow(row, di) - obsMeans[di]
glm.ObsVariance[di] += o * o
e := ev.FloatRow(row, di) - errMeans[di]
glm.ErrVariance[di] += e * e
}
}
for di := 0; di < nDv; di++ {
glm.ObsVariance[di] *= norm
glm.ErrVariance[di] *= norm
glm.R2[di] = 1.0 - (glm.ErrVariance[di] / glm.ObsVariance[di])
}
}
// Variance returns a description of the variance accounted for by the regression
// equation, R^2, for each dependent variable, along with the variances of
// observed and errors (residuals), which are used to compute it.
func (glm *GLM) Variance() string {
str := ""
for di := range glm.R2 {
if len(glm.DepNames) > di && glm.DepNames[di] != "" {
str += glm.DepNames[di]
} else {
str += fmt.Sprintf("DV %d", di)
}
str += fmt.Sprintf("\tR^2: %8.6g\tR: %8.6g\tVar Err: %8.4g\t Obs: %8.4g\n", glm.R2[di], math.Sqrt(glm.R2[di]), glm.ErrVariance[di], glm.ObsVariance[di])
}
return str
}
// Coeffs returns a string describing the coefficients
func (glm *GLM) Coeffs() string {
str := ""
for di := range glm.NDepVars {
if len(glm.DepNames) > di && glm.DepNames[di] != "" {
str += glm.DepNames[di]
} else {
str += fmt.Sprintf("DV %d", di)
}
str += " = "
for ii := 0; ii <= glm.NIndepVars; ii++ {
str += fmt.Sprintf("\t%8.6g", glm.Coeff.Float(di, ii))
if ii < glm.NIndepVars {
str += " * "
if len(glm.IndepNames) > ii && glm.IndepNames[di] != "" {
str += glm.IndepNames[di]
} else {
str += fmt.Sprintf("IV_%d", ii)
}
str += " + "
}
}
str += "\n"
}
return str
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package histogram
//go:generate core generate
import (
"cogentcore.org/core/base/slicesx"
"cogentcore.org/core/math32"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
)
// F64 generates a histogram of counts of values within given
// number of bins and min / max range. hist vals is sized to nBins.
// if value is < min or > max it is ignored.
func F64(hist *[]float64, vals []float64, nBins int, min, max float64) {
*hist = slicesx.SetLength(*hist, nBins)
h := *hist
// 0.1.2.3 = 3-0 = 4 bins
inc := (max - min) / float64(nBins)
for i := 0; i < nBins; i++ {
h[i] = 0
}
for _, v := range vals {
if v < min || v > max {
continue
}
bin := int((v - min) / inc)
if bin >= nBins {
bin = nBins - 1
}
h[bin] += 1
}
}
// F64Table generates an table with a histogram of counts of values within given
// number of bins and min / max range. The table has columns: Value, Count
// if value is < min or > max it is ignored.
// The Value column represents the min value for each bin, with the max being
// the value of the next bin, or the max if at the end.
func F64Table(dt *table.Table, vals []float64, nBins int, min, max float64) {
dt.DeleteAll()
dt.AddFloat64Column("Value")
dt.AddFloat64Column("Count")
dt.SetNumRows(nBins)
ct := dt.Columns.Values[1].(*tensor.Float64)
F64(&ct.Values, vals, nBins, min, max)
inc := (max - min) / float64(nBins)
vls := dt.Columns.Values[0].(*tensor.Float64).Values
for i := 0; i < nBins; i++ {
vls[i] = math32.Truncate64(min+float64(i)*inc, 4)
}
}
// Code generated by "core generate"; DO NOT EDIT.
package metric
import (
"cogentcore.org/core/enums"
)
var _MetricsValues = []Metrics{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}
// MetricsN is the highest valid value for type Metrics, plus one.
const MetricsN Metrics = 13
var _MetricsValueMap = map[string]Metrics{`L2Norm`: 0, `SumSquares`: 1, `L1Norm`: 2, `Hamming`: 3, `L2NormBinTol`: 4, `SumSquaresBinTol`: 5, `InvCosine`: 6, `InvCorrelation`: 7, `CrossEntropy`: 8, `DotProduct`: 9, `Covariance`: 10, `Correlation`: 11, `Cosine`: 12}
var _MetricsDescMap = map[Metrics]string{0: `L2Norm is the square root of the sum of squares differences between tensor values, aka the L2 Norm.`, 1: `SumSquares is the sum of squares differences between tensor values.`, 2: `L1Norm is the sum of the absolute value of differences between tensor values, the L1 Norm.`, 3: `Hamming is the sum of 1s for every element that is different, i.e., "city block" distance.`, 4: `L2NormBinTol is the [L2Norm] square root of the sum of squares differences between tensor values, with binary tolerance: differences < 0.5 are thresholded to 0.`, 5: `SumSquaresBinTol is the [SumSquares] differences between tensor values, with binary tolerance: differences < 0.5 are thresholded to 0.`, 6: `InvCosine is 1-[Cosine], which is useful to convert it to an Increasing metric where more different vectors have larger metric values.`, 7: `InvCorrelation is 1-[Correlation], which is useful to convert it to an Increasing metric where more different vectors have larger metric values.`, 8: `CrossEntropy is a standard measure of the difference between two probabilty distributions, reflecting the additional entropy (uncertainty) associated with measuring probabilities under distribution b when in fact they come from distribution a. It is also the entropy of a plus the divergence between a from b, using Kullback-Leibler (KL) divergence. It is computed as: a * log(a/b) + (1-a) * log(1-a/1-b).`, 9: `DotProduct is the sum of the co-products of the tensor values.`, 10: `Covariance is co-variance between two vectors, i.e., the mean of the co-product of each vector element minus the mean of that vector: cov(A,B) = E[(A - E(A))(B - E(B))].`, 11: `Correlation is the standardized [Covariance] in the range (-1..1), computed as the mean of the co-product of each vector element minus the mean of that vector, normalized by the product of their standard deviations: cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B). Equivalent to the [Cosine] of mean-normalized vectors.`, 12: `Cosine is high-dimensional angle between two vectors, in range (-1..1) as the normalized [DotProduct]: inner product / sqrt(ssA * ssB). See also [Correlation].`}
var _MetricsMap = map[Metrics]string{0: `L2Norm`, 1: `SumSquares`, 2: `L1Norm`, 3: `Hamming`, 4: `L2NormBinTol`, 5: `SumSquaresBinTol`, 6: `InvCosine`, 7: `InvCorrelation`, 8: `CrossEntropy`, 9: `DotProduct`, 10: `Covariance`, 11: `Correlation`, 12: `Cosine`}
// String returns the string representation of this Metrics value.
func (i Metrics) String() string { return enums.String(i, _MetricsMap) }
// SetString sets the Metrics value from its string representation,
// and returns an error if the string is invalid.
func (i *Metrics) SetString(s string) error {
return enums.SetString(i, s, _MetricsValueMap, "Metrics")
}
// Int64 returns the Metrics value as an int64.
func (i Metrics) Int64() int64 { return int64(i) }
// SetInt64 sets the Metrics value from an int64.
func (i *Metrics) SetInt64(in int64) { *i = Metrics(in) }
// Desc returns the description of the Metrics value.
func (i Metrics) Desc() string { return enums.Desc(i, _MetricsDescMap) }
// MetricsValues returns all possible values for the type Metrics.
func MetricsValues() []Metrics { return _MetricsValues }
// Values returns all possible values for the type Metrics.
func (i Metrics) Values() []enums.Enum { return enums.Values(_MetricsValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i Metrics) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *Metrics) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Metrics") }
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package metric
import (
"math"
"cogentcore.org/core/math32"
"cogentcore.org/lab/stats/stats"
"cogentcore.org/lab/tensor"
)
// MetricFunc is the function signature for a metric function,
// which is computed over the outermost row dimension and the
// output is the shape of the remaining inner cells (a scalar for 1D inputs).
// Use [tensor.As1D], [tensor.NewRowCellsView], [tensor.Cells1D] etc
// to reshape and reslice the data as needed.
// All metric functions skip over NaN's, as a missing value,
// and use the min of the length of the two tensors.
// Metric functions cannot be computed in parallel,
// e.g., using VectorizeThreaded or GPU, due to shared writing
// to the same output values. Special implementations are required
// if that is needed.
type MetricFunc = func(a, b tensor.Tensor) tensor.Values
// MetricOutFunc is the function signature for a metric function,
// that takes output values as the final argument. See [MetricFunc].
// This version is for computationally demanding cases and saves
// reallocation of output.
type MetricOutFunc = func(a, b tensor.Tensor, out tensor.Values) error
// SumSquaresScaleOut64 computes the sum of squares differences between tensor values,
// returning scale and ss factors aggregated separately for better numerical stability, per BLAS.
func SumSquaresScaleOut64(a, b tensor.Tensor) (scale64, ss64 *tensor.Float64, err error) {
if err = tensor.MustBeSameShape(a, b); err != nil {
return
}
scale64, ss64 = Vectorize2Out64(a, b, 0, 1, func(a, b, scale, ss float64) (float64, float64) {
if math.IsNaN(a) || math.IsNaN(b) {
return scale, ss
}
d := a - b
if d == 0 {
return scale, ss
}
absxi := math.Abs(d)
if scale < absxi {
ss = 1 + ss*(scale/absxi)*(scale/absxi)
scale = absxi
} else {
ss = ss + (absxi/scale)*(absxi/scale)
}
return scale, ss
})
return
}
// SumSquaresOut64 computes the sum of squares differences between tensor values,
// and returns the Float64 output values for use in subsequent computations.
func SumSquaresOut64(a, b tensor.Tensor, out tensor.Values) (*tensor.Float64, error) {
scale64, ss64, err := SumSquaresScaleOut64(a, b)
if err != nil {
return nil, err
}
osz := tensor.CellsSize(a.ShapeSizes())
out.SetShapeSizes(osz...)
nsub := out.Len()
for i := range nsub {
scale := scale64.Float1D(i)
ss := ss64.Float1D(i)
v := 0.0
if math.IsInf(scale, 1) {
v = math.Inf(1)
} else {
v = scale * scale * ss
}
scale64.SetFloat1D(v, i)
out.SetFloat1D(v, i)
}
return scale64, err
}
// SumSquaresOut computes the sum of squares differences between tensor values,
// See [MetricOutFunc] for general information.
func SumSquaresOut(a, b tensor.Tensor, out tensor.Values) error {
_, err := SumSquaresOut64(a, b, out)
return err
}
// SumSquares computes the sum of squares differences between tensor values,
// See [MetricFunc] for general information.
func SumSquares(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(SumSquaresOut, a, b)
}
// L2NormOut computes the L2 Norm: square root of the sum of squares
// differences between tensor values, aka the Euclidean distance.
func L2NormOut(a, b tensor.Tensor, out tensor.Values) error {
scale64, ss64, err := SumSquaresScaleOut64(a, b)
if err != nil {
return err
}
osz := tensor.CellsSize(a.ShapeSizes())
out.SetShapeSizes(osz...)
nsub := out.Len()
for i := range nsub {
scale := scale64.Float1D(i)
ss := ss64.Float1D(i)
v := 0.0
if math.IsInf(scale, 1) {
v = math.Inf(1)
} else {
v = scale * math.Sqrt(ss)
}
scale64.SetFloat1D(v, i)
out.SetFloat1D(v, i)
}
return nil
}
// L2Norm computes the L2 Norm: square root of the sum of squares
// differences between tensor values, aka the Euclidean distance.
// See [MetricFunc] for general information.
func L2Norm(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(L2NormOut, a, b)
}
// L1NormOut computes the sum of the absolute value of differences between the
// tensor values, the L1 Norm.
// See [MetricOutFunc] for general information.
func L1NormOut(a, b tensor.Tensor, out tensor.Values) error {
if err := tensor.MustBeSameShape(a, b); err != nil {
return err
}
VectorizeOut64(a, b, out, 0, func(a, b, agg float64) float64 {
if math.IsNaN(a) || math.IsNaN(b) {
return agg
}
return agg + math.Abs(a-b)
})
return nil
}
// L1Norm computes the sum of the absolute value of differences between the
// tensor values, the L1 Norm.
// See [MetricFunc] for general information.
func L1Norm(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(L1NormOut, a, b)
}
// HammingOut computes the sum of 1s for every element that is different,
// i.e., "city block" distance.
// See [MetricOutFunc] for general information.
func HammingOut(a, b tensor.Tensor, out tensor.Values) error {
if err := tensor.MustBeSameShape(a, b); err != nil {
return err
}
VectorizeOut64(a, b, out, 0, func(a, b, agg float64) float64 {
if math.IsNaN(a) || math.IsNaN(b) {
return agg
}
if a != b {
agg += 1
}
return agg
})
return nil
}
// Hamming computes the sum of 1s for every element that is different,
// i.e., "city block" distance.
// See [MetricFunc] for general information.
func Hamming(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(HammingOut, a, b)
}
// SumSquaresBinTolScaleOut64 computes the sum of squares differences between tensor values,
// with binary tolerance: differences < 0.5 are thresholded to 0.
// returning scale and ss factors aggregated separately for better numerical stability, per BLAS.
func SumSquaresBinTolScaleOut64(a, b tensor.Tensor) (scale64, ss64 *tensor.Float64, err error) {
if err = tensor.MustBeSameShape(a, b); err != nil {
return
}
scale64, ss64 = Vectorize2Out64(a, b, 0, 1, func(a, b, scale, ss float64) (float64, float64) {
if math.IsNaN(a) || math.IsNaN(b) {
return scale, ss
}
d := a - b
if math.Abs(d) < 0.5 {
return scale, ss
}
absxi := math.Abs(d)
if scale < absxi {
ss = 1 + ss*(scale/absxi)*(scale/absxi)
scale = absxi
} else {
ss = ss + (absxi/scale)*(absxi/scale)
}
return scale, ss
})
return
}
// L2NormBinTolOut computes the L2 Norm square root of the sum of squares
// differences between tensor values (aka Euclidean distance), with binary tolerance:
// differences < 0.5 are thresholded to 0.
func L2NormBinTolOut(a, b tensor.Tensor, out tensor.Values) error {
scale64, ss64, err := SumSquaresBinTolScaleOut64(a, b)
if err != nil {
return err
}
osz := tensor.CellsSize(a.ShapeSizes())
out.SetShapeSizes(osz...)
nsub := out.Len()
for i := range nsub {
scale := scale64.Float1D(i)
ss := ss64.Float1D(i)
v := 0.0
if math.IsInf(scale, 1) {
v = math.Inf(1)
} else {
v = scale * math.Sqrt(ss)
}
scale64.SetFloat1D(v, i)
out.SetFloat1D(v, i)
}
return nil
}
// L2NormBinTol computes the L2 Norm square root of the sum of squares
// differences between tensor values (aka Euclidean distance), with binary tolerance:
// differences < 0.5 are thresholded to 0.
// See [MetricFunc] for general information.
func L2NormBinTol(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(L2NormBinTolOut, a, b)
}
// SumSquaresBinTolOut computes the sum of squares differences between tensor values,
// with binary tolerance: differences < 0.5 are thresholded to 0.
func SumSquaresBinTolOut(a, b tensor.Tensor, out tensor.Values) error {
scale64, ss64, err := SumSquaresBinTolScaleOut64(a, b)
if err != nil {
return err
}
osz := tensor.CellsSize(a.ShapeSizes())
out.SetShapeSizes(osz...)
nsub := out.Len()
for i := range nsub {
scale := scale64.Float1D(i)
ss := ss64.Float1D(i)
v := 0.0
if math.IsInf(scale, 1) {
v = math.Inf(1)
} else {
v = scale * scale * ss
}
scale64.SetFloat1D(v, i)
out.SetFloat1D(v, i)
}
return nil
}
// SumSquaresBinTol computes the sum of squares differences between tensor values,
// with binary tolerance: differences < 0.5 are thresholded to 0.
// See [MetricFunc] for general information.
func SumSquaresBinTol(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(SumSquaresBinTolOut, a, b)
}
// CrossEntropyOut is a standard measure of the difference between two
// probabilty distributions, reflecting the additional entropy (uncertainty) associated
// with measuring probabilities under distribution b when in fact they come from
// distribution a. It is also the entropy of a plus the divergence between a from b,
// using Kullback-Leibler (KL) divergence. It is computed as:
// a * log(a/b) + (1-a) * log(1-a/1-b).
// See [MetricOutFunc] for general information.
func CrossEntropyOut(a, b tensor.Tensor, out tensor.Values) error {
if err := tensor.MustBeSameShape(a, b); err != nil {
return err
}
VectorizeOut64(a, b, out, 0, func(a, b, agg float64) float64 {
if math.IsNaN(a) || math.IsNaN(b) {
return agg
}
b = math32.Clamp(b, 0.000001, 0.999999)
if a >= 1.0 {
agg += -math.Log(b)
} else if a <= 0.0 {
agg += -math.Log(1.0 - b)
} else {
agg += a*math.Log(a/b) + (1-a)*math.Log((1-a)/(1-b))
}
return agg
})
return nil
}
// CrossEntropy is a standard measure of the difference between two
// probabilty distributions, reflecting the additional entropy (uncertainty) associated
// with measuring probabilities under distribution b when in fact they come from
// distribution a. It is also the entropy of a plus the divergence between a from b,
// using Kullback-Leibler (KL) divergence. It is computed as:
// a * log(a/b) + (1-a) * log(1-a/1-b).
// See [MetricFunc] for general information.
func CrossEntropy(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(CrossEntropyOut, a, b)
}
// DotProductOut computes the sum of the element-wise products of the
// two tensors (aka the inner product).
// See [MetricOutFunc] for general information.
func DotProductOut(a, b tensor.Tensor, out tensor.Values) error {
if err := tensor.MustBeSameShape(a, b); err != nil {
return err
}
VectorizeOut64(a, b, out, 0, func(a, b, agg float64) float64 {
if math.IsNaN(a) || math.IsNaN(b) {
return agg
}
return agg + a*b
})
return nil
}
// DotProductOut computes the sum of the element-wise products of the
// two tensors (aka the inner product).
// See [MetricFunc] for general information.
func DotProduct(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(DotProductOut, a, b)
}
// CovarianceOut computes the co-variance between two vectors,
// i.e., the mean of the co-product of each vector element minus
// the mean of that vector: cov(A,B) = E[(A - E(A))(B - E(B))].
func CovarianceOut(a, b tensor.Tensor, out tensor.Values) error {
if err := tensor.MustBeSameShape(a, b); err != nil {
return err
}
amean, acount := stats.MeanOut64(a, out)
bmean, _ := stats.MeanOut64(b, out)
cov64 := VectorizePreOut64(a, b, out, 0, amean, bmean, func(a, b, am, bm, agg float64) float64 {
if math.IsNaN(a) || math.IsNaN(b) {
return agg
}
return agg + (a-am)*(b-bm)
})
osz := tensor.CellsSize(a.ShapeSizes())
out.SetShapeSizes(osz...)
nsub := out.Len()
for i := range nsub {
c := acount.Float1D(i)
if c == 0 {
continue
}
cov := cov64.Float1D(i) / c
out.SetFloat1D(cov, i)
}
return nil
}
// Covariance computes the co-variance between two vectors,
// i.e., the mean of the co-product of each vector element minus
// the mean of that vector: cov(A,B) = E[(A - E(A))(B - E(B))].
// See [MetricFunc] for general information.
func Covariance(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(CovarianceOut, a, b)
}
// CorrelationOut64 computes the correlation between two vectors,
// in range (-1..1) as the mean of the co-product of each vector
// element minus the mean of that vector, normalized by the product of their
// standard deviations: cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B).
// (i.e., the standardized covariance).
// Equivalent to the cosine of mean-normalized vectors.
// Returns the Float64 output values for subsequent use.
func CorrelationOut64(a, b tensor.Tensor, out tensor.Values) (*tensor.Float64, error) {
if err := tensor.MustBeSameShape(a, b); err != nil {
return nil, err
}
amean, _ := stats.MeanOut64(a, out)
bmean, _ := stats.MeanOut64(b, out)
ss64, avar64, bvar64 := VectorizePre3Out64(a, b, 0, 0, 0, amean, bmean, func(a, b, am, bm, ss, avar, bvar float64) (float64, float64, float64) {
if math.IsNaN(a) || math.IsNaN(b) {
return ss, avar, bvar
}
ad := a - am
bd := b - bm
ss += ad * bd // between
avar += ad * ad // within
bvar += bd * bd
return ss, avar, bvar
})
osz := tensor.CellsSize(a.ShapeSizes())
out.SetShapeSizes(osz...)
nsub := out.Len()
for i := range nsub {
ss := ss64.Float1D(i)
vp := math.Sqrt(avar64.Float1D(i) * bvar64.Float1D(i))
if vp > 0 {
ss /= vp
}
ss64.SetFloat1D(ss, i)
out.SetFloat1D(ss, i)
}
return ss64, nil
}
// CorrelationOut computes the correlation between two vectors,
// in range (-1..1) as the mean of the co-product of each vector
// element minus the mean of that vector, normalized by the product of their
// standard deviations: cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B).
// (i.e., the standardized [Covariance]).
// Equivalent to the [Cosine] of mean-normalized vectors.
func CorrelationOut(a, b tensor.Tensor, out tensor.Values) error {
_, err := CorrelationOut64(a, b, out)
return err
}
// Correlation computes the correlation between two vectors,
// in range (-1..1) as the mean of the co-product of each vector
// element minus the mean of that vector, normalized by the product of their
// standard deviations: cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B).
// (i.e., the standardized [Covariance]).
// Equivalent to the [Cosine] of mean-normalized vectors.
// See [MetricFunc] for general information.
func Correlation(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(CorrelationOut, a, b)
}
// InvCorrelationOut computes 1 minus the correlation between two vectors,
// in range (-1..1) as the mean of the co-product of each vector
// element minus the mean of that vector, normalized by the product of their
// standard deviations: cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B).
// (i.e., the standardized covariance).
// Equivalent to the [Cosine] of mean-normalized vectors.
// This is useful for a difference measure instead of similarity,
// where more different vectors have larger metric values.
func InvCorrelationOut(a, b tensor.Tensor, out tensor.Values) error {
cor64, err := CorrelationOut64(a, b, out)
if err != nil {
return err
}
nsub := out.Len()
for i := range nsub {
cor := cor64.Float1D(i)
out.SetFloat1D(1-cor, i)
}
return nil
}
// InvCorrelation computes 1 minus the correlation between two vectors,
// in range (-1..1) as the mean of the co-product of each vector
// element minus the mean of that vector, normalized by the product of their
// standard deviations: cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B).
// (i.e., the standardized covariance).
// Equivalent to the [Cosine] of mean-normalized vectors.
// This is useful for a difference measure instead of similarity,
// where more different vectors have larger metric values.
// See [MetricFunc] for general information.
func InvCorrelation(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(InvCorrelationOut, a, b)
}
// CosineOut64 computes the high-dimensional angle between two vectors,
// in range (-1..1) as the normalized [Dot]:
// dot product / sqrt(ssA * ssB). See also [Correlation].
func CosineOut64(a, b tensor.Tensor, out tensor.Values) (*tensor.Float64, error) {
if err := tensor.MustBeSameShape(a, b); err != nil {
return nil, err
}
ss64, avar64, bvar64 := Vectorize3Out64(a, b, 0, 0, 0, func(a, b, ss, avar, bvar float64) (float64, float64, float64) {
if math.IsNaN(a) || math.IsNaN(b) {
return ss, avar, bvar
}
ss += a * b
avar += a * a
bvar += b * b
return ss, avar, bvar
})
osz := tensor.CellsSize(a.ShapeSizes())
out.SetShapeSizes(osz...)
nsub := out.Len()
for i := range nsub {
ss := ss64.Float1D(i)
vp := math.Sqrt(avar64.Float1D(i) * bvar64.Float1D(i))
if vp > 0 {
ss /= vp
}
ss64.SetFloat1D(ss, i)
out.SetFloat1D(ss, i)
}
return ss64, nil
}
// CosineOut computes the high-dimensional angle between two vectors,
// in range (-1..1) as the normalized dot product:
// dot product / sqrt(ssA * ssB). See also [Correlation]
func CosineOut(a, b tensor.Tensor, out tensor.Values) error {
_, err := CosineOut64(a, b, out)
return err
}
// Cosine computes the high-dimensional angle between two vectors,
// in range (-1..1) as the normalized dot product:
// dot product / sqrt(ssA * ssB). See also [Correlation]
// See [MetricFunc] for general information.
func Cosine(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(CosineOut, a, b)
}
// InvCosineOut computes 1 minus the cosine between two vectors,
// in range (-1..1) as the normalized dot product:
// dot product / sqrt(ssA * ssB).
// This is useful for a difference measure instead of similarity,
// where more different vectors have larger metric values.
func InvCosineOut(a, b tensor.Tensor, out tensor.Values) error {
cos64, err := CosineOut64(a, b, out)
if err != nil {
return err
}
nsub := out.Len()
for i := range nsub {
cos := cos64.Float1D(i)
out.SetFloat1D(1-cos, i)
}
return nil
}
// InvCosine computes 1 minus the cosine between two vectors,
// in range (-1..1) as the normalized dot product:
// dot product / sqrt(ssA * ssB).
// This is useful for a difference measure instead of similarity,
// where more different vectors have larger metric values.
// See [MetricFunc] for general information.
func InvCosine(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(InvCosineOut, a, b)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package metric
import (
"cogentcore.org/lab/matrix"
"cogentcore.org/lab/tensor"
)
// MatrixOut computes the rows x rows square distance / similarity matrix
// between the patterns for each row of the given higher dimensional input tensor,
// which must have at least 2 dimensions: the outermost rows,
// and within that, 1+dimensional patterns (cells). Use [tensor.NewRowCellsView]
// to organize data into the desired split between a 1D outermost Row dimension
// and the remaining Cells dimension.
// The metric function must have the [MetricFunc] signature.
// The results fill in the elements of the output matrix, which is symmetric,
// and only the lower triangular part is computed, with results copied
// to the upper triangular region, for maximum efficiency.
func MatrixOut(fun any, in tensor.Tensor, out tensor.Values) error {
mfun, err := AsMetricFunc(fun)
if err != nil {
return err
}
rows, cells := in.Shape().RowCellSize()
if rows == 0 || cells == 0 {
return nil
}
out.SetShapeSizes(rows, rows)
coords := matrix.TriLIndicies(rows)
nc := coords.DimSize(0)
// note: flops estimating 3 per item on average -- different for different metrics.
tensor.VectorizeThreaded(cells*3, func(tsr ...tensor.Tensor) int { return nc },
func(idx int, tsr ...tensor.Tensor) {
cx := coords.Int(idx, 0)
cy := coords.Int(idx, 1)
sa := tensor.Cells1D(tsr[0], cx)
sb := tensor.Cells1D(tsr[0], cy)
mout := mfun(sa, sb)
tsr[1].SetFloat(mout.Float1D(0), cx, cy)
}, in, out)
for idx := range nc { // copy to upper
cx := coords.Int(idx, 0)
cy := coords.Int(idx, 1)
if cx == cy { // exclude diag
continue
}
out.SetFloat(out.Float(cx, cy), cy, cx)
}
return nil
}
// Matrix computes the rows x rows square distance / similarity matrix
// between the patterns for each row of the given higher dimensional input tensor,
// which must have at least 2 dimensions: the outermost rows,
// and within that, 1+dimensional patterns (cells). Use [tensor.NewRowCellsView]
// to organize data into the desired split between a 1D outermost Row dimension
// and the remaining Cells dimension.
// The metric function must have the [MetricFunc] signature.
// The results fill in the elements of the output matrix, which is symmetric,
// and only the lower triangular part is computed, with results copied
// to the upper triangular region, for maximum efficiency.
func Matrix(fun any, in tensor.Tensor) tensor.Values {
return tensor.CallOut1Gen1(MatrixOut, fun, in)
}
// CrossMatrixOut computes the distance / similarity matrix between
// two different sets of patterns in the two input tensors, where
// the patterns are in the sub-space cells of the tensors,
// which must have at least 2 dimensions: the outermost rows,
// and within that, 1+dimensional patterns that the given distance metric
// function is applied to, with the results filling in the cells of the output matrix.
// The metric function must have the [MetricFunc] signature.
// The rows of the output matrix are the rows of the first input tensor,
// and the columns of the output are the rows of the second input tensor.
func CrossMatrixOut(fun any, a, b tensor.Tensor, out tensor.Values) error {
mfun, err := AsMetricFunc(fun)
if err != nil {
return err
}
arows, acells := a.Shape().RowCellSize()
if arows == 0 || acells == 0 {
return nil
}
brows, bcells := b.Shape().RowCellSize()
if brows == 0 || bcells == 0 {
return nil
}
out.SetShapeSizes(arows, brows)
// note: flops estimating 3 per item on average -- different for different metrics.
flops := min(acells, bcells) * 3
nc := arows * brows
tensor.VectorizeThreaded(flops, func(tsr ...tensor.Tensor) int { return nc },
func(idx int, tsr ...tensor.Tensor) {
ar := idx / brows
br := idx % brows
sa := tensor.Cells1D(tsr[0], ar)
sb := tensor.Cells1D(tsr[1], br)
mout := mfun(sa, sb)
tsr[2].SetFloat(mout.Float1D(0), ar, br)
}, a, b, out)
return nil
}
// CrossMatrix computes the distance / similarity matrix between
// two different sets of patterns in the two input tensors, where
// the patterns are in the sub-space cells of the tensors,
// which must have at least 2 dimensions: the outermost rows,
// and within that, 1+dimensional patterns that the given distance metric
// function is applied to, with the results filling in the cells of the output matrix.
// The metric function must have the [MetricFunc] signature.
// The rows of the output matrix are the rows of the first input tensor,
// and the columns of the output are the rows of the second input tensor.
func CrossMatrix(fun any, a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2Gen1(CrossMatrixOut, fun, a, b)
}
// CovarianceMatrixOut generates the cells x cells square covariance matrix
// for all per-row cells of the given higher dimensional input tensor,
// which must have at least 2 dimensions: the outermost rows,
// and within that, 1+dimensional patterns (cells).
// Each value in the resulting matrix represents the extent to which the
// value of a given cell covaries across the rows of the tensor with the
// value of another cell.
// Uses the given metric function, typically [Covariance] or [Correlation],
// The metric function must have the [MetricFunc] signature.
// Use Covariance if vars have similar overall scaling,
// which is typical in neural network models, and use
// Correlation if they are on very different scales, because it effectively rescales).
// The resulting matrix can be used as the input to PCA or SVD eigenvalue decomposition.
func CovarianceMatrixOut(fun any, in tensor.Tensor, out tensor.Values) error {
mfun, err := AsMetricFunc(fun)
if err != nil {
return err
}
rows, cells := in.Shape().RowCellSize()
if rows == 0 || cells == 0 {
return nil
}
out.SetShapeSizes(cells, cells)
flatvw := tensor.NewReshaped(in, rows, cells)
coords := matrix.TriLIndicies(cells)
nc := coords.DimSize(0)
// note: flops estimating 3 per item on average -- different for different metrics.
tensor.VectorizeThreaded(rows*3, func(tsr ...tensor.Tensor) int { return nc },
func(idx int, tsr ...tensor.Tensor) {
cx := coords.Int(idx, 0)
cy := coords.Int(idx, 1)
av := tensor.Reslice(tsr[0], tensor.FullAxis, cx)
bv := tensor.Reslice(tsr[0], tensor.FullAxis, cy)
mout := mfun(av, bv)
tsr[1].SetFloat(mout.Float1D(0), cx, cy)
}, flatvw, out)
for idx := range nc { // copy to upper
cx := coords.Int(idx, 0)
cy := coords.Int(idx, 1)
if cx == cy { // exclude diag
continue
}
out.SetFloat(out.Float(cx, cy), cy, cx)
}
return nil
}
// CovarianceMatrix generates the cells x cells square covariance matrix
// for all per-row cells of the given higher dimensional input tensor,
// which must have at least 2 dimensions: the outermost rows,
// and within that, 1+dimensional patterns (cells).
// Each value in the resulting matrix represents the extent to which the
// value of a given cell covaries across the rows of the tensor with the
// value of another cell.
// Uses the given metric function, typically [Covariance] or [Correlation],
// The metric function must have the [MetricFunc] signature.
// Use Covariance if vars have similar overall scaling,
// which is typical in neural network models, and use
// Correlation if they are on very different scales, because it effectively rescales).
// The resulting matrix can be used as the input to PCA or SVD eigenvalue decomposition.
func CovarianceMatrix(fun any, in tensor.Tensor) tensor.Values {
return tensor.CallOut1Gen1(CovarianceMatrixOut, fun, in)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:generate core generate
package metric
import (
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/tensor"
)
func init() {
tensor.AddFunc(MetricL2Norm.FuncName(), L2Norm)
tensor.AddFunc(MetricSumSquares.FuncName(), SumSquares)
tensor.AddFunc(MetricL1Norm.FuncName(), L1Norm)
tensor.AddFunc(MetricHamming.FuncName(), Hamming)
tensor.AddFunc(MetricL2NormBinTol.FuncName(), L2NormBinTol)
tensor.AddFunc(MetricSumSquaresBinTol.FuncName(), SumSquaresBinTol)
tensor.AddFunc(MetricInvCosine.FuncName(), InvCosine)
tensor.AddFunc(MetricInvCorrelation.FuncName(), InvCorrelation)
tensor.AddFunc(MetricDotProduct.FuncName(), DotProduct)
tensor.AddFunc(MetricCrossEntropy.FuncName(), CrossEntropy)
tensor.AddFunc(MetricCovariance.FuncName(), Covariance)
tensor.AddFunc(MetricCorrelation.FuncName(), Correlation)
tensor.AddFunc(MetricCosine.FuncName(), Cosine)
}
// Metrics are standard metric functions
type Metrics int32 //enums:enum -trim-prefix Metric
const (
// L2Norm is the square root of the sum of squares differences
// between tensor values, aka the L2 Norm.
MetricL2Norm Metrics = iota
// SumSquares is the sum of squares differences between tensor values.
MetricSumSquares
// L1Norm is the sum of the absolute value of differences
// between tensor values, the L1 Norm.
MetricL1Norm
// Hamming is the sum of 1s for every element that is different,
// i.e., "city block" distance.
MetricHamming
// L2NormBinTol is the [L2Norm] square root of the sum of squares
// differences between tensor values, with binary tolerance:
// differences < 0.5 are thresholded to 0.
MetricL2NormBinTol
// SumSquaresBinTol is the [SumSquares] differences between tensor values,
// with binary tolerance: differences < 0.5 are thresholded to 0.
MetricSumSquaresBinTol
// InvCosine is 1-[Cosine], which is useful to convert it
// to an Increasing metric where more different vectors have larger metric values.
MetricInvCosine
// InvCorrelation is 1-[Correlation], which is useful to convert it
// to an Increasing metric where more different vectors have larger metric values.
MetricInvCorrelation
// CrossEntropy is a standard measure of the difference between two
// probabilty distributions, reflecting the additional entropy (uncertainty) associated
// with measuring probabilities under distribution b when in fact they come from
// distribution a. It is also the entropy of a plus the divergence between a from b,
// using Kullback-Leibler (KL) divergence. It is computed as:
// a * log(a/b) + (1-a) * log(1-a/1-b).
MetricCrossEntropy
//////// Everything below here is !Increasing -- larger = closer, not farther
// DotProduct is the sum of the co-products of the tensor values.
MetricDotProduct
// Covariance is co-variance between two vectors,
// i.e., the mean of the co-product of each vector element minus
// the mean of that vector: cov(A,B) = E[(A - E(A))(B - E(B))].
MetricCovariance
// Correlation is the standardized [Covariance] in the range (-1..1),
// computed as the mean of the co-product of each vector
// element minus the mean of that vector, normalized by the product of their
// standard deviations: cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B).
// Equivalent to the [Cosine] of mean-normalized vectors.
MetricCorrelation
// Cosine is high-dimensional angle between two vectors,
// in range (-1..1) as the normalized [DotProduct]:
// inner product / sqrt(ssA * ssB). See also [Correlation].
MetricCosine
)
// FuncName returns the package-qualified function name to use
// in tensor.Call to call this function.
func (m Metrics) FuncName() string {
return "metric." + m.String()
}
// Func returns function for given metric.
func (m Metrics) Func() MetricFunc {
fn := errors.Log1(tensor.FuncByName(m.FuncName()))
return fn.Fun.(MetricFunc)
}
// Call calls a standard Metrics enum function on given tensors.
// Output results are in the out tensor.
func (m Metrics) Call(a, b tensor.Tensor) tensor.Values {
return m.Func()(a, b)
}
// Increasing returns true if the distance metric is such that metric
// values increase as a function of distance (e.g., L2Norm)
// and false if metric values decrease as a function of distance
// (e.g., Cosine, Correlation)
func (m Metrics) Increasing() bool {
if m >= MetricDotProduct {
return false
}
return true
}
// AsMetricFunc returns given function as a [MetricFunc] function,
// or an error if it does not fit that signature.
func AsMetricFunc(fun any) (MetricFunc, error) {
mfun, ok := fun.(MetricFunc)
if !ok {
return nil, errors.New("metric.AsMetricFunc: function does not fit the MetricFunc signature")
}
return mfun, nil
}
// AsMetricOutFunc returns given function as a [MetricFunc] function,
// or an error if it does not fit that signature.
func AsMetricOutFunc(fun any) (MetricOutFunc, error) {
mfun, ok := fun.(MetricOutFunc)
if !ok {
return nil, errors.New("metric.AsMetricOutFunc: function does not fit the MetricOutFunc signature")
}
return mfun, nil
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package metric
import (
"math"
"cogentcore.org/lab/tensor"
)
// ClosestRow returns the closest fit between probe pattern and patterns in
// a "vocabulary" tensor with outermost row dimension, using given metric
// function, which must fit the MetricFunc signature.
// The metric *must have the Increasing property*, i.e., larger = further.
// Output is a 1D tensor with 2 elements: the row index and metric value for that row.
// Note: this does _not_ use any existing Indexes for the probe,
// but does for the vocab, and the returned index is the logical index
// into any existing Indexes.
func ClosestRow(fun any, probe, vocab tensor.Tensor) tensor.Values {
return tensor.CallOut2Gen1(ClosestRowOut, fun, probe, vocab)
}
// ClosestRowOut returns the closest fit between probe pattern and patterns in
// a "vocabulary" tensor with outermost row dimension, using given metric
// function, which must fit the MetricFunc signature.
// The metric *must have the Increasing property*, i.e., larger = further.
// Output is a 1D tensor with 2 elements: the row index and metric value for that row.
// Note: this does _not_ use any existing Indexes for the probe,
// but does for the vocab, and the returned index is the logical index
// into any existing Indexes.
func ClosestRowOut(fun any, probe, vocab tensor.Tensor, out tensor.Values) error {
out.SetShapeSizes(2)
mfun, err := AsMetricFunc(fun)
if err != nil {
return err
}
rows, _ := vocab.Shape().RowCellSize()
mi := -1
mind := math.MaxFloat64
pr1d := tensor.As1D(probe)
for ri := range rows {
sub := tensor.Cells1D(vocab, ri)
mout := mfun(pr1d, sub)
d := mout.Float1D(0)
if d < mind {
mi = ri
mind = d
}
}
out.SetFloat1D(float64(mi), 0)
out.SetFloat1D(mind, 1)
return nil
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package metric
import (
"cogentcore.org/lab/tensor"
)
// VectorizeOut64 is the general compute function for metric.
// This version makes a Float64 output tensor for aggregating
// and computing values, and then copies the results back to the
// original output. This allows metric functions to operate directly
// on integer valued inputs and produce sensible results.
// It returns the Float64 output tensor for further processing as needed.
// a and b are already enforced to be the same shape.
func VectorizeOut64(a, b tensor.Tensor, out tensor.Values, ini float64, fun func(a, b, agg float64) float64) *tensor.Float64 {
rows, cells := a.Shape().RowCellSize()
o64 := tensor.NewFloat64(cells)
if rows <= 0 {
return o64
}
if cells == 1 {
out.SetShapeSizes(1)
agg := ini
switch x := a.(type) {
case *tensor.Float64:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
agg = fun(x.Float1D(i), y.Float1D(i), agg)
}
case *tensor.Float32:
for i := range rows {
agg = fun(x.Float1D(i), y.Float1D(i), agg)
}
default:
for i := range rows {
agg = fun(x.Float1D(i), b.Float1D(i), agg)
}
}
case *tensor.Float32:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
agg = fun(x.Float1D(i), y.Float1D(i), agg)
}
case *tensor.Float32:
for i := range rows {
agg = fun(x.Float1D(i), y.Float1D(i), agg)
}
default:
for i := range rows {
agg = fun(x.Float1D(i), b.Float1D(i), agg)
}
}
default:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
agg = fun(a.Float1D(i), y.Float1D(i), agg)
}
case *tensor.Float32:
for i := range rows {
agg = fun(a.Float1D(i), y.Float1D(i), agg)
}
default:
for i := range rows {
agg = fun(a.Float1D(i), b.Float1D(i), agg)
}
}
}
o64.SetFloat1D(agg, 0)
out.SetFloat1D(agg, 0)
return o64
}
osz := tensor.CellsSize(a.ShapeSizes())
out.SetShapeSizes(osz...)
for i := range cells {
o64.SetFloat1D(ini, i)
}
switch x := a.(type) {
case *tensor.Float64:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), o64.Float1D(j)), j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), o64.Float1D(j)), j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), b.Float1D(si+j), o64.Float1D(j)), j)
}
}
}
case *tensor.Float32:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), o64.Float1D(j)), j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), o64.Float1D(j)), j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), b.Float1D(si+j), o64.Float1D(j)), j)
}
}
}
default:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(a.Float1D(si+j), y.Float1D(si+j), o64.Float1D(j)), j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(a.Float1D(si+j), y.Float1D(si+j), o64.Float1D(j)), j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(a.Float1D(si+j), b.Float1D(si+j), o64.Float1D(j)), j)
}
}
}
}
for j := range cells {
out.SetFloat1D(o64.Float1D(j), j)
}
return o64
}
// VectorizePreOut64 is a version of [VectorizeOut64] that takes additional
// tensor.Float64 inputs of pre-computed values, e.g., the means of each output cell.
func VectorizePreOut64(a, b tensor.Tensor, out tensor.Values, ini float64, preA, preB *tensor.Float64, fun func(a, b, preA, preB, agg float64) float64) *tensor.Float64 {
rows, cells := a.Shape().RowCellSize()
o64 := tensor.NewFloat64(cells)
if rows <= 0 {
return o64
}
if cells == 1 {
out.SetShapeSizes(1)
agg := ini
prevA := preA.Float1D(0)
prevB := preB.Float1D(0)
switch x := a.(type) {
case *tensor.Float64:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
agg = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, agg)
}
case *tensor.Float32:
for i := range rows {
agg = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, agg)
}
default:
for i := range rows {
agg = fun(x.Float1D(i), b.Float1D(i), prevA, prevB, agg)
}
}
case *tensor.Float32:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
agg = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, agg)
}
case *tensor.Float32:
for i := range rows {
agg = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, agg)
}
default:
for i := range rows {
agg = fun(x.Float1D(i), b.Float1D(i), prevA, prevB, agg)
}
}
default:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
agg = fun(a.Float1D(i), y.Float1D(i), prevA, prevB, agg)
}
case *tensor.Float32:
for i := range rows {
agg = fun(a.Float1D(i), y.Float1D(i), prevA, prevB, agg)
}
default:
for i := range rows {
agg = fun(a.Float1D(i), b.Float1D(i), prevA, prevB, agg)
}
}
}
o64.SetFloat1D(agg, 0)
out.SetFloat1D(agg, 0)
return o64
}
osz := tensor.CellsSize(a.ShapeSizes())
out.SetShapeSizes(osz...)
for j := range cells {
o64.SetFloat1D(ini, j)
}
switch x := a.(type) {
case *tensor.Float64:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), b.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j)
}
}
}
case *tensor.Float32:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), b.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j)
}
}
}
default:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(a.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(a.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(a.Float1D(si+j), b.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j)
}
}
}
}
for i := range cells {
out.SetFloat1D(o64.Float1D(i), i)
}
return o64
}
// Vectorize2Out64 is a version of [VectorizeOut64] that separately aggregates
// two output values, x and y as tensor.Float64.
func Vectorize2Out64(a, b tensor.Tensor, iniX, iniY float64, fun func(a, b, ox, oy float64) (float64, float64)) (ox64, oy64 *tensor.Float64) {
rows, cells := a.Shape().RowCellSize()
ox64 = tensor.NewFloat64(cells)
oy64 = tensor.NewFloat64(cells)
if rows <= 0 {
return
}
if cells == 1 {
ox := iniX
oy := iniY
switch x := a.(type) {
case *tensor.Float64:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
ox, oy = fun(x.Float1D(i), y.Float1D(i), ox, oy)
}
case *tensor.Float32:
for i := range rows {
ox, oy = fun(x.Float1D(i), y.Float1D(i), ox, oy)
}
default:
for i := range rows {
ox, oy = fun(x.Float1D(i), b.Float1D(i), ox, oy)
}
}
case *tensor.Float32:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
ox, oy = fun(x.Float1D(i), y.Float1D(i), ox, oy)
}
case *tensor.Float32:
for i := range rows {
ox, oy = fun(x.Float1D(i), y.Float1D(i), ox, oy)
}
default:
for i := range rows {
ox, oy = fun(x.Float1D(i), b.Float1D(i), ox, oy)
}
}
default:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
ox, oy = fun(a.Float1D(i), y.Float1D(i), ox, oy)
}
case *tensor.Float32:
for i := range rows {
ox, oy = fun(a.Float1D(i), y.Float1D(i), ox, oy)
}
default:
for i := range rows {
ox, oy = fun(a.Float1D(i), b.Float1D(i), ox, oy)
}
}
}
ox64.SetFloat1D(ox, 0)
oy64.SetFloat1D(oy, 0)
return
}
for j := range cells {
ox64.SetFloat1D(iniX, j)
oy64.SetFloat1D(iniY, j)
}
switch x := a.(type) {
case *tensor.Float64:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy := fun(x.Float1D(si+j), b.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
}
case *tensor.Float32:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy := fun(x.Float1D(si+j), b.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
}
default:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy := fun(a.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy := fun(a.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy := fun(a.Float1D(si+j), b.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
}
}
return
}
// Vectorize3Out64 is a version of [VectorizeOut64] that has 3 outputs instead of 1.
func Vectorize3Out64(a, b tensor.Tensor, iniX, iniY, iniZ float64, fun func(a, b, ox, oy, oz float64) (float64, float64, float64)) (ox64, oy64, oz64 *tensor.Float64) {
rows, cells := a.Shape().RowCellSize()
ox64 = tensor.NewFloat64(cells)
oy64 = tensor.NewFloat64(cells)
oz64 = tensor.NewFloat64(cells)
if rows <= 0 {
return
}
if cells == 1 {
ox := iniX
oy := iniY
oz := iniZ
switch x := a.(type) {
case *tensor.Float64:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), ox, oy, oz)
}
case *tensor.Float32:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), ox, oy, oz)
}
default:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), b.Float1D(i), ox, oy, oz)
}
}
case *tensor.Float32:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), ox, oy, oz)
}
case *tensor.Float32:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), ox, oy, oz)
}
default:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), b.Float1D(i), ox, oy, oz)
}
}
default:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
ox, oy, oz = fun(a.Float1D(i), y.Float1D(i), ox, oy, oz)
}
case *tensor.Float32:
for i := range rows {
ox, oy, oz = fun(a.Float1D(i), y.Float1D(i), ox, oy, oz)
}
default:
for i := range rows {
ox, oy, oz = fun(a.Float1D(i), b.Float1D(i), ox, oy, oz)
}
}
}
ox64.SetFloat1D(ox, 0)
oy64.SetFloat1D(oy, 0)
oz64.SetFloat1D(oz, 0)
return
}
for j := range cells {
ox64.SetFloat1D(iniX, j)
oy64.SetFloat1D(iniY, j)
oz64.SetFloat1D(iniZ, j)
}
switch x := a.(type) {
case *tensor.Float64:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), b.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
}
case *tensor.Float32:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), b.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
}
default:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(a.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(a.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(a.Float1D(si+j), b.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
}
}
return
}
// VectorizePre3Out64 is a version of [VectorizePreOut64] that takes additional
// tensor.Float64 inputs of pre-computed values, e.g., the means of each output cell,
// and has 3 outputs instead of 1.
func VectorizePre3Out64(a, b tensor.Tensor, iniX, iniY, iniZ float64, preA, preB *tensor.Float64, fun func(a, b, preA, preB, ox, oy, oz float64) (float64, float64, float64)) (ox64, oy64, oz64 *tensor.Float64) {
rows, cells := a.Shape().RowCellSize()
ox64 = tensor.NewFloat64(cells)
oy64 = tensor.NewFloat64(cells)
oz64 = tensor.NewFloat64(cells)
if rows <= 0 {
return
}
if cells == 1 {
ox := iniX
oy := iniY
oz := iniZ
prevA := preA.Float1D(0)
prevB := preB.Float1D(0)
switch x := a.(type) {
case *tensor.Float64:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, ox, oy, oz)
}
case *tensor.Float32:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, ox, oy, oz)
}
default:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), b.Float1D(i), prevA, prevB, ox, oy, oz)
}
}
case *tensor.Float32:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, ox, oy, oz)
}
case *tensor.Float32:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, ox, oy, oz)
}
default:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), b.Float1D(i), prevA, prevB, ox, oy, oz)
}
}
default:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
ox, oy, oz = fun(a.Float1D(i), y.Float1D(i), prevA, prevB, ox, oy, oz)
}
case *tensor.Float32:
for i := range rows {
ox, oy, oz = fun(a.Float1D(i), y.Float1D(i), prevA, prevB, ox, oy, oz)
}
default:
for i := range rows {
ox, oy, oz = fun(a.Float1D(i), b.Float1D(i), prevA, prevB, ox, oy, oz)
}
}
}
ox64.SetFloat1D(ox, 0)
oy64.SetFloat1D(oy, 0)
oz64.SetFloat1D(oz, 0)
return
}
for j := range cells {
ox64.SetFloat1D(iniX, j)
oy64.SetFloat1D(iniY, j)
oz64.SetFloat1D(iniZ, j)
}
switch x := a.(type) {
case *tensor.Float64:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), b.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
}
case *tensor.Float32:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), b.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
}
default:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(a.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(a.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(a.Float1D(si+j), b.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
}
}
return
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stats
import (
"strconv"
"cogentcore.org/core/base/metadata"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorfs"
)
// DescriptiveStats are the standard descriptive stats used in Describe function.
// Cannot apply the final 3 sort-based stats to higher-dimensional data.
var DescriptiveStats = []Stats{StatCount, StatMean, StatStd, StatSem, StatMin, StatMax, StatQ1, StatMedian, StatQ3}
// Describe adds standard descriptive statistics for given tensor
// to the given [tensorfs] directory, adding a directory for each tensor
// and result tensor stats for each result.
// This is an easy way to provide a comprehensive description of data.
// The [DescriptiveStats] list is: [Count], [Mean], [Std], [Sem],
// [Min], [Max], [Q1], [Median], [Q3]
func Describe(dir *tensorfs.Node, tsrs ...tensor.Tensor) {
dd := dir.Dir("Describe")
for i, tsr := range tsrs {
nr := tsr.DimSize(0)
if nr == 0 {
continue
}
nm := metadata.Name(tsr)
if nm == "" {
nm = strconv.Itoa(i)
}
td := dd.Dir(nm)
for _, st := range DescriptiveStats {
stnm := st.String()
sv := tensorfs.Scalar[float64](td, stnm)
stout := st.Call(tsr)
sv.CopyFrom(stout)
}
}
}
// DescribeTable runs [Describe] on given columns in table.
func DescribeTable(dir *tensorfs.Node, dt *table.Table, columns ...string) {
Describe(dir, dt.ColumnList(columns...)...)
}
// DescribeTableAll runs [Describe] on all numeric columns in given table.
func DescribeTableAll(dir *tensorfs.Node, dt *table.Table) {
var cols []string
for i, cl := range dt.Columns.Values {
if !cl.IsString() {
cols = append(cols, dt.ColumnName(i))
}
}
Describe(dir, dt.ColumnList(cols...)...)
}
// Code generated by "core generate"; DO NOT EDIT.
package stats
import (
"cogentcore.org/core/enums"
)
var _StatsValues = []Stats{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}
// StatsN is the highest valid value for type Stats, plus one.
const StatsN Stats = 22
var _StatsValueMap = map[string]Stats{`Count`: 0, `Sum`: 1, `L1Norm`: 2, `Prod`: 3, `Min`: 4, `Max`: 5, `MinAbs`: 6, `MaxAbs`: 7, `Mean`: 8, `Var`: 9, `Std`: 10, `Sem`: 11, `SumSq`: 12, `L2Norm`: 13, `VarPop`: 14, `StdPop`: 15, `SemPop`: 16, `Median`: 17, `Q1`: 18, `Q3`: 19, `First`: 20, `Final`: 21}
var _StatsDescMap = map[Stats]string{0: `count of number of elements.`, 1: `sum of elements.`, 2: `L1 Norm: sum of absolute values of elements.`, 3: `product of elements.`, 4: `minimum value.`, 5: `maximum value.`, 6: `minimum of absolute values.`, 7: `maximum of absolute values.`, 8: `mean value = sum / count.`, 9: `sample variance (squared deviations from mean, divided by n-1).`, 10: `sample standard deviation (sqrt of Var).`, 11: `sample standard error of the mean (Std divided by sqrt(n)).`, 12: `sum of squared values.`, 13: `L2 Norm: square-root of sum-of-squares.`, 14: `population variance (squared diffs from mean, divided by n).`, 15: `population standard deviation (sqrt of VarPop).`, 16: `population standard error of the mean (StdPop divided by sqrt(n)).`, 17: `middle value in sorted ordering.`, 18: `Q1 first quartile = 25%ile value = .25 quantile value.`, 19: `Q3 third quartile = 75%ile value = .75 quantile value.`, 20: `first item in the set of data: for data with a natural ordering.`, 21: `final item in the set of data: for data with a natural ordering.`}
var _StatsMap = map[Stats]string{0: `Count`, 1: `Sum`, 2: `L1Norm`, 3: `Prod`, 4: `Min`, 5: `Max`, 6: `MinAbs`, 7: `MaxAbs`, 8: `Mean`, 9: `Var`, 10: `Std`, 11: `Sem`, 12: `SumSq`, 13: `L2Norm`, 14: `VarPop`, 15: `StdPop`, 16: `SemPop`, 17: `Median`, 18: `Q1`, 19: `Q3`, 20: `First`, 21: `Final`}
// String returns the string representation of this Stats value.
func (i Stats) String() string { return enums.String(i, _StatsMap) }
// SetString sets the Stats value from its string representation,
// and returns an error if the string is invalid.
func (i *Stats) SetString(s string) error { return enums.SetString(i, s, _StatsValueMap, "Stats") }
// Int64 returns the Stats value as an int64.
func (i Stats) Int64() int64 { return int64(i) }
// SetInt64 sets the Stats value from an int64.
func (i *Stats) SetInt64(in int64) { *i = Stats(in) }
// Desc returns the description of the Stats value.
func (i Stats) Desc() string { return enums.Desc(i, _StatsDescMap) }
// StatsValues returns all possible values for the type Stats.
func StatsValues() []Stats { return _StatsValues }
// Values returns all possible values for the type Stats.
func (i Stats) Values() []enums.Enum { return enums.Values(_StatsValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i Stats) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *Stats) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Stats") }
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stats
import (
"math"
"cogentcore.org/lab/tensor"
)
// StatsFunc is the function signature for a stats function that
// returns a new output vector. This can be less efficient for repeated
// computations where the output can be re-used: see [StatsOutFunc].
// But this version can be directly chained with other function calls.
// Function is computed over the outermost row dimension and the
// output is the shape of the remaining inner cells (a scalar for 1D inputs).
// Use [tensor.As1D], [tensor.NewRowCellsView], [tensor.Cells1D] etc
// to reshape and reslice the data as needed.
// All stats functions skip over NaN's, as a missing value.
// Stats functions cannot be computed in parallel,
// e.g., using VectorizeThreaded or GPU, due to shared writing
// to the same output values. Special implementations are required
// if that is needed.
type StatsFunc = func(in tensor.Tensor) tensor.Values
// StatsOutFunc is the function signature for a stats function,
// that takes output values as final argument. See [StatsFunc]
// This version is for computationally demanding cases and saves
// reallocation of output.
type StatsOutFunc = func(in tensor.Tensor, out tensor.Values) error
// CountOut64 computes the count of non-NaN tensor values,
// and returns the Float64 output values for subsequent use.
func CountOut64(in tensor.Tensor, out tensor.Values) *tensor.Float64 {
return VectorizeOut64(in, out, 0, func(val, agg float64) float64 {
if math.IsNaN(val) {
return agg
}
return agg + 1
})
}
// Count computes the count of non-NaN tensor values.
// See [StatsFunc] for general information.
func Count(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(CountOut, in)
}
// CountOut computes the count of non-NaN tensor values.
// See [StatsOutFunc] for general information.
func CountOut(in tensor.Tensor, out tensor.Values) error {
CountOut64(in, out)
return nil
}
// SumOut64 computes the sum of tensor values,
// and returns the Float64 output values for subsequent use.
func SumOut64(in tensor.Tensor, out tensor.Values) *tensor.Float64 {
return VectorizeOut64(in, out, 0, func(val, agg float64) float64 {
if math.IsNaN(val) {
return agg
}
return agg + val
})
}
// SumOut computes the sum of tensor values.
// See [StatsOutFunc] for general information.
func SumOut(in tensor.Tensor, out tensor.Values) error {
SumOut64(in, out)
return nil
}
// Sum computes the sum of tensor values.
// See [StatsFunc] for general information.
func Sum(in tensor.Tensor) tensor.Values {
out := tensor.NewOfType(in.DataType())
SumOut64(in, out)
return out
}
// L1Norm computes the sum of absolute-value-of tensor values.
// See [StatsFunc] for general information.
func L1Norm(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(L1NormOut, in)
}
// L1NormOut computes the sum of absolute-value-of tensor values.
// See [StatsFunc] for general information.
func L1NormOut(in tensor.Tensor, out tensor.Values) error {
VectorizeOut64(in, out, 0, func(val, agg float64) float64 {
if math.IsNaN(val) {
return agg
}
return agg + math.Abs(val)
})
return nil
}
// Prod computes the product of tensor values.
// See [StatsFunc] for general information.
func Prod(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(ProdOut, in)
}
// ProdOut computes the product of tensor values.
// See [StatsOutFunc] for general information.
func ProdOut(in tensor.Tensor, out tensor.Values) error {
VectorizeOut64(in, out, 1, func(val, agg float64) float64 {
if math.IsNaN(val) {
return agg
}
return agg * val
})
return nil
}
// Min computes the min of tensor values.
// See [StatsFunc] for general information.
func Min(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(MinOut, in)
}
// MinOut computes the min of tensor values.
// See [StatsOutFunc] for general information.
func MinOut(in tensor.Tensor, out tensor.Values) error {
VectorizeOut64(in, out, math.MaxFloat64, func(val, agg float64) float64 {
if math.IsNaN(val) {
return agg
}
return math.Min(agg, val)
})
return nil
}
// Max computes the max of tensor values.
// See [StatsFunc] for general information.
func Max(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(MaxOut, in)
}
// MaxOut computes the max of tensor values.
// See [StatsOutFunc] for general information.
func MaxOut(in tensor.Tensor, out tensor.Values) error {
VectorizeOut64(in, out, -math.MaxFloat64, func(val, agg float64) float64 {
if math.IsNaN(val) {
return agg
}
return math.Max(agg, val)
})
return nil
}
// MinAbs computes the min of absolute-value-of tensor values.
// See [StatsFunc] for general information.
func MinAbs(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(MinAbsOut, in)
}
// MinAbsOut computes the min of absolute-value-of tensor values.
// See [StatsOutFunc] for general information.
func MinAbsOut(in tensor.Tensor, out tensor.Values) error {
VectorizeOut64(in, out, math.MaxFloat64, func(val, agg float64) float64 {
if math.IsNaN(val) {
return agg
}
return math.Min(agg, math.Abs(val))
})
return nil
}
// MaxAbs computes the max of absolute-value-of tensor values.
// See [StatsFunc] for general information.
func MaxAbs(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(MaxAbsOut, in)
}
// MaxAbsOut computes the max of absolute-value-of tensor values.
// See [StatsOutFunc] for general information.
func MaxAbsOut(in tensor.Tensor, out tensor.Values) error {
VectorizeOut64(in, out, -math.MaxFloat64, func(val, agg float64) float64 {
if math.IsNaN(val) {
return agg
}
return math.Max(agg, math.Abs(val))
})
return nil
}
// MeanOut64 computes the mean of tensor values,
// and returns the Float64 output values for subsequent use.
func MeanOut64(in tensor.Tensor, out tensor.Values) (mean64, count64 *tensor.Float64) {
var sum64 *tensor.Float64
sum64, count64 = Vectorize2Out64(in, 0, 0, func(val, sum, count float64) (float64, float64) {
if math.IsNaN(val) {
return sum, count
}
count += 1
sum += val
return sum, count
})
osz := tensor.CellsSize(in.ShapeSizes())
out.SetShapeSizes(osz...)
nsub := out.Len()
for i := range nsub {
c := count64.Float1D(i)
if c == 0 {
continue
}
mean := sum64.Float1D(i) / c
sum64.SetFloat1D(mean, i)
out.SetFloat1D(mean, i)
}
return sum64, count64
}
// Mean computes the mean of tensor values.
// See [StatsFunc] for general information.
func Mean(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(MeanOut, in)
}
// MeanOut computes the mean of tensor values.
// See [StatsOutFunc] for general information.
func MeanOut(in tensor.Tensor, out tensor.Values) error {
MeanOut64(in, out)
return nil
}
// SumSqDevOut64 computes the sum of squared mean deviates of tensor values,
// and returns the Float64 output values for subsequent use.
func SumSqDevOut64(in tensor.Tensor, out tensor.Values) (ssd64, mean64, count64 *tensor.Float64) {
mean64, count64 = MeanOut64(in, out)
ssd64 = VectorizePreOut64(in, out, 0, mean64, func(val, mean, agg float64) float64 {
if math.IsNaN(val) {
return agg
}
dv := val - mean
return agg + dv*dv
})
return
}
// VarOut64 computes the sample variance of tensor values,
// and returns the Float64 output values for subsequent use.
func VarOut64(in tensor.Tensor, out tensor.Values) (var64, mean64, count64 *tensor.Float64) {
var64, mean64, count64 = SumSqDevOut64(in, out)
nsub := out.Len()
for i := range nsub {
c := count64.Float1D(i)
if c < 2 {
continue
}
vr := var64.Float1D(i) / (c - 1)
var64.SetFloat1D(vr, i)
out.SetFloat1D(vr, i)
}
return
}
// Var computes the sample variance of tensor values.
// Squared deviations from mean, divided by n-1. See also [VarPopFunc].
// See [StatsFunc] for general information.
func Var(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(VarOut, in)
}
// VarOut computes the sample variance of tensor values.
// Squared deviations from mean, divided by n-1. See also [VarPopFunc].
// See [StatsOutFunc] for general information.
func VarOut(in tensor.Tensor, out tensor.Values) error {
VarOut64(in, out)
return nil
}
// StdOut64 computes the sample standard deviation of tensor values.
// and returns the Float64 output values for subsequent use.
func StdOut64(in tensor.Tensor, out tensor.Values) (std64, mean64, count64 *tensor.Float64) {
std64, mean64, count64 = VarOut64(in, out)
nsub := out.Len()
for i := range nsub {
std := math.Sqrt(std64.Float1D(i))
std64.SetFloat1D(std, i)
out.SetFloat1D(std, i)
}
return
}
// Std computes the sample standard deviation of tensor values.
// Sqrt of variance from [VarFunc]. See also [StdPopFunc].
// See [StatsFunc] for general information.
func Std(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(StdOut, in)
}
// StdOut computes the sample standard deviation of tensor values.
// Sqrt of variance from [VarFunc]. See also [StdPopFunc].
// See [StatsOutFunc] for general information.
func StdOut(in tensor.Tensor, out tensor.Values) error {
StdOut64(in, out)
return nil
}
// Sem computes the sample standard error of the mean of tensor values.
// Standard deviation [StdFunc] / sqrt(n). See also [SemPopFunc].
// See [StatsFunc] for general information.
func Sem(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(SemOut, in)
}
// SemOut computes the sample standard error of the mean of tensor values.
// Standard deviation [StdFunc] / sqrt(n). See also [SemPopFunc].
// See [StatsOutFunc] for general information.
func SemOut(in tensor.Tensor, out tensor.Values) error {
var64, _, count64 := VarOut64(in, out)
nsub := out.Len()
for i := range nsub {
c := count64.Float1D(i)
if c < 2 {
out.SetFloat1D(math.Sqrt(var64.Float1D(i)), i)
} else {
out.SetFloat1D(math.Sqrt(var64.Float1D(i))/math.Sqrt(c), i)
}
}
return nil
}
// VarPopOut64 computes the population variance of tensor values.
// and returns the Float64 output values for subsequent use.
func VarPopOut64(in tensor.Tensor, out tensor.Values) (var64, mean64, count64 *tensor.Float64) {
var64, mean64, count64 = SumSqDevOut64(in, out)
nsub := out.Len()
for i := range nsub {
c := count64.Float1D(i)
if c == 0 {
continue
}
var64.SetFloat1D(var64.Float1D(i)/c, i)
out.SetFloat1D(var64.Float1D(i), i)
}
return
}
// VarPop computes the population variance of tensor values.
// Squared deviations from mean, divided by n. See also [VarFunc].
// See [StatsFunc] for general information.
func VarPop(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(VarPopOut, in)
}
// VarPopOut computes the population variance of tensor values.
// Squared deviations from mean, divided by n. See also [VarFunc].
// See [StatsOutFunc] for general information.
func VarPopOut(in tensor.Tensor, out tensor.Values) error {
VarPopOut64(in, out)
return nil
}
// StdPop computes the population standard deviation of tensor values.
// Sqrt of variance from [VarPopFunc]. See also [StdFunc].
// See [StatsFunc] for general information.
func StdPop(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(StdPopOut, in)
}
// StdPopOut computes the population standard deviation of tensor values.
// Sqrt of variance from [VarPopFunc]. See also [StdFunc].
// See [StatsOutFunc] for general information.
func StdPopOut(in tensor.Tensor, out tensor.Values) error {
var64, _, _ := VarPopOut64(in, out)
nsub := out.Len()
for i := range nsub {
out.SetFloat1D(math.Sqrt(var64.Float1D(i)), i)
}
return nil
}
// SemPop computes the population standard error of the mean of tensor values.
// Standard deviation [StdPopFunc] / sqrt(n). See also [SemFunc].
// See [StatsFunc] for general information.
func SemPop(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(SemPopOut, in)
}
// SemPopOut computes the population standard error of the mean of tensor values.
// Standard deviation [StdPopFunc] / sqrt(n). See also [SemFunc].
// See [StatsOutFunc] for general information.
func SemPopOut(in tensor.Tensor, out tensor.Values) error {
var64, _, count64 := VarPopOut64(in, out)
nsub := out.Len()
for i := range nsub {
c := count64.Float1D(i)
if c < 2 {
out.SetFloat1D(math.Sqrt(var64.Float1D(i)), i)
} else {
out.SetFloat1D(math.Sqrt(var64.Float1D(i))/math.Sqrt(c), i)
}
}
return nil
}
// SumSqScaleOut64 is a helper for sum-of-squares, returning scale and ss
// factors aggregated separately for better numerical stability, per BLAS.
// Returns the Float64 output values for subsequent use.
func SumSqScaleOut64(in tensor.Tensor) (scale64, ss64 *tensor.Float64) {
scale64, ss64 = Vectorize2Out64(in, 0, 1, func(val, scale, ss float64) (float64, float64) {
if math.IsNaN(val) || val == 0 {
return scale, ss
}
absxi := math.Abs(val)
if scale < absxi {
ss = 1 + ss*(scale/absxi)*(scale/absxi)
scale = absxi
} else {
ss = ss + (absxi/scale)*(absxi/scale)
}
return scale, ss
})
return
}
// SumSqOut64 computes the sum of squares of tensor values,
// and returns the Float64 output values for subsequent use.
func SumSqOut64(in tensor.Tensor, out tensor.Values) *tensor.Float64 {
scale64, ss64 := SumSqScaleOut64(in)
osz := tensor.CellsSize(in.ShapeSizes())
out.SetShapeSizes(osz...)
nsub := out.Len()
for i := range nsub {
scale := scale64.Float1D(i)
ss := ss64.Float1D(i)
v := 0.0
if math.IsInf(scale, 1) {
v = math.Inf(1)
} else {
v = scale * scale * ss
}
scale64.SetFloat1D(v, i)
out.SetFloat1D(v, i)
}
return scale64
}
// SumSq computes the sum of squares of tensor values,
// See [StatsFunc] for general information.
func SumSq(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(SumSqOut, in)
}
// SumSqOut computes the sum of squares of tensor values,
// See [StatsOutFunc] for general information.
func SumSqOut(in tensor.Tensor, out tensor.Values) error {
SumSqOut64(in, out)
return nil
}
// L2NormOut64 computes the square root of the sum of squares of tensor values,
// known as the L2 norm, and returns the Float64 output values for
// use in subsequent computations.
func L2NormOut64(in tensor.Tensor, out tensor.Values) *tensor.Float64 {
scale64, ss64 := SumSqScaleOut64(in)
osz := tensor.CellsSize(in.ShapeSizes())
out.SetShapeSizes(osz...)
nsub := out.Len()
for i := range nsub {
scale := scale64.Float1D(i)
ss := ss64.Float1D(i)
v := 0.0
if math.IsInf(scale, 1) {
v = math.Inf(1)
} else {
v = scale * math.Sqrt(ss)
}
scale64.SetFloat1D(v, i)
out.SetFloat1D(v, i)
}
return scale64
}
// L2Norm computes the square root of the sum of squares of tensor values,
// known as the L2 norm.
// See [StatsFunc] for general information.
func L2Norm(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(L2NormOut, in)
}
// L2NormOut computes the square root of the sum of squares of tensor values,
// known as the L2 norm.
// See [StatsOutFunc] for general information.
func L2NormOut(in tensor.Tensor, out tensor.Values) error {
L2NormOut64(in, out)
return nil
}
// First returns the first tensor value(s), as a stats function,
// for the starting point in a naturally-ordered set of data.
// See [StatsFunc] for general information.
func First(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(FirstOut, in)
}
// FirstOut returns the first tensor value(s), as a stats function,
// for the starting point in a naturally-ordered set of data.
// See [StatsOutFunc] for general information.
func FirstOut(in tensor.Tensor, out tensor.Values) error {
rows, cells := in.Shape().RowCellSize()
if cells == 1 {
out.SetShapeSizes(1)
if rows > 0 {
out.SetFloat1D(in.Float1D(0), 0)
}
return nil
}
osz := tensor.CellsSize(in.ShapeSizes())
out.SetShapeSizes(osz...)
if rows == 0 {
return nil
}
for i := range cells {
out.SetFloat1D(in.Float1D(i), i)
}
return nil
}
// Final returns the final tensor value(s), as a stats function,
// for the ending point in a naturally-ordered set of data.
// See [StatsFunc] for general information.
func Final(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(FinalOut, in)
}
// FinalOut returns the first tensor value(s), as a stats function,
// for the ending point in a naturally-ordered set of data.
// See [StatsOutFunc] for general information.
func FinalOut(in tensor.Tensor, out tensor.Values) error {
rows, cells := in.Shape().RowCellSize()
if cells == 1 {
out.SetShapeSizes(1)
if rows > 0 {
out.SetFloat1D(in.Float1D(rows-1), 0)
}
return nil
}
osz := tensor.CellsSize(in.ShapeSizes())
out.SetShapeSizes(osz...)
if rows == 0 {
return nil
}
st := (rows - 1) * cells
for i := range cells {
out.SetFloat1D(in.Float1D(st+i), i)
}
return nil
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stats
import (
"strconv"
"strings"
"cogentcore.org/core/base/metadata"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorfs"
)
// Groups generates indexes for each unique value in each of the given tensors.
// One can then use the resulting indexes for the [tensor.Rows] indexes to
// perform computations restricted to grouped subsets of data, as in the
// [GroupStats] function. See [GroupCombined] for function that makes a
// "Combined" Group that has a unique group for each _combination_ of
// the separate, independent groups created by this function.
// It creates subdirectories in a "Groups" directory within given [tensorfs],
// for each tensor passed in here, using the metadata Name property for
// names (index if empty).
// Within each subdirectory there are int tensors for each unique 1D
// row-wise value of elements in the input tensor, named as the string
// representation of the value, where the int tensor contains a list of
// row-wise indexes corresponding to the source rows having that value.
// Note that these indexes are directly in terms of the underlying [Tensor] data
// rows, indirected through any existing indexes on the inputs, so that
// the results can be used directly as Indexes into the corresponding tensor data.
// Uses a stable sort on columns, so ordering of other dimensions is preserved.
func Groups(dir *tensorfs.Node, tsrs ...tensor.Tensor) error {
gd := dir.Dir("Groups")
makeIdxs := func(dir *tensorfs.Node, srt *tensor.Rows, val string, start, r int) {
n := r - start
it := tensorfs.Value[int](dir, val, n)
for j := range n {
it.SetIntRow(srt.Indexes[start+j], j, 0) // key to indirect through sort indexes
}
}
for i, tsr := range tsrs {
nr := tsr.DimSize(0)
if nr == 0 {
continue
}
nm := metadata.Name(tsr)
if nm == "" {
nm = strconv.Itoa(i)
}
td := gd.Dir(nm)
srt := tensor.AsRows(tsr).CloneIndexes()
srt.SortStable(tensor.Ascending)
start := 0
if tsr.IsString() {
lastVal := srt.StringRow(0, 0)
for r := range nr {
v := srt.StringRow(r, 0)
if v != lastVal {
makeIdxs(td, srt, lastVal, start, r)
start = r
lastVal = v
}
}
if start != nr-1 {
makeIdxs(td, srt, lastVal, start, nr)
}
} else {
lastVal := srt.FloatRow(0, 0)
for r := range nr {
v := srt.FloatRow(r, 0)
if v != lastVal {
makeIdxs(td, srt, tensor.Float64ToString(lastVal), start, r)
start = r
lastVal = v
}
}
if start != nr-1 {
makeIdxs(td, srt, tensor.Float64ToString(lastVal), start, nr)
}
}
}
return nil
}
// TableGroups runs [Groups] on the given columns from given [table.Table].
func TableGroups(dir *tensorfs.Node, dt *table.Table, columns ...string) error {
dv := table.NewView(dt)
// important for consistency across columns, to do full outer product sort first.
dv.SortColumns(tensor.Ascending, tensor.StableSort, columns...)
return Groups(dir, dv.ColumnList(columns...)...)
}
// GroupAll copies all indexes from the first given tensor,
// into an "All/All" tensor in the given [tensorfs], which can then
// be used with [GroupStats] to generate summary statistics across
// all the data. See [Groups] for more general documentation.
func GroupAll(dir *tensorfs.Node, tsrs ...tensor.Tensor) error {
gd := dir.Dir("Groups")
tsr := tensor.AsRows(tsrs[0])
nr := tsr.NumRows()
if nr == 0 {
return nil
}
td := gd.Dir("All")
it := tensorfs.Value[int](td, "All", nr)
for j := range nr {
it.SetIntRow(tsr.RowIndex(j), j, 0) // key to indirect through any existing indexes
}
return nil
}
// todo: GroupCombined
// GroupStats computes the given stats function on the unique grouped indexes
// produced by the [Groups] function, in the given [tensorfs] directory,
// applied to each of the tensors passed here.
// It creates a "Stats" subdirectory in given directory, with
// subdirectories with the name of each value tensor (if it does not
// yet exist), and then creates a subdirectory within that
// for the statistic name. Within that statistic directory, it creates
// a String tensor with the unique values of each source [Groups] tensor,
// and a aligned Float64 tensor with the statistics results for each such
// unique group value. See the README.md file for a diagram of the results.
func GroupStats(dir *tensorfs.Node, stat Stats, tsrs ...tensor.Tensor) error {
gd := dir.Dir("Groups")
sd := dir.Dir("Stats")
stnm := StripPackage(stat.String())
groups, _ := gd.Nodes()
for _, gp := range groups {
gpnm := gp.Name()
ggd := gd.Dir(gpnm)
vals := ggd.ValuesFunc(nil)
nv := len(vals)
if nv == 0 {
continue
}
sgd := sd.Dir(gpnm)
gv := sgd.Node(gpnm)
if gv == nil {
gtsr := tensorfs.Value[string](sgd, gpnm, nv)
for i, v := range vals {
gtsr.SetStringRow(metadata.Name(v), i, 0)
}
}
for _, tsr := range tsrs {
vd := sgd.Dir(metadata.Name(tsr))
sv := tensorfs.Value[float64](vd, stnm, nv)
for i, v := range vals {
idx := tensor.AsIntSlice(v)
sg := tensor.NewRows(tsr.AsValues(), idx...)
stout := stat.Call(sg)
sv.SetFloatRow(stout.Float1D(0), i, 0)
}
}
}
return nil
}
// TableGroupStats runs [GroupStats] using standard [Stats]
// on the given columns from given [table.Table].
func TableGroupStats(dir *tensorfs.Node, stat Stats, dt *table.Table, columns ...string) error {
return GroupStats(dir, stat, dt.ColumnList(columns...)...)
}
// GroupDescribe runs standard descriptive statistics on given tensor data
// using [GroupStats] function, with [DescriptiveStats] list of stats.
func GroupDescribe(dir *tensorfs.Node, tsrs ...tensor.Tensor) error {
for _, st := range DescriptiveStats {
err := GroupStats(dir, st, tsrs...)
if err != nil {
return err
}
}
return nil
}
// TableGroupDescribe runs [GroupDescribe] on the given columns from given [table.Table].
func TableGroupDescribe(dir *tensorfs.Node, dt *table.Table, columns ...string) error {
return GroupDescribe(dir, dt.ColumnList(columns...)...)
}
// GroupStatsAsTable returns the results from [GroupStats] in given directory
// as a [table.Table], using [tensorfs.DirTable] function.
func GroupStatsAsTable(dir *tensorfs.Node) *table.Table {
return tensorfs.DirTable(dir.Node("Stats"), nil)
}
// GroupStatsAsTableNoStatName returns the results from [GroupStats]
// in given directory as a [table.Table], using [tensorfs.DirTable] function.
// Column names are updated to not include the stat name, if there is only
// one statistic such that the resulting name will still be unique.
// Otherwise, column names are Value/Stat.
func GroupStatsAsTableNoStatName(dir *tensorfs.Node) *table.Table {
dt := tensorfs.DirTable(dir.Node("Stats"), nil)
cols := make(map[string]string)
for _, nm := range dt.Columns.Keys {
vn := nm
si := strings.Index(nm, "/")
if si > 0 {
vn = nm[:si]
}
if _, exists := cols[vn]; exists {
continue
}
cols[vn] = nm
}
for k, v := range cols {
ci := dt.Columns.IndexByKey(v)
dt.Columns.RenameIndex(ci, k)
}
return dt
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stats
import (
"cogentcore.org/core/math32"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensor/tmath"
)
// ZScore computes Z-normalized values into given output tensor,
// subtracting the Mean and dividing by the standard deviation.
func ZScore(a tensor.Tensor) tensor.Values {
return tensor.CallOut1(ZScoreOut, a)
}
// ZScore computes Z-normalized values into given output tensor,
// subtracting the Mean and dividing by the standard deviation.
func ZScoreOut(a tensor.Tensor, out tensor.Values) error {
mout := tensor.NewFloat64()
std, mean, _ := StdOut64(a, mout)
tmath.SubOut(a, mean, out)
tmath.DivOut(out, std, out)
return nil
}
// UnitNorm computes unit normalized values into given output tensor,
// subtracting the Min value and dividing by the Max of the remaining numbers.
func UnitNorm(a tensor.Tensor) tensor.Values {
return tensor.CallOut1(UnitNormOut, a)
}
// UnitNormOut computes unit normalized values into given output tensor,
// subtracting the Min value and dividing by the Max of the remaining numbers.
func UnitNormOut(a tensor.Tensor, out tensor.Values) error {
mout := tensor.NewFloat64()
err := MinOut(a, mout)
if err != nil {
return err
}
tmath.SubOut(a, mout, out)
MaxOut(out, mout)
tmath.DivOut(out, mout, out)
return nil
}
// Clamp ensures that all values are within min, max limits, clamping
// values to those bounds if they exceed them. min and max args are
// treated as scalars (first value used).
func Clamp(in, minv, maxv tensor.Tensor) tensor.Values {
return tensor.CallOut3(ClampOut, in, minv, minv)
}
// ClampOut ensures that all values are within min, max limits, clamping
// values to those bounds if they exceed them. min and max args are
// treated as scalars (first value used).
func ClampOut(in, minv, maxv tensor.Tensor, out tensor.Values) error {
tensor.SetShapeFrom(out, in)
mn := minv.Float1D(0)
mx := maxv.Float1D(0)
tensor.VectorizeThreaded(1, tensor.NFirstLen, func(idx int, tsr ...tensor.Tensor) {
tsr[1].SetFloat1D(math32.Clamp(tsr[0].Float1D(idx), mn, mx), idx)
}, in, out)
return nil
}
// Binarize results in a binary-valued output by setting
// values >= the threshold to 1, else 0. threshold is
// treated as a scalar (first value used).
func Binarize(in, threshold tensor.Tensor) tensor.Values {
return tensor.CallOut2(BinarizeOut, in, threshold)
}
// BinarizeOut results in a binary-valued output by setting
// values >= the threshold to 1, else 0. threshold is
// treated as a scalar (first value used).
func BinarizeOut(in, threshold tensor.Tensor, out tensor.Values) error {
tensor.SetShapeFrom(out, in)
thr := threshold.Float1D(0)
tensor.VectorizeThreaded(1, tensor.NFirstLen, func(idx int, tsr ...tensor.Tensor) {
v := tsr[0].Float1D(idx)
if v >= thr {
v = 1
} else {
v = 0
}
tsr[1].SetFloat1D(v, idx)
}, in, out)
return nil
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stats
import (
"math"
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/tensor"
)
// Quantiles returns the given quantile(s) of non-NaN elements in given
// 1D tensor. Because sorting uses indexes, this only works for 1D case.
// If needed for a sub-space of values, that can be extracted through slicing
// and then used. Logs an error if not 1D.
// qs are 0-1 values, 0 = min, 1 = max, .5 = median, etc.
// Uses linear interpolation.
// Because this requires a sort, it is more efficient to get as many quantiles
// as needed in one pass.
func Quantiles(in, qs tensor.Tensor) tensor.Values {
return tensor.CallOut2(QuantilesOut, in, qs)
}
// QuantilesOut returns the given quantile(s) of non-NaN elements in given
// 1D tensor. Because sorting uses indexes, this only works for 1D case.
// If needed for a sub-space of values, that can be extracted through slicing
// and then used. Returns and logs an error if not 1D.
// qs are 0-1 values, 0 = min, 1 = max, .5 = median, etc.
// Uses linear interpolation.
// Because this requires a sort, it is more efficient to get as many quantiles
// as needed in one pass.
func QuantilesOut(in, qs tensor.Tensor, out tensor.Values) error {
if in.NumDims() != 1 {
return errors.Log(errors.New("stats.QuantilesFunc: only 1D input tensors allowed"))
}
if qs.NumDims() != 1 {
return errors.Log(errors.New("stats.QuantilesFunc: only 1D quantile tensors allowed"))
}
tensor.SetShapeFrom(out, in)
sin := tensor.AsRows(in.AsValues())
sin.ExcludeMissing()
sin.Sort(tensor.Ascending)
sz := len(sin.Indexes) - 1 // length of our own index list
if sz <= 0 {
out.(tensor.Values).SetZeros()
return nil
}
fsz := float64(sz)
nq := qs.Len()
for i := range nq {
q := qs.Float1D(i)
val := 0.0
qi := q * fsz
lwi := math.Floor(qi)
lwii := int(lwi)
if lwii >= sz {
val = sin.FloatRow(sz, 0)
} else if lwii < 0 {
val = sin.FloatRow(0, 0)
} else {
phi := qi - lwi
lwv := sin.FloatRow(lwii, 0)
hiv := sin.FloatRow(lwii+1, 0)
val = (1-phi)*lwv + phi*hiv
}
out.SetFloat1D(val, i)
}
return nil
}
// Median computes the median (50% quantile) of tensor values.
// See [StatsFunc] for general information.
func Median(in tensor.Tensor) tensor.Values {
return Quantiles(in, tensor.NewFloat64Scalar(.5))
}
// Q1 computes the first quantile (25%) of tensor values.
// See [StatsFunc] for general information.
func Q1(in tensor.Tensor) tensor.Values {
return Quantiles(in, tensor.NewFloat64Scalar(.25))
}
// Q3 computes the third quantile (75%) of tensor values.
// See [StatsFunc] for general information.
func Q3(in tensor.Tensor) tensor.Values {
return Quantiles(in, tensor.NewFloat64Scalar(.75))
}
// MedianOut computes the median (50% quantile) of tensor values.
// See [StatsFunc] for general information.
func MedianOut(in tensor.Tensor, out tensor.Values) error {
return QuantilesOut(in, tensor.NewFloat64Scalar(.5), out)
}
// Q1Out computes the first quantile (25%) of tensor values.
// See [StatsFunc] for general information.
func Q1Out(in tensor.Tensor, out tensor.Values) error {
return QuantilesOut(in, tensor.NewFloat64Scalar(.25), out)
}
// Q3Out computes the third quantile (75%) of tensor values.
// See [StatsFunc] for general information.
func Q3Out(in tensor.Tensor, out tensor.Values) error {
return QuantilesOut(in, tensor.NewFloat64Scalar(.75), out)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stats
import (
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/tensor"
)
//go:generate core generate
func init() {
tensor.AddFunc(StatCount.FuncName(), Count)
tensor.AddFunc(StatSum.FuncName(), Sum)
tensor.AddFunc(StatL1Norm.FuncName(), L1Norm)
tensor.AddFunc(StatProd.FuncName(), Prod)
tensor.AddFunc(StatMin.FuncName(), Min)
tensor.AddFunc(StatMax.FuncName(), Max)
tensor.AddFunc(StatMinAbs.FuncName(), MinAbs)
tensor.AddFunc(StatMaxAbs.FuncName(), MaxAbs)
tensor.AddFunc(StatMean.FuncName(), Mean)
tensor.AddFunc(StatVar.FuncName(), Var)
tensor.AddFunc(StatStd.FuncName(), Std)
tensor.AddFunc(StatSem.FuncName(), Sem)
tensor.AddFunc(StatSumSq.FuncName(), SumSq)
tensor.AddFunc(StatL2Norm.FuncName(), L2Norm)
tensor.AddFunc(StatVarPop.FuncName(), VarPop)
tensor.AddFunc(StatStdPop.FuncName(), StdPop)
tensor.AddFunc(StatSemPop.FuncName(), SemPop)
tensor.AddFunc(StatMedian.FuncName(), Median)
tensor.AddFunc(StatQ1.FuncName(), Q1)
tensor.AddFunc(StatQ3.FuncName(), Q3)
tensor.AddFunc(StatFirst.FuncName(), First)
tensor.AddFunc(StatFinal.FuncName(), Final)
}
// Stats is a list of different standard aggregation functions, which can be used
// to choose an aggregation function
type Stats int32 //enums:enum -trim-prefix Stat
const (
// count of number of elements.
StatCount Stats = iota
// sum of elements.
StatSum
// L1 Norm: sum of absolute values of elements.
StatL1Norm
// product of elements.
StatProd
// minimum value.
StatMin
// maximum value.
StatMax
// minimum of absolute values.
StatMinAbs
// maximum of absolute values.
StatMaxAbs
// mean value = sum / count.
StatMean
// sample variance (squared deviations from mean, divided by n-1).
StatVar
// sample standard deviation (sqrt of Var).
StatStd
// sample standard error of the mean (Std divided by sqrt(n)).
StatSem
// sum of squared values.
StatSumSq
// L2 Norm: square-root of sum-of-squares.
StatL2Norm
// population variance (squared diffs from mean, divided by n).
StatVarPop
// population standard deviation (sqrt of VarPop).
StatStdPop
// population standard error of the mean (StdPop divided by sqrt(n)).
StatSemPop
// middle value in sorted ordering.
StatMedian
// Q1 first quartile = 25%ile value = .25 quantile value.
StatQ1
// Q3 third quartile = 75%ile value = .75 quantile value.
StatQ3
// first item in the set of data: for data with a natural ordering.
StatFirst
// final item in the set of data: for data with a natural ordering.
StatFinal
)
// FuncName returns the package-qualified function name to use
// in tensor.Call to call this function.
func (s Stats) FuncName() string {
return "stats." + s.String()
}
// Func returns function for given stat.
func (s Stats) Func() StatsFunc {
fn := errors.Log1(tensor.FuncByName(s.FuncName()))
return fn.Fun.(StatsFunc)
}
// Call calls this statistic function on given tensors.
// returning output as a newly created tensor.
func (s Stats) Call(in tensor.Tensor) tensor.Values {
return s.Func()(in)
}
// StripPackage removes any package name from given string,
// used for naming based on FuncName() which could be custom
// or have a package prefix.
func StripPackage(name string) string {
spl := strings.Split(name, ".")
if len(spl) > 1 {
return spl[len(spl)-1]
}
return name
}
// AsStatsFunc returns given function as a [StatsFunc] function,
// or an error if it does not fit that signature.
func AsStatsFunc(fun any) (StatsFunc, error) {
sfun, ok := fun.(StatsFunc)
if !ok {
return nil, errors.New("metric.AsStatsFunc: function does not fit the StatsFunc signature")
}
return sfun, nil
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stats
import (
"cogentcore.org/lab/tensor"
)
// VectorizeOut64 is the general compute function for stats.
// This version makes a Float64 output tensor for aggregating
// and computing values, and then copies the results back to the
// original output. This allows stats functions to operate directly
// on integer valued inputs and produce sensible results.
// It returns the Float64 output tensor for further processing as needed.
func VectorizeOut64(in tensor.Tensor, out tensor.Values, ini float64, fun func(val, agg float64) float64) *tensor.Float64 {
rows, cells := in.Shape().RowCellSize()
o64 := tensor.NewFloat64(cells)
if rows <= 0 {
return o64
}
if cells == 1 {
out.SetShapeSizes(1)
agg := ini
switch x := in.(type) {
case *tensor.Float64:
for i := range rows {
agg = fun(x.Float1D(i), agg)
}
case *tensor.Float32:
for i := range rows {
agg = fun(x.Float1D(i), agg)
}
default:
for i := range rows {
agg = fun(in.Float1D(i), agg)
}
}
o64.SetFloat1D(agg, 0)
out.SetFloat1D(agg, 0)
return o64
}
osz := tensor.CellsSize(in.ShapeSizes())
out.SetShapeSizes(osz...)
for i := range cells {
o64.SetFloat1D(ini, i)
}
switch x := in.(type) {
case *tensor.Float64:
for i := range rows {
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(i*cells+j), o64.Float1D(j)), j)
}
}
case *tensor.Float32:
for i := range rows {
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(i*cells+j), o64.Float1D(j)), j)
}
}
default:
for i := range rows {
for j := range cells {
o64.SetFloat1D(fun(in.Float1D(i*cells+j), o64.Float1D(j)), j)
}
}
}
for j := range cells {
out.SetFloat1D(o64.Float1D(j), j)
}
return o64
}
// VectorizePreOut64 is a version of [VectorizeOut64] that takes an additional
// tensor.Float64 input of pre-computed values, e.g., the means of each output cell.
func VectorizePreOut64(in tensor.Tensor, out tensor.Values, ini float64, pre *tensor.Float64, fun func(val, pre, agg float64) float64) *tensor.Float64 {
rows, cells := in.Shape().RowCellSize()
o64 := tensor.NewFloat64(cells)
if rows <= 0 {
return o64
}
if cells == 1 {
out.SetShapeSizes(1)
agg := ini
prev := pre.Float1D(0)
switch x := in.(type) {
case *tensor.Float64:
for i := range rows {
agg = fun(x.Float1D(i), prev, agg)
}
case *tensor.Float32:
for i := range rows {
agg = fun(x.Float1D(i), prev, agg)
}
default:
for i := range rows {
agg = fun(in.Float1D(i), prev, agg)
}
}
o64.SetFloat1D(agg, 0)
out.SetFloat1D(agg, 0)
return o64
}
osz := tensor.CellsSize(in.ShapeSizes())
out.SetShapeSizes(osz...)
for j := range cells {
o64.SetFloat1D(ini, j)
}
switch x := in.(type) {
case *tensor.Float64:
for i := range rows {
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(i*cells+j), pre.Float1D(j), o64.Float1D(j)), j)
}
}
case *tensor.Float32:
for i := range rows {
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(i*cells+j), pre.Float1D(j), o64.Float1D(j)), j)
}
}
default:
for i := range rows {
for j := range cells {
o64.SetFloat1D(fun(in.Float1D(i*cells+j), pre.Float1D(j), o64.Float1D(j)), j)
}
}
}
for i := range cells {
out.SetFloat1D(o64.Float1D(i), i)
}
return o64
}
// Vectorize2Out64 is a version of [VectorizeOut64] that separately aggregates
// two output values, x and y as tensor.Float64.
func Vectorize2Out64(in tensor.Tensor, iniX, iniY float64, fun func(val, ox, oy float64) (float64, float64)) (ox64, oy64 *tensor.Float64) {
rows, cells := in.Shape().RowCellSize()
ox64 = tensor.NewFloat64(cells)
oy64 = tensor.NewFloat64(cells)
if rows <= 0 {
return ox64, oy64
}
if cells == 1 {
ox := iniX
oy := iniY
switch x := in.(type) {
case *tensor.Float64:
for i := range rows {
ox, oy = fun(x.Float1D(i), ox, oy)
}
case *tensor.Float32:
for i := range rows {
ox, oy = fun(x.Float1D(i), ox, oy)
}
default:
for i := range rows {
ox, oy = fun(in.Float1D(i), ox, oy)
}
}
ox64.SetFloat1D(ox, 0)
oy64.SetFloat1D(oy, 0)
return
}
for j := range cells {
ox64.SetFloat1D(iniX, j)
oy64.SetFloat1D(iniY, j)
}
switch x := in.(type) {
case *tensor.Float64:
for i := range rows {
for j := range cells {
ox, oy := fun(x.Float1D(i*cells+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
case *tensor.Float32:
for i := range rows {
for j := range cells {
ox, oy := fun(x.Float1D(i*cells+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
default:
for i := range rows {
for j := range cells {
ox, oy := fun(in.Float1D(i*cells+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
}
return
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package table
import (
"cogentcore.org/core/base/keylist"
"cogentcore.org/core/base/metadata"
"cogentcore.org/lab/tensor"
)
// Columns is the underlying column list and number of rows for Table.
// Each column is a raw [tensor.Values] tensor, and [Table]
// provides a [tensor.Rows] indexed view onto the Columns.
type Columns struct {
keylist.List[string, tensor.Values]
// number of rows, which is enforced to be the size of the
// outermost row dimension of the column tensors.
Rows int `edit:"-"`
}
// NewColumns returns a new Columns.
func NewColumns() *Columns {
return &Columns{}
}
// SetNumRows sets the number of rows in the table, across all columns.
// It is safe to set this to 0. For incrementally growing tables (e.g., a log)
// it is best to first set the anticipated full size, which allocates the
// full amount of memory, and then set to 0 and grow incrementally.
func (cl *Columns) SetNumRows(rows int) *Columns { //types:add
cl.Rows = rows // can be 0
for _, tsr := range cl.Values {
tsr.SetNumRows(rows)
}
return cl
}
// AddColumn adds the given tensor (as a [tensor.Values]) as a column,
// returning an error and not adding if the name is not unique.
// Automatically adjusts the shape to fit the current number of rows,
// (setting Rows if this is the first column added)
// and calls the metadata SetName with column name.
func (cl *Columns) AddColumn(name string, tsr tensor.Values) error {
if cl.Len() == 0 {
cl.Rows = tsr.DimSize(0)
}
err := cl.Add(name, tsr)
if err != nil {
return err
}
tsr.SetNumRows(cl.Rows)
metadata.SetName(tsr, name)
return nil
}
// InsertColumn inserts the given tensor as a column at given index,
// returning an error and not adding if the name is not unique.
// Automatically adjusts the shape to fit the current number of rows.
func (cl *Columns) InsertColumn(idx int, name string, tsr tensor.Values) error {
cl.Insert(idx, name, tsr)
tsr.SetNumRows(cl.Rows)
return nil
}
// Clone returns a complete copy of this set of columns.
func (cl *Columns) Clone() *Columns {
cp := NewColumns().SetNumRows(cl.Rows)
for i, nm := range cl.Keys {
tsr := cl.Values[i]
cp.AddColumn(nm, tsr.Clone())
}
return cl
}
// AppendRows appends shared columns in both tables with input table rows.
func (cl *Columns) AppendRows(cl2 *Columns) {
for i, nm := range cl.Keys {
c2 := cl2.At(nm)
if c2 == nil {
continue
}
c1 := cl.Values[i]
c1.AppendFrom(c2)
}
cl.SetNumRows(cl.Rows + cl2.Rows)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package table
import (
"math/rand"
"slices"
"sort"
"cogentcore.org/lab/tensor"
)
// RowIndex returns the actual index into underlying tensor row based on given
// index value. If Indexes == nil, index is passed through.
func (dt *Table) RowIndex(idx int) int {
if dt.Indexes == nil {
return idx
}
return dt.Indexes[idx]
}
// NumRows returns the number of rows, which is the number of Indexes if present,
// else actual number of [Columns.Rows].
func (dt *Table) NumRows() int {
if dt.Indexes == nil {
return dt.Columns.Rows
}
return len(dt.Indexes)
}
// Sequential sets Indexes to nil, resulting in sequential row-wise access into tensor.
func (dt *Table) Sequential() { //types:add
dt.Indexes = nil
}
// IndexesNeeded is called prior to an operation that needs actual indexes,
// e.g., Sort, Filter. If Indexes == nil, they are set to all rows, otherwise
// current indexes are left as is. Use Sequential, then IndexesNeeded to ensure
// all rows are represented.
func (dt *Table) IndexesNeeded() {
if dt.Indexes != nil {
return
}
dt.Indexes = make([]int, dt.Columns.Rows)
for i := range dt.Indexes {
dt.Indexes[i] = i
}
}
// IndexesFromTensor copies Indexes from the given [tensor.Rows] tensor,
// including if they are nil. This allows column-specific Sort, Filter and
// other such methods to be applied to the entire table.
func (dt *Table) IndexesFromTensor(ix *tensor.Rows) {
dt.Indexes = ix.Indexes
}
// ValidIndexes deletes all invalid indexes from the list.
// Call this if rows (could) have been deleted from table.
func (dt *Table) ValidIndexes() {
if dt.Columns.Rows <= 0 || dt.Indexes == nil {
dt.Indexes = nil
return
}
ni := dt.NumRows()
for i := ni - 1; i >= 0; i-- {
if dt.Indexes[i] >= dt.Columns.Rows {
dt.Indexes = append(dt.Indexes[:i], dt.Indexes[i+1:]...)
}
}
}
// Permuted sets indexes to a permuted order -- if indexes already exist
// then existing list of indexes is permuted, otherwise a new set of
// permuted indexes are generated
func (dt *Table) Permuted() {
if dt.Columns.Rows <= 0 {
dt.Indexes = nil
return
}
if dt.Indexes == nil {
dt.Indexes = rand.Perm(dt.Columns.Rows)
} else {
rand.Shuffle(len(dt.Indexes), func(i, j int) {
dt.Indexes[i], dt.Indexes[j] = dt.Indexes[j], dt.Indexes[i]
})
}
}
// SortColumn sorts the indexes into our Table according to values in
// given column, using either ascending or descending order,
// (use [tensor.Ascending] or [tensor.Descending] for self-documentation).
// Uses first cell of higher dimensional data.
// Returns error if column name not found.
func (dt *Table) SortColumn(columnName string, ascending bool) error { //types:add
dt.IndexesNeeded()
cl, err := dt.ColumnTry(columnName)
if err != nil {
return err
}
cl.Sort(ascending)
dt.IndexesFromTensor(cl)
return nil
}
// SortFunc sorts the indexes into our Table using given compare function.
// The compare function operates directly on row numbers into the Table
// as these row numbers have already been projected through the indexes.
// cmp(a, b) should return a negative number when a < b, a positive
// number when a > b and zero when a == b.
func (dt *Table) SortFunc(cmp func(dt *Table, i, j int) int) {
dt.IndexesNeeded()
slices.SortFunc(dt.Indexes, func(a, b int) int {
return cmp(dt, a, b) // key point: these are already indirected through indexes!!
})
}
// SortStableFunc stably sorts the indexes into our Table using given compare function.
// The compare function operates directly on row numbers into the Table
// as these row numbers have already been projected through the indexes.
// cmp(a, b) should return a negative number when a < b, a positive
// number when a > b and zero when a == b.
// It is *essential* that it always returns 0 when the two are equal
// for the stable function to actually work.
func (dt *Table) SortStableFunc(cmp func(dt *Table, i, j int) int) {
dt.IndexesNeeded()
slices.SortStableFunc(dt.Indexes, func(a, b int) int {
return cmp(dt, a, b) // key point: these are already indirected through indexes!!
})
}
// SortColumns sorts the indexes into our Table according to values in
// given column names, using either ascending or descending order,
// (use [tensor.Ascending] or [tensor.Descending] for self-documentation,
// and optionally using a stable sort.
// Uses first cell of higher dimensional data.
func (dt *Table) SortColumns(ascending, stable bool, columns ...string) { //types:add
dt.SortColumnIndexes(ascending, stable, dt.ColumnIndexList(columns...)...)
}
// SortColumnIndexes sorts the indexes into our Table according to values in
// given list of column indexes, using either ascending or descending order for
// all of the columns. Uses first cell of higher dimensional data.
func (dt *Table) SortColumnIndexes(ascending, stable bool, colIndexes ...int) {
dt.IndexesNeeded()
sf := dt.SortFunc
if stable {
sf = dt.SortStableFunc
}
sf(func(dt *Table, i, j int) int {
for _, ci := range colIndexes {
cl := dt.ColumnByIndex(ci).Tensor
if cl.IsString() {
v := tensor.CompareAscending(cl.StringRow(i, 0), cl.StringRow(j, 0), ascending)
if v != 0 {
return v
}
} else {
v := tensor.CompareAscending(cl.FloatRow(i, 0), cl.FloatRow(j, 0), ascending)
if v != 0 {
return v
}
}
}
return 0
})
}
// SortIndexes sorts the indexes into our Table directly in
// numerical order, producing the native ordering, while preserving
// any filtering that might have occurred.
func (dt *Table) SortIndexes() {
if dt.Indexes == nil {
return
}
sort.Ints(dt.Indexes)
}
// FilterFunc is a function used for filtering that returns
// true if Table row should be included in the current filtered
// view of the table, and false if it should be removed.
type FilterFunc func(dt *Table, row int) bool
// Filter filters the indexes into our Table using given Filter function.
// The Filter function operates directly on row numbers into the Table
// as these row numbers have already been projected through the indexes.
func (dt *Table) Filter(filterer func(dt *Table, row int) bool) {
dt.IndexesNeeded()
sz := len(dt.Indexes)
for i := sz - 1; i >= 0; i-- { // always go in reverse for filtering
if !filterer(dt, dt.Indexes[i]) { // delete
dt.Indexes = append(dt.Indexes[:i], dt.Indexes[i+1:]...)
}
}
}
// FilterString filters the indexes using string values in column compared to given
// string. Includes rows with matching values unless the Exclude option is set.
// If Contains option is set, it only checks if row contains string;
// if IgnoreCase, ignores case, otherwise filtering is case sensitive.
// Uses first cell from higher dimensions.
// Returns error if column name not found.
func (dt *Table) FilterString(columnName string, str string, opts tensor.FilterOptions) error { //types:add
dt.IndexesNeeded()
cl, err := dt.ColumnTry(columnName)
if err != nil {
return err
}
cl.FilterString(str, opts)
dt.IndexesFromTensor(cl)
return nil
}
// New returns a new table with column data organized according to
// the indexes. If Indexes are nil, a clone of the current tensor is returned
// but this function is only sensible if there is an indexed view in place.
func (dt *Table) New() *Table {
if dt.Indexes == nil {
return dt.Clone()
}
rows := len(dt.Indexes)
nt := dt.Clone()
nt.Indexes = nil
nt.SetNumRows(rows)
if rows == 0 {
return nt
}
for ci, cl := range nt.Columns.Values {
scl := dt.Columns.Values[ci]
_, csz := cl.Shape().RowCellSize()
for i, srw := range dt.Indexes {
cl.CopyCellsFrom(scl, i*csz, srw*csz, csz)
}
}
return nt
}
// DeleteRows deletes n rows of Indexes starting at given index in the list of indexes.
// This does not affect the underlying tensor data; To create an actual in-memory
// ordering with rows deleted, use [Table.New].
func (dt *Table) DeleteRows(at, n int) {
dt.IndexesNeeded()
dt.Indexes = append(dt.Indexes[:at], dt.Indexes[at+n:]...)
}
// Swap switches the indexes for i and j
func (dt *Table) Swap(i, j int) {
dt.Indexes[i], dt.Indexes[j] = dt.Indexes[j], dt.Indexes[i]
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package table
import (
"bufio"
"encoding/csv"
"fmt"
"io"
"io/fs"
"log"
"log/slog"
"math"
"os"
"reflect"
"strconv"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/fsx"
"cogentcore.org/lab/tensor"
)
const (
// Headers is passed to CSV methods for the headers arg, to use headers
// that capture full type and tensor shape information.
Headers = true
// NoHeaders is passed to CSV methods for the headers arg, to not use headers
NoHeaders = false
)
// SaveCSV writes a table to a comma-separated-values (CSV) file
// (where comma = any delimiter, specified in the delim arg).
// If headers = true then generate column headers that capture the type
// and tensor cell geometry of the columns, enabling full reloading
// of exactly the same table format and data (recommended).
// Otherwise, only the data is written.
func (dt *Table) SaveCSV(filename fsx.Filename, delim tensor.Delims, headers bool) error { //types:add
fp, err := os.Create(string(filename))
defer fp.Close()
if err != nil {
log.Println(err)
return err
}
bw := bufio.NewWriter(fp)
err = dt.WriteCSV(bw, delim, headers)
bw.Flush()
return err
}
// OpenCSV reads a table from a comma-separated-values (CSV) file
// (where comma = any delimiter, specified in the delim arg),
// using the Go standard encoding/csv reader conforming to the official CSV standard.
// If the table does not currently have any columns, the first row of the file
// is assumed to be headers, and columns are constructed therefrom.
// If the file was saved from table with headers, then these have full configuration
// information for tensor type and dimensionality.
// If the table DOES have existing columns, then those are used robustly
// for whatever information fits from each row of the file.
func (dt *Table) OpenCSV(filename fsx.Filename, delim tensor.Delims) error { //types:add
fp, err := os.Open(string(filename))
if err != nil {
return errors.Log(err)
}
defer fp.Close()
return dt.ReadCSV(bufio.NewReader(fp), delim)
}
// OpenFS is the version of [Table.OpenCSV] that uses an [fs.FS] filesystem.
func (dt *Table) OpenFS(fsys fs.FS, filename string, delim tensor.Delims) error {
fp, err := fsys.Open(filename)
if err != nil {
return errors.Log(err)
}
defer fp.Close()
return dt.ReadCSV(bufio.NewReader(fp), delim)
}
// ReadCSV reads a table from a comma-separated-values (CSV) file
// (where comma = any delimiter, specified in the delim arg),
// using the Go standard encoding/csv reader conforming to the official CSV standard.
// If the table does not currently have any columns, the first row of the file
// is assumed to be headers, and columns are constructed therefrom.
// If the file was saved from table with headers, then these have full configuration
// information for tensor type and dimensionality.
// If the table DOES have existing columns, then those are used robustly
// for whatever information fits from each row of the file.
func (dt *Table) ReadCSV(r io.Reader, delim tensor.Delims) error {
dt.Sequential()
cr := csv.NewReader(r)
cr.Comma = delim.Rune()
rec, err := cr.ReadAll() // todo: lazy, avoid resizing
if err != nil || len(rec) == 0 {
return err
}
rows := len(rec)
strow := 0
if dt.NumColumns() == 0 || DetectTableHeaders(rec[0]) {
dt.DeleteAll()
err := ConfigFromHeaders(dt, rec[0], rec)
if err != nil {
log.Println(err.Error())
return err
}
strow++
rows--
}
dt.SetNumRows(rows)
for ri := 0; ri < rows; ri++ {
dt.ReadCSVRow(rec[ri+strow], ri)
}
return nil
}
// ReadCSVRow reads a record of CSV data into given row in table
func (dt *Table) ReadCSVRow(rec []string, row int) {
ci := 0
if rec[0] == "_D:" { // data row
ci++
}
nan := math.NaN()
for _, tsr := range dt.Columns.Values {
_, csz := tsr.Shape().RowCellSize()
stoff := row * csz
for cc := 0; cc < csz; cc++ {
str := rec[ci]
if !tsr.IsString() {
if str == "" || str == "NaN" || str == "-NaN" || str == "Inf" || str == "-Inf" {
tsr.SetFloat1D(nan, stoff+cc)
} else {
tsr.SetString1D(strings.TrimSpace(str), stoff+cc)
}
} else {
tsr.SetString1D(strings.TrimSpace(str), stoff+cc)
}
ci++
if ci >= len(rec) {
return
}
}
}
}
// ConfigFromHeaders attempts to configure Table based on the headers.
// for non-table headers, data is examined to determine types.
func ConfigFromHeaders(dt *Table, hdrs []string, rec [][]string) error {
if DetectTableHeaders(hdrs) {
return ConfigFromTableHeaders(dt, hdrs)
}
return ConfigFromDataValues(dt, hdrs, rec)
}
// DetectTableHeaders looks for special header characters -- returns true if found
func DetectTableHeaders(hdrs []string) bool {
for _, hd := range hdrs {
hd = strings.TrimSpace(hd)
if hd == "" {
continue
}
if hd == "_H:" {
return true
}
if _, ok := TableHeaderToType[hd[0]]; !ok { // all must be table
return false
}
}
return true
}
// ConfigFromTableHeaders attempts to configure a Table based on special table headers
func ConfigFromTableHeaders(dt *Table, hdrs []string) error {
for _, hd := range hdrs {
hd = strings.TrimSpace(hd)
if hd == "" || hd == "_H:" {
continue
}
typ, hd := TableColumnType(hd)
dimst := strings.Index(hd, "]<")
if dimst > 0 {
dims := hd[dimst+2 : len(hd)-1]
lbst := strings.Index(hd, "[")
hd = hd[:lbst]
csh := ShapeFromString(dims)
// new tensor starting
dt.AddColumnOfType(hd, typ, csh...)
continue
}
dimst = strings.Index(hd, "[")
if dimst > 0 {
continue
}
dt.AddColumnOfType(hd, typ)
}
return nil
}
// TableHeaderToType maps special header characters to data type
var TableHeaderToType = map[byte]reflect.Kind{
'$': reflect.String,
'%': reflect.Float32,
'#': reflect.Float64,
'|': reflect.Int,
'^': reflect.Bool,
}
// TableHeaderChar returns the special header character based on given data type
func TableHeaderChar(typ reflect.Kind) byte {
switch {
case typ == reflect.Bool:
return '^'
case typ == reflect.Float32:
return '%'
case typ == reflect.Float64:
return '#'
case typ >= reflect.Int && typ <= reflect.Uintptr:
return '|'
default:
return '$'
}
}
// TableColumnType parses the column header for special table type information
func TableColumnType(nm string) (reflect.Kind, string) {
typ, ok := TableHeaderToType[nm[0]]
if ok {
nm = nm[1:]
} else {
typ = reflect.String // most general, default
}
return typ, nm
}
// ShapeFromString parses string representation of shape as N:d,d,..
func ShapeFromString(dims string) []int {
clni := strings.Index(dims, ":")
nd, _ := strconv.Atoi(dims[:clni])
sh := make([]int, nd)
ci := clni + 1
for i := 0; i < nd; i++ {
dstr := ""
if i < nd-1 {
nci := strings.Index(dims[ci:], ",")
dstr = dims[ci : ci+nci]
ci += nci + 1
} else {
dstr = dims[ci:]
}
d, _ := strconv.Atoi(dstr)
sh[i] = d
}
return sh
}
// ConfigFromDataValues configures a Table based on data types inferred
// from the string representation of given records, using header names if present.
func ConfigFromDataValues(dt *Table, hdrs []string, rec [][]string) error {
nr := len(rec)
for ci, hd := range hdrs {
hd = strings.TrimSpace(hd)
if hd == "" {
hd = fmt.Sprintf("col_%d", ci)
}
nmatch := 0
typ := reflect.String
for ri := 1; ri < nr; ri++ {
rv := rec[ri][ci]
if rv == "" {
continue
}
ctyp := InferDataType(rv)
switch {
case ctyp == reflect.String: // definitive
typ = ctyp
break
case typ == ctyp && (nmatch > 1 || ri == nr-1): // good enough
break
case typ == ctyp: // gather more info
nmatch++
case typ == reflect.String: // always upgrade from string default
nmatch = 0
typ = ctyp
case typ == reflect.Int && ctyp == reflect.Float64: // upgrade
nmatch = 0
typ = ctyp
}
}
dt.AddColumnOfType(hd, typ)
}
return nil
}
// InferDataType returns the inferred data type for the given string
// only deals with float64, int, and string types
func InferDataType(str string) reflect.Kind {
if strings.Contains(str, ".") {
_, err := strconv.ParseFloat(str, 64)
if err == nil {
return reflect.Float64
}
}
_, err := strconv.ParseInt(str, 10, 64)
if err == nil {
return reflect.Int
}
// try float again just in case..
_, err = strconv.ParseFloat(str, 64)
if err == nil {
return reflect.Float64
}
return reflect.String
}
//////// WriteCSV
// WriteCSV writes only rows in table idx view to a comma-separated-values (CSV) file
// (where comma = any delimiter, specified in the delim arg).
// If headers = true then generate column headers that capture the type
// and tensor cell geometry of the columns, enabling full reloading
// of exactly the same table format and data (recommended).
// Otherwise, only the data is written.
func (dt *Table) WriteCSV(w io.Writer, delim tensor.Delims, headers bool) error {
ncol := 0
var err error
if headers {
ncol, err = dt.WriteCSVHeaders(w, delim)
if err != nil {
log.Println(err)
return err
}
}
cw := csv.NewWriter(w)
cw.Comma = delim.Rune()
nrow := dt.NumRows()
for ri := range nrow {
ix := dt.RowIndex(ri)
err = dt.WriteCSVRowWriter(cw, ix, ncol)
if err != nil {
log.Println(err)
return err
}
}
cw.Flush()
return nil
}
// WriteCSVHeaders writes headers to a comma-separated-values (CSV) file
// (where comma = any delimiter, specified in the delim arg).
// Returns number of columns in header
func (dt *Table) WriteCSVHeaders(w io.Writer, delim tensor.Delims) (int, error) {
cw := csv.NewWriter(w)
cw.Comma = delim.Rune()
hdrs := dt.TableHeaders()
nc := len(hdrs)
err := cw.Write(hdrs)
if err != nil {
return nc, err
}
cw.Flush()
return nc, nil
}
// WriteCSVRow writes given row to a comma-separated-values (CSV) file
// (where comma = any delimiter, specified in the delim arg)
func (dt *Table) WriteCSVRow(w io.Writer, row int, delim tensor.Delims) error {
cw := csv.NewWriter(w)
cw.Comma = delim.Rune()
err := dt.WriteCSVRowWriter(cw, row, 0)
cw.Flush()
return err
}
// WriteCSVRowWriter uses csv.Writer to write one row
func (dt *Table) WriteCSVRowWriter(cw *csv.Writer, row int, ncol int) error {
prec := -1
if ps, err := tensor.Precision(dt); err == nil {
prec = ps
}
var rec []string
if ncol > 0 {
rec = make([]string, 0, ncol)
} else {
rec = make([]string, 0)
}
rc := 0
for _, tsr := range dt.Columns.Values {
nd := tsr.NumDims()
if nd == 1 {
vl := ""
if prec <= 0 || tsr.IsString() {
vl = tsr.String1D(row)
} else {
vl = strconv.FormatFloat(tsr.Float1D(row), 'g', prec, 64)
}
if len(rec) <= rc {
rec = append(rec, vl)
} else {
rec[rc] = vl
}
rc++
} else {
csh := tensor.NewShape(tsr.ShapeSizes()[1:]...) // cell shape
tc := csh.Len()
for ti := 0; ti < tc; ti++ {
vl := ""
if prec <= 0 || tsr.IsString() {
vl = tsr.String1D(row*tc + ti)
} else {
vl = strconv.FormatFloat(tsr.Float1D(row*tc+ti), 'g', prec, 64)
}
if len(rec) <= rc {
rec = append(rec, vl)
} else {
rec[rc] = vl
}
rc++
}
}
}
err := cw.Write(rec)
return err
}
// TableHeaders generates special header strings from the table
// with full information about type and tensor cell dimensionality.
func (dt *Table) TableHeaders() []string {
hdrs := []string{}
for i, nm := range dt.Columns.Keys {
tsr := dt.Columns.Values[i]
nm = string([]byte{TableHeaderChar(tsr.DataType())}) + nm
if tsr.NumDims() == 1 {
hdrs = append(hdrs, nm)
} else {
csh := tensor.NewShape(tsr.ShapeSizes()[1:]...) // cell shape
tc := csh.Len()
nd := csh.NumDims()
fnm := nm + fmt.Sprintf("[%v:", nd)
dn := fmt.Sprintf("<%v:", nd)
ffnm := fnm
for di := 0; di < nd; di++ {
ffnm += "0"
dn += fmt.Sprintf("%v", csh.DimSize(di))
if di < nd-1 {
ffnm += ","
dn += ","
}
}
ffnm += "]" + dn + ">"
hdrs = append(hdrs, ffnm)
for ti := 1; ti < tc; ti++ {
idx := csh.IndexFrom1D(ti)
ffnm := fnm
for di := 0; di < nd; di++ {
ffnm += fmt.Sprintf("%v", idx[di])
if di < nd-1 {
ffnm += ","
}
}
ffnm += "]"
hdrs = append(hdrs, ffnm)
}
}
}
return hdrs
}
// CleanCatTSV cleans a TSV file formed by concatenating multiple files together.
// Removes redundant headers and then sorts by given set of columns.
func CleanCatTSV(filename string, sorts ...string) error {
str, err := os.ReadFile(filename)
if err != nil {
slog.Error(err.Error())
return err
}
lns := strings.Split(string(str), "\n")
if len(lns) == 0 {
return nil
}
hdr := lns[0]
f, err := os.Create(filename)
if err != nil {
slog.Error(err.Error())
return err
}
for i, ln := range lns {
if i > 0 && ln == hdr {
continue
}
io.WriteString(f, ln)
io.WriteString(f, "\n")
}
f.Close()
dt := New()
err = dt.OpenCSV(fsx.Filename(filename), tensor.Detect)
if err != nil {
slog.Error(err.Error())
return err
}
dt.SortColumns(tensor.Ascending, tensor.StableSort, sorts...)
st := dt.New()
err = st.SaveCSV(fsx.Filename(filename), tensor.Tab, true)
if err != nil {
slog.Error(err.Error())
}
return err
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package table
import (
"os"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/metadata"
"cogentcore.org/lab/tensor"
)
func setLogRow(dt *Table, row int) {
metadata.SetTo(dt, "LogRow", row)
}
func logRow(dt *Table) int {
return errors.Ignore1(metadata.GetFrom[int](dt, "LogRow"))
}
func setLogDelim(dt *Table, delim tensor.Delims) {
metadata.SetTo(dt, "LogDelim", delim)
}
func logDelim(dt *Table) tensor.Delims {
return errors.Ignore1(metadata.GetFrom[tensor.Delims](dt, "LogDelim"))
}
// OpenLog opens a log file for this table, which supports incremental
// output of table data as it is generated, using the standard [Table.SaveCSV]
// output formatting, using given delimiter between values on a line.
// Call [Table.WriteToLog] to write any new data rows to
// the open log file, and [Table.CloseLog] to close the file.
func (dt *Table) OpenLog(filename string, delim tensor.Delims) error {
f, err := os.Create(filename)
if err != nil {
return err
}
metadata.SetFile(dt, f)
setLogDelim(dt, delim)
setLogRow(dt, 0)
return nil
}
var (
ErrLogNoNewRows = errors.New("no new rows to write")
)
// WriteToLog writes any accumulated rows in the table to the file
// opened by [Table.OpenLog]. A Header row is written for the first output.
// If the current number of rows is less than the last number of rows,
// all of those rows are written under the assumption that the rows
// were reset via [Table.SetNumRows].
// Returns error for any failure, including [ErrLogNoNewRows] if
// no new rows are available to write.
func (dt *Table) WriteToLog() error {
f := metadata.File(dt)
if f == nil {
return errors.New("tensor.Table.WriteToLog: log file was not opened")
}
delim := logDelim(dt)
lrow := logRow(dt)
nr := dt.NumRows()
if nr == 0 || lrow == nr {
return ErrLogNoNewRows
}
if lrow == 0 {
dt.WriteCSVHeaders(f, delim)
}
sr := lrow
if nr < lrow {
sr = 0
}
for r := sr; r < nr; r++ {
dt.WriteCSVRow(f, r, delim)
}
setLogRow(dt, nr)
return nil
}
// CloseLog closes the log file opened by [Table.OpenLog].
func (dt *Table) CloseLog() {
f := metadata.File(dt)
f.Close()
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package table
import (
"fmt"
"reflect"
"cogentcore.org/core/base/reflectx"
)
// NewSliceTable returns a new Table with data from the given slice
// of structs.
func NewSliceTable(st any) (*Table, error) {
npv := reflectx.NonPointerValue(reflect.ValueOf(st))
if npv.Kind() != reflect.Slice {
return nil, fmt.Errorf("NewSliceTable: not a slice")
}
eltyp := reflectx.NonPointerType(npv.Type().Elem())
if eltyp.Kind() != reflect.Struct {
return nil, fmt.Errorf("NewSliceTable: element type is not a struct")
}
dt := New()
for i := 0; i < eltyp.NumField(); i++ {
f := eltyp.Field(i)
kind := f.Type.Kind()
if !reflectx.KindIsBasic(kind) {
continue
}
dt.AddColumnOfType(f.Name, kind)
}
UpdateSliceTable(st, dt)
return dt, nil
}
// UpdateSliceTable updates given Table with data from the given slice
// of structs, which must be the same type as used to configure the table
func UpdateSliceTable(st any, dt *Table) {
npv := reflectx.NonPointerValue(reflect.ValueOf(st))
eltyp := reflectx.NonPointerType(npv.Type().Elem())
nr := npv.Len()
dt.SetNumRows(nr)
for ri := 0; ri < nr; ri++ {
for i := 0; i < eltyp.NumField(); i++ {
f := eltyp.Field(i)
kind := f.Type.Kind()
if !reflectx.KindIsBasic(kind) {
continue
}
val := npv.Index(ri).Field(i).Interface()
cl := dt.Column(f.Name)
if kind == reflect.String {
cl.SetStringRow(val.(string), ri, 0)
} else {
fv, _ := reflectx.ToFloat(val)
cl.SetFloatRow(fv, ri, 0)
}
}
}
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package table
//go:generate core generate
import (
"fmt"
"reflect"
"slices"
"cogentcore.org/core/base/metadata"
"cogentcore.org/lab/tensor"
)
// Table is a table of Tensor columns aligned by a common outermost row dimension.
// Use the [Table.Column] (by name) and [Table.ColumnIndex] methods to obtain a
// [tensor.Rows] view of the column, using the shared [Table.Indexes] of the Table.
// Thus, a coordinated sorting and filtered view of the column data is automatically
// available for any of the tensor package functions that use [tensor.Tensor] as the one
// common data representation for all operations.
// Tensor Columns are always raw value types and support SubSpace operations on cells.
type Table struct { //types:add
// Columns has the list of column tensor data for this table.
// Different tables can provide different indexed views onto the same Columns.
Columns *Columns
// Indexes are the indexes into Tensor rows, with nil = sequential.
// Only set if order is different from default sequential order.
// These indexes are shared into the `tensor.Rows` Column values
// to provide a coordinated indexed view into the underlying data.
Indexes []int
// Meta is misc metadata for the table. Use lower-case key names
// following the struct tag convention:
// - name string = name of table
// - doc string = documentation, description
// - read-only bool = gui is read-only
// - precision int = n for precision to write out floats in csv.
Meta metadata.Data
}
// New returns a new Table with its own (empty) set of Columns.
// Can pass an optional name which calls metadata SetName.
func New(name ...string) *Table {
dt := &Table{}
dt.Columns = NewColumns()
if len(name) > 0 {
metadata.SetName(dt, name[0])
}
return dt
}
// NewView returns a new Table with its own Rows view into the
// same underlying set of Column tensor data as the source table.
// Indexes are copied from the existing table -- use Sequential
// to reset to full sequential view.
func NewView(src *Table) *Table {
dt := &Table{Columns: src.Columns}
if src.Indexes != nil {
dt.Indexes = slices.Clone(src.Indexes)
}
dt.Meta.CopyFrom(src.Meta)
return dt
}
// Init initializes a new empty table with [NewColumns].
func (dt *Table) Init() {
dt.Columns = NewColumns()
}
func (dt *Table) Metadata() *metadata.Data { return &dt.Meta }
// IsValidRow returns error if the row is invalid, if error checking is needed.
func (dt *Table) IsValidRow(row int) error {
if row < 0 || row >= dt.NumRows() {
return fmt.Errorf("table.Table IsValidRow: row %d is out of valid range [0..%d]", row, dt.NumRows())
}
return nil
}
// NumColumns returns the number of columns.
func (dt *Table) NumColumns() int { return dt.Columns.Len() }
// Column returns the tensor with given column name, as a [tensor.Rows]
// with the shared [Table.Indexes] from this table. It is best practice to
// access columns by name, and direct access through [Table.Columns] does not
// provide the shared table-wide Indexes.
// Returns nil if not found.
func (dt *Table) Column(name string) *tensor.Rows {
cl := dt.Columns.At(name)
if cl == nil {
return nil
}
return tensor.NewRows(cl, dt.Indexes...)
}
// ColumnTry is a version of [Table.Column] that also returns an error
// if the column name is not found, for cases when error is needed.
func (dt *Table) ColumnTry(name string) (*tensor.Rows, error) {
cl := dt.Column(name)
if cl != nil {
return cl, nil
}
return nil, fmt.Errorf("table.Table: Column named %q not found", name)
}
// ColumnIndex returns the tensor at the given column index,
// as a [tensor.Rows] with the shared [Table.Indexes] from this table.
// It is best practice to instead access columns by name using [Table.Column].
// Direct access through [Table.Columns} does not provide the shared table-wide Indexes.
// Will panic if out of range.
func (dt *Table) ColumnByIndex(idx int) *tensor.Rows {
cl := dt.Columns.Values[idx]
return tensor.NewRows(cl, dt.Indexes...)
}
// ColumnList returns a list of tensors with given column names,
// as [tensor.Rows] with the shared [Table.Indexes] from this table.
func (dt *Table) ColumnList(names ...string) []tensor.Tensor {
list := make([]tensor.Tensor, 0, len(names))
for _, nm := range names {
cl := dt.Column(nm)
if cl != nil {
list = append(list, cl)
}
}
return list
}
// ColumnName returns the name of given column.
func (dt *Table) ColumnName(i int) string {
return dt.Columns.Keys[i]
}
// ColumnIndex returns the index for given column name.
func (dt *Table) ColumnIndex(name string) int {
return dt.Columns.IndexByKey(name)
}
// ColumnIndexList returns a list of indexes to columns of given names.
func (dt *Table) ColumnIndexList(names ...string) []int {
list := make([]int, 0, len(names))
for _, nm := range names {
ci := dt.ColumnIndex(nm)
if ci >= 0 {
list = append(list, ci)
}
}
return list
}
// AddColumn adds a new column to the table, of given type and column name
// (which must be unique). If no cellSizes are specified, it holds scalar values,
// otherwise the cells are n-dimensional tensors of given size.
func AddColumn[T tensor.DataTypes](dt *Table, name string, cellSizes ...int) tensor.Tensor {
rows := dt.Columns.Rows
sz := append([]int{rows}, cellSizes...)
tsr := tensor.New[T](sz...)
// tsr.SetNames("Row")
dt.AddColumn(name, tsr)
return tsr
}
// InsertColumn inserts a new column to the table, of given type and column name
// (which must be unique), at given index.
// If no cellSizes are specified, it holds scalar values,
// otherwise the cells are n-dimensional tensors of given size.
func InsertColumn[T tensor.DataTypes](dt *Table, name string, idx int, cellSizes ...int) tensor.Tensor {
rows := dt.Columns.Rows
sz := append([]int{rows}, cellSizes...)
tsr := tensor.New[T](sz...)
// tsr.SetNames("Row")
dt.InsertColumn(idx, name, tsr)
return tsr
}
// AddColumn adds the given [tensor.Values] as a column to the table,
// returning an error and not adding if the name is not unique.
// Automatically adjusts the shape to fit the current number of rows.
func (dt *Table) AddColumn(name string, tsr tensor.Values) error {
return dt.Columns.AddColumn(name, tsr)
}
// InsertColumn inserts the given [tensor.Values] as a column to the table at given index,
// returning an error and not adding if the name is not unique.
// Automatically adjusts the shape to fit the current number of rows.
func (dt *Table) InsertColumn(idx int, name string, tsr tensor.Values) error {
return dt.Columns.InsertColumn(idx, name, tsr)
}
// AddColumnOfType adds a new scalar column to the table, of given reflect type,
// column name (which must be unique),
// If no cellSizes are specified, it holds scalar values,
// otherwise the cells are n-dimensional tensors of given size.
// Supported types include string, bool (for [tensor.Bool]), float32, float64, int, int32, and byte.
func (dt *Table) AddColumnOfType(name string, typ reflect.Kind, cellSizes ...int) tensor.Tensor {
rows := dt.Columns.Rows
sz := append([]int{rows}, cellSizes...)
tsr := tensor.NewOfType(typ, sz...)
// tsr.SetNames("Row")
dt.AddColumn(name, tsr)
return tsr
}
// AddStringColumn adds a new String column with given name.
// If no cellSizes are specified, it holds scalar values,
// otherwise the cells are n-dimensional tensors of given size.
func (dt *Table) AddStringColumn(name string, cellSizes ...int) *tensor.String {
return AddColumn[string](dt, name, cellSizes...).(*tensor.String)
}
// AddFloat64Column adds a new float64 column with given name.
// If no cellSizes are specified, it holds scalar values,
// otherwise the cells are n-dimensional tensors of given size.
func (dt *Table) AddFloat64Column(name string, cellSizes ...int) *tensor.Float64 {
return AddColumn[float64](dt, name, cellSizes...).(*tensor.Float64)
}
// AddFloat32Column adds a new float32 column with given name.
// If no cellSizes are specified, it holds scalar values,
// otherwise the cells are n-dimensional tensors of given size.
func (dt *Table) AddFloat32Column(name string, cellSizes ...int) *tensor.Float32 {
return AddColumn[float32](dt, name, cellSizes...).(*tensor.Float32)
}
// AddIntColumn adds a new int column with given name.
// If no cellSizes are specified, it holds scalar values,
// otherwise the cells are n-dimensional tensors of given size.
func (dt *Table) AddIntColumn(name string, cellSizes ...int) *tensor.Int {
return AddColumn[int](dt, name, cellSizes...).(*tensor.Int)
}
// DeleteColumnName deletes column of given name.
// returns false if not found.
func (dt *Table) DeleteColumnName(name string) bool {
return dt.Columns.DeleteByKey(name)
}
// DeleteColumnIndex deletes column within the index range [i:j].
func (dt *Table) DeleteColumnByIndex(i, j int) {
dt.Columns.DeleteByIndex(i, j)
}
// DeleteAll deletes all columns, does full reset.
func (dt *Table) DeleteAll() {
dt.Indexes = nil
dt.Columns.Reset()
}
// AddRows adds n rows to end of underlying Table, and to the indexes in this view.
func (dt *Table) AddRows(n int) *Table { //types:add
return dt.SetNumRows(dt.Columns.Rows + n)
}
// InsertRows adds n rows to end of underlying Table, and to the indexes starting at
// given index in this view, providing an efficient insertion operation that only
// exists in the indexed view. To create an in-memory ordering, use [Table.New].
func (dt *Table) InsertRows(at, n int) *Table {
dt.IndexesNeeded()
strow := dt.Columns.Rows
stidx := len(dt.Indexes)
dt.SetNumRows(strow + n) // adds n indexes to end of list
// move those indexes to at:at+n in index list
dt.Indexes = append(dt.Indexes[:at], append(dt.Indexes[stidx:], dt.Indexes[at:]...)...)
dt.Indexes = dt.Indexes[:strow+n]
return dt
}
// SetNumRows sets the number of rows in the table, across all columns.
// If rows = 0 then effective number of rows in tensors is 1, as this dim cannot be 0.
// If indexes are in place and rows are added, indexes for the new rows are added.
func (dt *Table) SetNumRows(rows int) *Table { //types:add
strow := dt.Columns.Rows
dt.Columns.SetNumRows(rows)
if dt.Indexes == nil {
return dt
}
if rows > strow {
for i := range rows - strow {
dt.Indexes = append(dt.Indexes, strow+i)
}
} else {
dt.ValidIndexes()
}
return dt
}
// SetNumRowsToMax gets the current max number of rows across all the column tensors,
// and sets the number of rows to that. This will automatically pad shorter columns
// so they all have the same number of rows. If a table has columns that are not fully
// under its own control, they can change size, so this reestablishes
// a common row dimension.
func (dt *Table) SetNumRowsToMax() {
var maxRow int
for _, tsr := range dt.Columns.Values {
maxRow = max(maxRow, tsr.DimSize(0))
}
dt.SetNumRows(maxRow)
}
// note: no really clean definition of CopyFrom -- no point of re-using existing
// table -- just clone it.
// Clone returns a complete copy of this table, including cloning
// the underlying Columns tensors, and the current [Table.Indexes].
// See also [Table.New] to flatten the current indexes.
func (dt *Table) Clone() *Table {
cp := &Table{}
cp.Columns = dt.Columns.Clone()
cp.Meta.CopyFrom(dt.Meta)
if dt.Indexes != nil {
cp.Indexes = slices.Clone(dt.Indexes)
}
return cp
}
// AppendRows appends shared columns in both tables with input table rows.
func (dt *Table) AppendRows(dt2 *Table) {
strow := dt.Columns.Rows
n := dt2.Columns.Rows
dt.Columns.AppendRows(dt2.Columns)
if dt.Indexes == nil {
return
}
for i := range n {
dt.Indexes = append(dt.Indexes, strow+i)
}
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package table
import (
"errors"
"fmt"
"reflect"
"strings"
"cogentcore.org/lab/tensor"
)
// InsertKeyColumns returns a copy of the given Table with new columns
// having given values, inserted at the start, used as legend keys etc.
// args must be in pairs: column name, value. All rows get the same value.
func (dt *Table) InsertKeyColumns(args ...string) *Table {
n := len(args)
if n%2 != 0 {
fmt.Println("InsertKeyColumns requires even number of args as column name, value pairs")
return dt
}
c := dt.Clone()
nc := n / 2
for j := range nc {
colNm := args[2*j]
val := args[2*j+1]
col := tensor.NewString(c.Columns.Rows)
c.InsertColumn(0, colNm, col)
for i := range col.Values {
col.Values[i] = val
}
}
return c
}
// ConfigFromTable configures the columns of this table according to the
// values in the first two columns of given format table, conventionally named
// Name, Type (but names are not used), which must be of the string type.
func (dt *Table) ConfigFromTable(ft *Table) error {
nmcol := ft.ColumnByIndex(0)
tycol := ft.ColumnByIndex(1)
var errs []error
for i := range ft.NumRows() {
name := nmcol.String1D(i)
typ := strings.ToLower(tycol.String1D(i))
kind := reflect.Float64
switch typ {
case "string":
kind = reflect.String
case "bool":
kind = reflect.Bool
case "float32":
kind = reflect.Float32
case "float64":
kind = reflect.Float64
case "int":
kind = reflect.Int
case "int32":
kind = reflect.Int32
case "byte", "uint8":
kind = reflect.Uint8
default:
err := fmt.Errorf("ConfigFromTable: type string %q not recognized", typ)
errs = append(errs, err)
}
dt.AddColumnOfType(name, kind)
}
return errors.Join(errs...)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"fmt"
"slices"
"cogentcore.org/core/base/errors"
)
// AlignShapes aligns the shapes of two tensors, a and b for a binary
// computation producing an output, returning the effective aligned shapes
// for a, b, and the output, all with the same number of dimensions.
// Alignment proceeds from the innermost dimension out, with 1s provided
// beyond the number of dimensions for a or b.
// The output has the max of the dimension sizes for each dimension.
// An error is returned if the rules of alignment are violated:
// each dimension size must be either the same, or one of them
// is equal to 1. This corresponds to the "broadcasting" logic of NumPy.
func AlignShapes(a, b Tensor) (as, bs, os *Shape, err error) {
asz := a.ShapeSizes()
bsz := b.ShapeSizes()
an := len(asz)
bn := len(bsz)
n := max(an, bn)
osizes := make([]int, n)
asizes := make([]int, n)
bsizes := make([]int, n)
for d := range n {
ai := an - 1 - d
bi := bn - 1 - d
oi := n - 1 - d
ad := 1
bd := 1
if ai >= 0 {
ad = asz[ai]
}
if bi >= 0 {
bd = bsz[bi]
}
if ad != bd && !(ad == 1 || bd == 1) {
err = fmt.Errorf("tensor.AlignShapes: output dimension %d does not align for a=%d b=%d: must be either the same or one of them is a 1", oi, ad, bd)
return
}
od := max(ad, bd)
osizes[oi] = od
asizes[oi] = ad
bsizes[oi] = bd
}
as = NewShape(asizes...)
bs = NewShape(bsizes...)
os = NewShape(osizes...)
return
}
// WrapIndex1D returns the 1d flat index for given n-dimensional index
// based on given shape, where any singleton dimension sizes cause the
// resulting index value to remain at 0, effectively causing that dimension
// to wrap around, while the other tensor is presumably using the full range
// of the values along this dimension. See [AlignShapes] for more info.
func WrapIndex1D(sh *Shape, i ...int) int {
nd := sh.NumDims()
ai := slices.Clone(i)
for d := range nd {
if sh.DimSize(d) == 1 {
ai[d] = 0
}
}
return sh.IndexTo1D(ai...)
}
// AlignForAssign ensures that the shapes of two tensors, a and b
// have the proper alignment for assigning b into a.
// Alignment proceeds from the innermost dimension out, with 1s provided
// beyond the number of dimensions for a or b.
// An error is returned if the rules of alignment are violated:
// each dimension size must be either the same, or b is equal to 1.
// This corresponds to the "broadcasting" logic of NumPy.
func AlignForAssign(a, b Tensor) (as, bs *Shape, err error) {
asz := a.ShapeSizes()
bsz := b.ShapeSizes()
an := len(asz)
bn := len(bsz)
n := max(an, bn)
asizes := make([]int, n)
bsizes := make([]int, n)
for d := range n {
ai := an - 1 - d
bi := bn - 1 - d
oi := n - 1 - d
ad := 1
bd := 1
if ai >= 0 {
ad = asz[ai]
}
if bi >= 0 {
bd = bsz[bi]
}
if ad != bd && bd != 1 {
err = fmt.Errorf("tensor.AlignShapes: dimension %d does not align for a=%d b=%d: must be either the same or b is a 1", oi, ad, bd)
return
}
asizes[oi] = ad
bsizes[oi] = bd
}
as = NewShape(asizes...)
bs = NewShape(bsizes...)
return
}
// SplitAtInnerDims returns the sizes of the given tensor's shape
// with the given number of inner-most dimensions retained as is,
// and those above collapsed to a single dimension.
// If the total number of dimensions is < nInner the result is nil.
func SplitAtInnerDims(tsr Tensor, nInner int) []int {
sizes := tsr.ShapeSizes()
nd := len(sizes)
if nd < nInner {
return nil
}
rsz := make([]int, nInner+1)
split := nd - nInner
rows := sizes[:split]
copy(rsz[1:], sizes[split:])
nr := 1
for _, r := range rows {
nr *= r
}
rsz[0] = nr
return rsz
}
// FloatAssignFunc sets a to a binary function of a and b float64 values.
func FloatAssignFunc(fun func(a, b float64) float64, a, b Tensor) error {
as, bs, err := AlignForAssign(a, b)
if err != nil {
return err
}
alen := as.Len()
VectorizeThreaded(1, func(tsr ...Tensor) int { return alen },
func(idx int, tsr ...Tensor) {
ai := as.IndexFrom1D(idx)
bi := WrapIndex1D(bs, ai...)
tsr[0].SetFloat1D(fun(tsr[0].Float1D(idx), tsr[1].Float1D(bi)), idx)
}, a, b)
return nil
}
// StringAssignFunc sets a to a binary function of a and b string values.
func StringAssignFunc(fun func(a, b string) string, a, b Tensor) error {
as, bs, err := AlignForAssign(a, b)
if err != nil {
return err
}
alen := as.Len()
VectorizeThreaded(1, func(tsr ...Tensor) int { return alen },
func(idx int, tsr ...Tensor) {
ai := as.IndexFrom1D(idx)
bi := WrapIndex1D(bs, ai...)
tsr[0].SetString1D(fun(tsr[0].String1D(idx), tsr[1].String1D(bi)), idx)
}, a, b)
return nil
}
// FloatBinaryFunc sets output to a binary function of a, b float64 values.
// The flops (floating point operations) estimate is used to control parallel
// threading using goroutines, and should reflect number of flops in the function.
// See [VectorizeThreaded] for more information.
func FloatBinaryFunc(flops int, fun func(a, b float64) float64, a, b Tensor) Tensor {
return CallOut2Gen2(FloatBinaryFuncOut, flops, fun, a, b)
}
// FloatBinaryFuncOut sets output to a binary function of a, b float64 values.
func FloatBinaryFuncOut(flops int, fun func(a, b float64) float64, a, b Tensor, out Values) error {
as, bs, os, err := AlignShapes(a, b)
if err != nil {
return err
}
out.SetShapeSizes(os.Sizes...)
olen := os.Len()
VectorizeThreaded(flops, func(tsr ...Tensor) int { return olen },
func(idx int, tsr ...Tensor) {
oi := os.IndexFrom1D(idx)
ai := WrapIndex1D(as, oi...)
bi := WrapIndex1D(bs, oi...)
out.SetFloat1D(fun(tsr[0].Float1D(ai), tsr[1].Float1D(bi)), idx)
}, a, b, out)
return nil
}
// StringBinaryFunc sets output to a binary function of a, b string values.
func StringBinaryFunc(fun func(a, b string) string, a, b Tensor) Tensor {
return CallOut2Gen1(StringBinaryFuncOut, fun, a, b)
}
// StringBinaryFuncOut sets output to a binary function of a, b string values.
func StringBinaryFuncOut(fun func(a, b string) string, a, b Tensor, out Values) error {
as, bs, os, err := AlignShapes(a, b)
if err != nil {
return err
}
out.SetShapeSizes(os.Sizes...)
olen := os.Len()
VectorizeThreaded(1, func(tsr ...Tensor) int { return olen },
func(idx int, tsr ...Tensor) {
oi := os.IndexFrom1D(idx)
ai := WrapIndex1D(as, oi...)
bi := WrapIndex1D(bs, oi...)
out.SetString1D(fun(tsr[0].String1D(ai), tsr[1].String1D(bi)), idx)
}, a, b, out)
return nil
}
// FloatFunc sets output to a function of tensor float64 values.
// The flops (floating point operations) estimate is used to control parallel
// threading using goroutines, and should reflect number of flops in the function.
// See [VectorizeThreaded] for more information.
func FloatFunc(flops int, fun func(in float64) float64, in Tensor) Values {
return CallOut1Gen2(FloatFuncOut, flops, fun, in)
}
// FloatFuncOut sets output to a function of tensor float64 values.
func FloatFuncOut(flops int, fun func(in float64) float64, in Tensor, out Values) error {
SetShapeFrom(out, in)
n := in.Len()
VectorizeThreaded(flops, func(tsr ...Tensor) int { return n },
func(idx int, tsr ...Tensor) {
tsr[1].SetFloat1D(fun(tsr[0].Float1D(idx)), idx)
}, in, out)
return nil
}
// FloatSetFunc sets tensor float64 values from a function,
// which gets the index. Must be parallel threadsafe.
// The flops (floating point operations) estimate is used to control parallel
// threading using goroutines, and should reflect number of flops in the function.
// See [VectorizeThreaded] for more information.
func FloatSetFunc(flops int, fun func(idx int) float64, a Tensor) error {
n := a.Len()
VectorizeThreaded(flops, func(tsr ...Tensor) int { return n },
func(idx int, tsr ...Tensor) {
tsr[0].SetFloat1D(fun(idx), idx)
}, a)
return nil
}
//////// Bool
// BoolStringsFunc sets boolean output value based on a function involving
// string values from the two tensors.
func BoolStringsFunc(fun func(a, b string) bool, a, b Tensor) *Bool {
out := NewBool()
errors.Log(BoolStringsFuncOut(fun, a, b, out))
return out
}
// BoolStringsFuncOut sets boolean output value based on a function involving
// string values from the two tensors.
func BoolStringsFuncOut(fun func(a, b string) bool, a, b Tensor, out *Bool) error {
as, bs, os, err := AlignShapes(a, b)
if err != nil {
return err
}
out.SetShapeSizes(os.Sizes...)
olen := os.Len()
VectorizeThreaded(5, func(tsr ...Tensor) int { return olen },
func(idx int, tsr ...Tensor) {
oi := os.IndexFrom1D(idx)
ai := WrapIndex1D(as, oi...)
bi := WrapIndex1D(bs, oi...)
out.SetBool1D(fun(tsr[0].String1D(ai), tsr[1].String1D(bi)), idx)
}, a, b, out)
return nil
}
// BoolFloatsFunc sets boolean output value based on a function involving
// float64 values from the two tensors.
func BoolFloatsFunc(fun func(a, b float64) bool, a, b Tensor) *Bool {
out := NewBool()
errors.Log(BoolFloatsFuncOut(fun, a, b, out))
return out
}
// BoolFloatsFuncOut sets boolean output value based on a function involving
// float64 values from the two tensors.
func BoolFloatsFuncOut(fun func(a, b float64) bool, a, b Tensor, out *Bool) error {
as, bs, os, err := AlignShapes(a, b)
if err != nil {
return err
}
out.SetShapeSizes(os.Sizes...)
olen := os.Len()
VectorizeThreaded(5, func(tsr ...Tensor) int { return olen },
func(idx int, tsr ...Tensor) {
oi := os.IndexFrom1D(idx)
ai := WrapIndex1D(as, oi...)
bi := WrapIndex1D(bs, oi...)
out.SetBool1D(fun(tsr[0].Float1D(ai), tsr[1].Float1D(bi)), idx)
}, a, b, out)
return nil
}
// BoolIntsFunc sets boolean output value based on a function involving
// int values from the two tensors.
func BoolIntsFunc(fun func(a, b int) bool, a, b Tensor) *Bool {
out := NewBool()
errors.Log(BoolIntsFuncOut(fun, a, b, out))
return out
}
// BoolIntsFuncOut sets boolean output value based on a function involving
// int values from the two tensors.
func BoolIntsFuncOut(fun func(a, b int) bool, a, b Tensor, out *Bool) error {
as, bs, os, err := AlignShapes(a, b)
if err != nil {
return err
}
out.SetShapeSizes(os.Sizes...)
olen := os.Len()
VectorizeThreaded(5, func(tsr ...Tensor) int { return olen },
func(idx int, tsr ...Tensor) {
oi := os.IndexFrom1D(idx)
ai := WrapIndex1D(as, oi...)
bi := WrapIndex1D(bs, oi...)
out.SetBool1D(fun(tsr[0].Int1D(ai), tsr[1].Int1D(bi)), idx)
}, a, b, out)
return nil
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"reflect"
"slices"
"unsafe"
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/base/reflectx"
"cogentcore.org/core/base/slicesx"
)
// Base is the base Tensor implementation for given type.
type Base[T any] struct {
shape Shape
Values []T
Meta metadata.Data
}
// Metadata returns the metadata for this tensor, which can be used
// to encode plotting options, etc.
func (tsr *Base[T]) Metadata() *metadata.Data { return &tsr.Meta }
func (tsr *Base[T]) Shape() *Shape { return &tsr.shape }
// ShapeSizes returns the sizes of each dimension as a slice of ints.
// The returned slice is a copy and can be modified without side effects.
func (tsr *Base[T]) ShapeSizes() []int { return slices.Clone(tsr.shape.Sizes) }
// SetShapeSizes sets the dimension sizes of the tensor, and resizes
// backing storage appropriately, retaining all existing data that fits.
func (tsr *Base[T]) SetShapeSizes(sizes ...int) {
tsr.shape.SetShapeSizes(sizes...)
nln := tsr.shape.Len()
tsr.Values = slicesx.SetLength(tsr.Values, nln)
}
// Len returns the number of elements in the tensor (product of shape dimensions).
func (tsr *Base[T]) Len() int { return tsr.shape.Len() }
// NumDims returns the total number of dimensions.
func (tsr *Base[T]) NumDims() int { return tsr.shape.NumDims() }
// DimSize returns size of given dimension.
func (tsr *Base[T]) DimSize(dim int) int { return tsr.shape.DimSize(dim) }
// DataType returns the type of the data elements in the tensor.
// Bool is returned for the Bool tensor type.
func (tsr *Base[T]) DataType() reflect.Kind {
var v T
return reflect.TypeOf(v).Kind()
}
func (tsr *Base[T]) Sizeof() int64 {
var v T
return int64(unsafe.Sizeof(v)) * int64(tsr.Len())
}
func (tsr *Base[T]) Bytes() []byte {
return slicesx.ToBytes(tsr.Values)
}
func (tsr *Base[T]) Value(i ...int) T {
return tsr.Values[tsr.shape.IndexTo1D(i...)]
}
func (tsr *Base[T]) ValuePtr(i ...int) *T {
return &tsr.Values[tsr.shape.IndexTo1D(i...)]
}
func (tsr *Base[T]) Set(val T, i ...int) {
tsr.Values[tsr.shape.IndexTo1D(i...)] = val
}
func (tsr *Base[T]) Value1D(i int) T { return tsr.Values[i] }
func (tsr *Base[T]) Set1D(val T, i int) { tsr.Values[i] = val }
// SetNumRows sets the number of rows (outermost dimension) in a RowMajor organized tensor.
// It is safe to set this to 0. For incrementally growing tensors (e.g., a log)
// it is best to first set the anticipated full size, which allocates the
// full amount of memory, and then set to 0 and grow incrementally.
func (tsr *Base[T]) SetNumRows(rows int) {
if tsr.NumDims() == 0 {
tsr.SetShapeSizes(0)
}
_, cells := tsr.shape.RowCellSize()
nln := rows * cells
tsr.shape.Sizes[0] = rows
tsr.Values = slicesx.SetLength(tsr.Values, nln)
}
// subSpaceImpl returns a new tensor with innermost subspace at given
// offset(s) in outermost dimension(s) (len(offs) < NumDims).
// The new tensor points to the values of the this tensor (i.e., modifications
// will affect both), as its Values slice is a view onto the original (which
// is why only inner-most contiguous supsaces are supported).
// Use AsValues() method to separate the two.
func (tsr *Base[T]) subSpaceImpl(offs ...int) *Base[T] {
nd := tsr.NumDims()
od := len(offs)
if od > nd {
return nil
}
var ssz []int
if od == nd { // scalar subspace
ssz = []int{1}
} else {
ssz = tsr.shape.Sizes[od:]
}
stsr := &Base[T]{}
stsr.SetShapeSizes(ssz...)
sti := make([]int, nd)
copy(sti, offs)
stoff := tsr.shape.IndexTo1D(sti...)
sln := stsr.Len()
stsr.Values = tsr.Values[stoff : stoff+sln]
return stsr
}
//////// Strings
func (tsr *Base[T]) StringValue(i ...int) string {
return reflectx.ToString(tsr.Values[tsr.shape.IndexTo1D(i...)])
}
func (tsr *Base[T]) String1D(i int) string {
return reflectx.ToString(tsr.Values[i])
}
func (tsr *Base[T]) StringRow(row, cell int) string {
_, sz := tsr.shape.RowCellSize()
return reflectx.ToString(tsr.Values[row*sz+cell])
}
// Label satisfies the core.Labeler interface for a summary description of the tensor.
func (tsr *Base[T]) Label() string {
return label(metadata.Name(tsr), &tsr.shape)
}
// Copyright (c) 2024, The Cogent Core Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package bitslice implements a simple slice-of-bits using a []byte slice for storage,
// which is used for efficient storage of boolean data, such as projection connectivity patterns.
package bitslice
import "fmt"
// bitslice.Slice is the slice of []byte that holds the bits.
// first byte maintains the number of bits used in the last byte (0-7).
// when 0 then prior byte is all full and a new one must be added for append.
type Slice []byte
// BitIndex returns the byte, bit index of given bit index
func BitIndex(idx int) (byte int, bit uint32) {
return idx / 8, uint32(idx % 8)
}
// Make makes a new bitslice of given length and capacity (optional, pass 0 for default)
// *bits* (rounds up 1 for both).
// also reserves first byte for extra bits value
func Make(ln, cp int) Slice {
by, bi := BitIndex(ln)
bln := by
if bi != 0 {
bln++
}
var sl Slice
if cp > 0 {
sl = make(Slice, bln+1, (cp/8)+2)
} else {
sl = make(Slice, bln+1)
}
sl[0] = byte(bi)
return sl
}
// Len returns the length of the slice in bits
func (bs *Slice) Len() int {
ln := len(*bs)
if ln == 0 {
return 0
}
eb := (*bs)[0]
bln := ln - 1
if eb != 0 {
bln--
}
tln := bln*8 + int(eb)
return tln
}
// Cap returns the capacity of the slice in bits -- always modulo 8
func (bs *Slice) Cap() int {
return (cap(*bs) - 1) * 8
}
// SetLen sets the length of the slice, copying values if a new allocation is required
func (bs *Slice) SetLen(ln int) {
by, bi := BitIndex(ln)
bln := by
if bi != 0 {
bln++
}
if cap(*bs) >= bln+1 {
*bs = (*bs)[0 : bln+1]
(*bs)[0] = byte(bi)
} else {
sl := make(Slice, bln+1)
sl[0] = byte(bi)
copy(sl, *bs)
*bs = sl
}
}
// Set sets value of given bit index -- no extra range checking is performed -- will panic if out of range
func (bs *Slice) Set(val bool, idx int) {
by, bi := BitIndex(idx)
if val {
(*bs)[by+1] |= 1 << bi
} else {
(*bs)[by+1] &^= 1 << bi
}
}
// Index returns bit value at given bit index
func (bs *Slice) Index(idx int) bool {
by, bi := BitIndex(idx)
return ((*bs)[by+1] & (1 << bi)) != 0
}
// Append adds a bit to the slice and returns possibly new slice, possibly old slice..
func (bs *Slice) Append(val bool) Slice {
if len(*bs) == 0 {
*bs = Make(1, 0)
bs.Set(val, 0)
return *bs
}
ln := bs.Len()
eb := (*bs)[0]
if eb == 0 {
*bs = append(*bs, 0) // now we add
(*bs)[0] = 1
} else if eb < 7 {
(*bs)[0]++
} else {
(*bs)[0] = 0
}
bs.Set(val, ln)
return *bs
}
// SetAll sets all values to either on or off -- much faster than setting individual bits
func (bs *Slice) SetAll(val bool) {
ln := len(*bs)
for i := 1; i < ln; i++ {
if val {
(*bs)[i] = 0xFF
} else {
(*bs)[i] = 0
}
}
}
// ToBools converts to a []bool slice
func (bs *Slice) ToBools() []bool {
ln := len(*bs)
bb := make([]bool, ln)
for i := 0; i < ln; i++ {
bb[i] = bs.Index(i)
}
return bb
}
// Clone creates a new copy of this bitslice with separate memory
func (bs *Slice) Clone() Slice {
cp := make(Slice, len(*bs))
copy(cp, *bs)
return cp
}
// SubSlice returns a new Slice from given start, end range indexes of this slice
// if end is <= 0 then the length of the source slice is used (equivalent to omitting
// the number after the : in a Go subslice expression)
func (bs *Slice) SubSlice(start, end int) Slice {
ln := bs.Len()
if end <= 0 {
end = ln
}
if end > ln {
panic("bitslice.SubSlice: end index is beyond length of slice")
}
if start > end {
panic("bitslice.SubSlice: start index greater than end index")
}
nln := end - start
if nln <= 0 {
return Slice{}
}
ss := Make(nln, 0)
for i := 0; i < nln; i++ {
ss.Set(bs.Index(i+start), i)
}
return ss
}
// Delete returns a new bit slice with N elements removed starting at given index.
// This must be a copy given the nature of the 8-bit aliasing.
func (bs *Slice) Delete(start, n int) Slice {
ln := bs.Len()
if n <= 0 {
panic("bitslice.Delete: n <= 0")
}
if start >= ln {
panic("bitslice.Delete: start index >= length")
}
end := start + n
if end > ln {
panic("bitslice.Delete: end index greater than length")
}
nln := ln - n
if nln <= 0 {
return Slice{}
}
ss := Make(nln, 0)
for i := 0; i < start; i++ {
ss.Set(bs.Index(i), i)
}
for i := end; i < ln; i++ {
ss.Set(bs.Index(i), i-n)
}
return ss
}
// Insert returns a new bit slice with N false elements inserted starting at given index.
// This must be a copy given the nature of the 8-bit aliasing.
func (bs *Slice) Insert(start, n int) Slice {
ln := bs.Len()
if n <= 0 {
panic("bitslice.Insert: n <= 0")
}
if start > ln {
panic("bitslice.Insert: start index greater than length")
}
nln := ln + n
ss := Make(nln, 0)
for i := 0; i < start; i++ {
ss.Set(bs.Index(i), i)
}
for i := start; i < ln; i++ {
ss.Set(bs.Index(i), i+n)
}
return ss
}
// String satisfies the fmt.Stringer interface
func (bs *Slice) String() string {
ln := bs.Len()
if ln == 0 {
if *bs == nil {
return "nil"
}
return "[]"
}
mx := ln
if mx > 1000 {
mx = 1000
}
str := "["
for i := 0; i < mx; i++ {
val := bs.Index(i)
if val {
str += "1 "
} else {
str += "0 "
}
if (i+1)%80 == 0 {
str += "\n"
}
}
if ln > mx {
str += fmt.Sprintf("...(len=%v)", ln)
}
str += "]"
return str
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"fmt"
"reflect"
"slices"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/base/num"
"cogentcore.org/core/base/reflectx"
"cogentcore.org/core/base/slicesx"
"cogentcore.org/lab/tensor/bitslice"
)
// Bool is a tensor of bits backed by a [bitslice.Slice] for efficient storage
// of binary, boolean data. Bool does not support [RowMajor.SubSpace] access
// and related methods due to the nature of the underlying data representation.
type Bool struct {
shape Shape
Values bitslice.Slice
Meta metadata.Data
}
// NewBool returns a new n-dimensional tensor of bit values
// with the given sizes per dimension (shape).
func NewBool(sizes ...int) *Bool {
tsr := &Bool{}
tsr.SetShapeSizes(sizes...)
tsr.Values = bitslice.Make(tsr.Len(), 0)
return tsr
}
// NewBoolShape returns a new n-dimensional tensor of bit values
// using given shape.
func NewBoolShape(shape *Shape) *Bool {
tsr := &Bool{}
tsr.shape.CopyFrom(shape)
tsr.Values = bitslice.Make(tsr.Len(), 0)
return tsr
}
// Float64ToBool converts float64 value to bool.
func Float64ToBool(val float64) bool {
return num.ToBool(val)
}
// BoolToFloat64 converts bool to float64 value.
func BoolToFloat64(bv bool) float64 {
return num.FromBool[float64](bv)
}
// IntToBool converts int value to bool.
func IntToBool(val int) bool {
return num.ToBool(val)
}
// BoolToInt converts bool to int value.
func BoolToInt(bv bool) int {
return num.FromBool[int](bv)
}
// String satisfies the fmt.Stringer interface for string of tensor data.
func (tsr *Bool) String() string { return Sprintf("", tsr, 0) }
// Label satisfies the core.Labeler interface for a summary description of the tensor
func (tsr *Bool) Label() string {
return label(metadata.Name(tsr), tsr.Shape())
}
func (tsr *Bool) IsString() bool { return false }
func (tsr *Bool) AsValues() Values { return tsr }
// DataType returns the type of the data elements in the tensor.
// Bool is returned for the Bool tensor type.
func (tsr *Bool) DataType() reflect.Kind { return reflect.Bool }
func (tsr *Bool) Sizeof() int64 { return int64(len(tsr.Values)) }
func (tsr *Bool) Bytes() []byte { return slicesx.ToBytes(tsr.Values) }
func (tsr *Bool) Shape() *Shape { return &tsr.shape }
// ShapeSizes returns the sizes of each dimension as a slice of ints.
// The returned slice is a copy and can be modified without side effects.
func (tsr *Bool) ShapeSizes() []int { return slices.Clone(tsr.shape.Sizes) }
// Metadata returns the metadata for this tensor, which can be used
// to encode plotting options, etc.
func (tsr *Bool) Metadata() *metadata.Data { return &tsr.Meta }
// Len returns the number of elements in the tensor (product of shape dimensions).
func (tsr *Bool) Len() int { return tsr.shape.Len() }
// NumDims returns the total number of dimensions.
func (tsr *Bool) NumDims() int { return tsr.shape.NumDims() }
// DimSize returns size of given dimension
func (tsr *Bool) DimSize(dim int) int { return tsr.shape.DimSize(dim) }
func (tsr *Bool) SetShapeSizes(sizes ...int) {
tsr.shape.SetShapeSizes(sizes...)
nln := tsr.Len()
tsr.Values.SetLen(nln)
}
// SetNumRows sets the number of rows (outermost dimension) in a RowMajor organized tensor.
// It is safe to set this to 0. For incrementally growing tensors (e.g., a log)
// it is best to first set the anticipated full size, which allocates the
// full amount of memory, and then set to 0 and grow incrementally.
func (tsr *Bool) SetNumRows(rows int) {
_, cells := tsr.shape.RowCellSize()
nln := rows * cells
tsr.shape.Sizes[0] = rows
tsr.Values.SetLen(nln)
}
// SubSpace is not possible with Bool.
func (tsr *Bool) SubSpace(offs ...int) Values { return nil }
// RowTensor not possible with Bool.
func (tsr *Bool) RowTensor(row int) Values { return nil }
// SetRowTensor not possible with Bool.
func (tsr *Bool) SetRowTensor(val Values, row int) {}
// AppendRow not possible with Bool.
func (tsr *Bool) AppendRow(val Values) {}
/////// Bool
func (tsr *Bool) Value(i ...int) bool {
return tsr.Values.Index(tsr.shape.IndexTo1D(i...))
}
func (tsr *Bool) Set(val bool, i ...int) {
tsr.Values.Set(val, tsr.shape.IndexTo1D(i...))
}
func (tsr *Bool) Value1D(i int) bool { return tsr.Values.Index(i) }
func (tsr *Bool) Set1D(val bool, i int) { tsr.Values.Set(val, i) }
/////// Strings
func (tsr *Bool) String1D(off int) string {
return reflectx.ToString(tsr.Values.Index(off))
}
func (tsr *Bool) SetString1D(val string, off int) {
if bv, err := reflectx.ToBool(val); err == nil {
tsr.Values.Set(bv, off)
}
}
func (tsr *Bool) StringValue(i ...int) string {
return reflectx.ToString(tsr.Values.Index(tsr.shape.IndexTo1D(i...)))
}
func (tsr *Bool) SetString(val string, i ...int) {
if bv, err := reflectx.ToBool(val); err == nil {
tsr.Values.Set(bv, tsr.shape.IndexTo1D(i...))
}
}
func (tsr *Bool) StringRow(row, cell int) string {
_, sz := tsr.shape.RowCellSize()
return reflectx.ToString(tsr.Values.Index(row*sz + cell))
}
func (tsr *Bool) SetStringRow(val string, row, cell int) {
if bv, err := reflectx.ToBool(val); err == nil {
_, sz := tsr.shape.RowCellSize()
tsr.Values.Set(bv, row*sz+cell)
}
}
// AppendRowString not possible with Bool.
func (tsr *Bool) AppendRowString(val ...string) {}
/////// Floats
func (tsr *Bool) Float(i ...int) float64 {
return BoolToFloat64(tsr.Values.Index(tsr.shape.IndexTo1D(i...)))
}
func (tsr *Bool) SetFloat(val float64, i ...int) {
tsr.Values.Set(Float64ToBool(val), tsr.shape.IndexTo1D(i...))
}
func (tsr *Bool) Float1D(off int) float64 {
return BoolToFloat64(tsr.Values.Index(off))
}
func (tsr *Bool) SetFloat1D(val float64, off int) {
tsr.Values.Set(Float64ToBool(val), off)
}
func (tsr *Bool) FloatRow(row, cell int) float64 {
_, sz := tsr.shape.RowCellSize()
return BoolToFloat64(tsr.Values.Index(row*sz + cell))
}
func (tsr *Bool) SetFloatRow(val float64, row, cell int) {
_, sz := tsr.shape.RowCellSize()
tsr.Values.Set(Float64ToBool(val), row*sz+cell)
}
// AppendRowFloat not possible with Bool.
func (tsr *Bool) AppendRowFloat(val ...float64) {}
/////// Ints
func (tsr *Bool) Int(i ...int) int {
return BoolToInt(tsr.Values.Index(tsr.shape.IndexTo1D(i...)))
}
func (tsr *Bool) SetInt(val int, i ...int) {
tsr.Values.Set(IntToBool(val), tsr.shape.IndexTo1D(i...))
}
func (tsr *Bool) Int1D(off int) int {
return BoolToInt(tsr.Values.Index(off))
}
func (tsr *Bool) SetInt1D(val int, off int) {
tsr.Values.Set(IntToBool(val), off)
}
func (tsr *Bool) IntRow(row, cell int) int {
_, sz := tsr.shape.RowCellSize()
return BoolToInt(tsr.Values.Index(row*sz + cell))
}
func (tsr *Bool) SetIntRow(val int, row, cell int) {
_, sz := tsr.shape.RowCellSize()
tsr.Values.Set(IntToBool(val), row*sz+cell)
}
// AppendRowInt not possible with Bool.
func (tsr *Bool) AppendRowInt(val ...int) {}
/////// Bools
func (tsr *Bool) Bool(i ...int) bool {
return tsr.Values.Index(tsr.shape.IndexTo1D(i...))
}
func (tsr *Bool) SetBool(val bool, i ...int) {
tsr.Values.Set(val, tsr.shape.IndexTo1D(i...))
}
func (tsr *Bool) Bool1D(off int) bool {
return tsr.Values.Index(off)
}
func (tsr *Bool) SetBool1D(val bool, off int) {
tsr.Values.Set(val, off)
}
// SetZeros is a convenience function initialize all values to 0 (false).
func (tsr *Bool) SetZeros() {
ln := tsr.Len()
for j := 0; j < ln; j++ {
tsr.Values.Set(false, j)
}
}
// SetTrue is simple convenience function initialize all values to 0
func (tsr *Bool) SetTrue() {
ln := tsr.Len()
for j := 0; j < ln; j++ {
tsr.Values.Set(true, j)
}
}
// Clone clones this tensor, creating a duplicate copy of itself with its
// own separate memory representation of all the values, and returns
// that as a Tensor (which can be converted into the known type as needed).
func (tsr *Bool) Clone() Values {
csr := NewBoolShape(&tsr.shape)
csr.Values = tsr.Values.Clone()
return csr
}
// CopyFrom copies all avail values from other tensor into this tensor, with an
// optimized implementation if the other tensor is of the same type, and
// otherwise it goes through appropriate standard type.
func (tsr *Bool) CopyFrom(frm Values) {
if fsm, ok := frm.(*Bool); ok {
copy(tsr.Values, fsm.Values)
return
}
sz := min(len(tsr.Values), frm.Len())
for i := 0; i < sz; i++ {
tsr.Values.Set(Float64ToBool(frm.Float1D(i)), i)
}
}
// AppendFrom appends values from other tensor into this tensor,
// which must have the same cell size as this tensor.
// It uses and optimized implementation if the other tensor
// is of the same type, and otherwise it goes through
// appropriate standard type.
func (tsr *Bool) AppendFrom(frm Values) Values {
rows, cell := tsr.shape.RowCellSize()
frows, fcell := frm.Shape().RowCellSize()
if cell != fcell {
errors.Log(fmt.Errorf("tensor.AppendFrom: cell sizes do not match: %d != %d", cell, fcell))
return tsr
}
tsr.SetNumRows(rows + frows)
st := rows * cell
fsz := frows * fcell
if fsm, ok := frm.(*Bool); ok {
copy(tsr.Values[st:st+fsz], fsm.Values)
return tsr
}
for i := 0; i < fsz; i++ {
tsr.Values.Set(Float64ToBool(frm.Float1D(i)), st+i)
}
return tsr
}
// CopyCellsFrom copies given range of values from other tensor into this tensor,
// using flat 1D indexes: to = starting index in this Tensor to start copying into,
// start = starting index on from Tensor to start copying from, and n = number of
// values to copy. Uses an optimized implementation if the other tensor is
// of the same type, and otherwise it goes through appropriate standard type.
func (tsr *Bool) CopyCellsFrom(frm Values, to, start, n int) {
if fsm, ok := frm.(*Bool); ok {
for i := 0; i < n; i++ {
tsr.Values.Set(fsm.Values.Index(start+i), to+i)
}
return
}
for i := 0; i < n; i++ {
tsr.Values.Set(Float64ToBool(frm.Float1D(start+i)), to+i)
}
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"math"
"cogentcore.org/core/base/errors"
)
// Clone returns a copy of the given tensor.
// If it is raw [Values] then a [Values.Clone] is returned.
// Otherwise if it is a view, then [Tensor.AsValues] is returned.
// This is equivalent to the NumPy copy function.
func Clone(tsr Tensor) Values {
if vl, ok := tsr.(Values); ok {
return vl.Clone()
}
return tsr.AsValues()
}
// Flatten returns a copy of the given tensor as a 1D flat list
// of values, by calling Clone(As1D(tsr)).
// It is equivalent to the NumPy flatten function.
func Flatten(tsr Tensor) Values {
if msk, ok := tsr.(*Masked); ok {
return msk.AsValues()
}
return Clone(As1D(tsr))
}
// Squeeze a [Reshaped] view of given tensor with all singleton
// (size = 1) dimensions removed (if none, just returns the tensor).
func Squeeze(tsr Tensor) Tensor {
nd := tsr.NumDims()
sh := tsr.ShapeSizes()
reshape := make([]int, 0, nd)
for _, sz := range sh {
if sz > 1 {
reshape = append(reshape, sz)
}
}
if len(reshape) == nd {
return tsr
}
return NewReshaped(tsr, reshape...)
}
// As1D returns a 1D tensor, which is either the input tensor if it is
// already 1D, or a new [Reshaped] 1D view of it.
// This can be useful e.g., for stats and metric functions that operate
// on a 1D list of values. See also [Flatten].
func As1D(tsr Tensor) Tensor {
if tsr.NumDims() == 1 {
return tsr
}
return NewReshaped(tsr, tsr.Len())
}
// Cells1D returns a flat 1D view of the innermost cells for given row index.
// For a [RowMajor] tensor, it uses the [RowTensor] subspace directly,
// otherwise it uses [Sliced] to extract the cells. In either case,
// [As1D] is used to ensure the result is a 1D tensor.
func Cells1D(tsr Tensor, row int) Tensor {
if rm, ok := tsr.(RowMajor); ok {
return As1D(rm.RowTensor(row))
}
return As1D(NewSliced(tsr, []int{row}))
}
// MustBeValues returns the given tensor as a [Values] subtype, or nil and
// an error if it is not one. Typically outputs of compute operations must
// be values, and are reshaped to hold the results as needed.
func MustBeValues(tsr Tensor) (Values, error) {
vl, ok := tsr.(Values)
if !ok {
return nil, errors.New("tensor.MustBeValues: tensor must be a Values type")
}
return vl, nil
}
// MustBeSameShape returns an error if the two tensors do not have the same shape.
func MustBeSameShape(a, b Tensor) error {
if !a.Shape().IsEqual(b.Shape()) {
return errors.New("tensor.MustBeSameShape: tensors must have the same shape")
}
return nil
}
// SetShape sets the dimension sizes from given Shape
func SetShape(vals Values, sh *Shape) {
vals.SetShapeSizes(sh.Sizes...)
}
// SetShapeSizesFromTensor sets the dimension sizes as 1D int values from given tensor.
// The backing storage is resized appropriately, retaining all existing data that fits.
func SetShapeSizesFromTensor(vals Values, sizes Tensor) {
vals.SetShapeSizes(AsIntSlice(sizes)...)
}
// SetShapeFrom sets shape of given tensor from a source tensor.
func SetShapeFrom(vals Values, from Tensor) {
vals.SetShapeSizes(from.ShapeSizes()...)
}
// AsFloat64Scalar returns the first value of tensor as a float64 scalar.
// Returns 0 if no values.
func AsFloat64Scalar(tsr Tensor) float64 {
if tsr.Len() == 0 {
return 0
}
return tsr.Float1D(0)
}
// AsIntScalar returns the first value of tensor as an int scalar.
// Returns 0 if no values.
func AsIntScalar(tsr Tensor) int {
if tsr.Len() == 0 {
return 0
}
return tsr.Int1D(0)
}
// AsStringScalar returns the first value of tensor as a string scalar.
// Returns "" if no values.
func AsStringScalar(tsr Tensor) string {
if tsr.Len() == 0 {
return ""
}
return tsr.String1D(0)
}
// AsFloat64Slice returns all the tensor values as a slice of float64's.
// This allocates a new slice for the return values, and is not
// a good option for performance-critical code.
func AsFloat64Slice(tsr Tensor) []float64 {
if tsr.Len() == 0 {
return nil
}
sz := tsr.Len()
slc := make([]float64, sz)
for i := range sz {
slc[i] = tsr.Float1D(i)
}
return slc
}
// AsIntSlice returns all the tensor values as a slice of ints.
// This allocates a new slice for the return values, and is not
// a good option for performance-critical code.
func AsIntSlice(tsr Tensor) []int {
if tsr.Len() == 0 {
return nil
}
sz := tsr.Len()
slc := make([]int, sz)
for i := range sz {
slc[i] = tsr.Int1D(i)
}
return slc
}
// AsStringSlice returns all the tensor values as a slice of strings.
// This allocates a new slice for the return values, and is not
// a good option for performance-critical code.
func AsStringSlice(tsr Tensor) []string {
if tsr.Len() == 0 {
return nil
}
sz := tsr.Len()
slc := make([]string, sz)
for i := range sz {
slc[i] = tsr.String1D(i)
}
return slc
}
// AsFloat64 returns the tensor as a [Float64] tensor.
// If already is a Float64, it is returned as such.
// Otherwise, a new Float64 tensor is created and values are copied.
// Use this function for interfacing with gonum or other apis that
// only operate on float64 types.
func AsFloat64(tsr Tensor) *Float64 {
if f, ok := tsr.(*Float64); ok {
return f
}
f := NewFloat64(tsr.ShapeSizes()...)
f.CopyFrom(tsr.AsValues())
return f
}
// AsFloat32 returns the tensor as a [Float32] tensor.
// If already is a Float32, it is returned as such.
// Otherwise, a new Float32 tensor is created and values are copied.
func AsFloat32(tsr Tensor) *Float32 {
if f, ok := tsr.(*Float32); ok {
return f
}
f := NewFloat32(tsr.ShapeSizes()...)
f.CopyFrom(tsr.AsValues())
return f
}
// AsString returns the tensor as a [String] tensor.
// If already is a String, it is returned as such.
// Otherwise, a new String tensor is created and values are copied.
func AsString(tsr Tensor) *String {
if f, ok := tsr.(*String); ok {
return f
}
f := NewString(tsr.ShapeSizes()...)
f.CopyFrom(tsr.AsValues())
return f
}
// AsInt returns the tensor as a [Int] tensor.
// If already is a Int, it is returned as such.
// Otherwise, a new Int tensor is created and values are copied.
func AsInt(tsr Tensor) *Int {
if f, ok := tsr.(*Int); ok {
return f
}
f := NewInt(tsr.ShapeSizes()...)
f.CopyFrom(tsr.AsValues())
return f
}
// Range returns the min, max (and associated indexes, -1 = no values) for the tensor.
// This is needed for display and is thus in the tensor api on Values.
func Range(vals Values) (min, max float64, minIndex, maxIndex int) {
minIndex = -1
maxIndex = -1
n := vals.Len()
for j := range n {
fv := vals.Float1D(n)
if math.IsNaN(fv) {
continue
}
if fv < min || minIndex < 0 {
min = fv
minIndex = j
}
if fv > max || maxIndex < 0 {
max = fv
maxIndex = j
}
}
return
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"math/rand"
"slices"
)
// NewFloat64Scalar is a convenience method for a Tensor
// representation of a single float64 scalar value.
func NewFloat64Scalar(val float64) *Float64 {
return NewNumberFromValues(val)
}
// NewFloat32Scalar is a convenience method for a Tensor
// representation of a single float32 scalar value.
func NewFloat32Scalar(val float32) *Float32 {
return NewNumberFromValues(val)
}
// NewIntScalar is a convenience method for a Tensor
// representation of a single int scalar value.
func NewIntScalar(val int) *Int {
return NewNumberFromValues(val)
}
// NewStringScalar is a convenience method for a Tensor
// representation of a single string scalar value.
func NewStringScalar(val string) *String {
return NewStringFromValues(val)
}
// NewFloat64FromValues returns a new 1-dimensional tensor of given value type
// initialized directly from the given slice values, which are not copied.
// The resulting Tensor thus "wraps" the given values.
func NewFloat64FromValues(vals ...float64) *Float64 {
return NewNumberFromValues(vals...)
}
// NewFloat32FromValues returns a new 1-dimensional tensor of given value type
// initialized directly from the given slice values, which are not copied.
// The resulting Tensor thus "wraps" the given values.
func NewFloat32FromValues(vals ...float32) *Float32 {
return NewNumberFromValues(vals...)
}
// NewIntFromValues returns a new 1-dimensional tensor of given value type
// initialized directly from the given slice values, which are not copied.
// The resulting Tensor thus "wraps" the given values.
func NewIntFromValues(vals ...int) *Int {
return NewNumberFromValues(vals...)
}
// NewStringFromValues returns a new 1-dimensional tensor of given value type
// initialized directly from the given slice values, which are not copied.
// The resulting Tensor thus "wraps" the given values.
func NewStringFromValues(vals ...string) *String {
n := len(vals)
tsr := &String{}
tsr.Values = vals
tsr.SetShapeSizes(n)
return tsr
}
// SetAllFloat64 sets all values of given tensor to given value.
func SetAllFloat64(tsr Tensor, val float64) {
VectorizeThreaded(1, func(tsr ...Tensor) int { return tsr[0].Len() },
func(idx int, tsr ...Tensor) {
tsr[0].SetFloat1D(val, idx)
}, tsr)
}
// SetAllInt sets all values of given tensor to given value.
func SetAllInt(tsr Tensor, val int) {
VectorizeThreaded(1, func(tsr ...Tensor) int { return tsr[0].Len() },
func(idx int, tsr ...Tensor) {
tsr[0].SetInt1D(val, idx)
}, tsr)
}
// SetAllString sets all values of given tensor to given value.
func SetAllString(tsr Tensor, val string) {
VectorizeThreaded(1, func(tsr ...Tensor) int { return tsr[0].Len() },
func(idx int, tsr ...Tensor) {
tsr[0].SetString1D(val, idx)
}, tsr)
}
// NewFloat64Full returns a new tensor full of given scalar value,
// of given shape sizes.
func NewFloat64Full(val float64, sizes ...int) *Float64 {
tsr := NewFloat64(sizes...)
SetAllFloat64(tsr, val)
return tsr
}
// NewFloat64Ones returns a new tensor full of 1s,
// of given shape sizes.
func NewFloat64Ones(sizes ...int) *Float64 {
tsr := NewFloat64(sizes...)
SetAllFloat64(tsr, 1.0)
return tsr
}
// NewIntFull returns a new tensor full of given scalar value,
// of given shape sizes.
func NewIntFull(val int, sizes ...int) *Int {
tsr := NewInt(sizes...)
SetAllInt(tsr, val)
return tsr
}
// NewStringFull returns a new tensor full of given scalar value,
// of given shape sizes.
func NewStringFull(val string, sizes ...int) *String {
tsr := NewString(sizes...)
SetAllString(tsr, val)
return tsr
}
// NewFloat64Rand returns a new tensor full of random numbers from
// global random source, of given shape sizes.
func NewFloat64Rand(sizes ...int) *Float64 {
tsr := NewFloat64(sizes...)
FloatSetFunc(1, func(idx int) float64 { return rand.Float64() }, tsr)
return tsr
}
// NewIntRange returns a new [Int] [Tensor] with given [Slice]
// range parameters, with the same semantics as NumPy arange based on
// the number of arguments passed:
// - 1 = stop
// - 2 = start, stop
// - 3 = start, stop, step
func NewIntRange(svals ...int) *Int {
if len(svals) == 0 {
return NewInt()
}
sl := Slice{}
switch len(svals) {
case 1:
sl.Stop = svals[0]
case 2:
sl.Start = svals[0]
sl.Stop = svals[1]
case 3:
sl.Start = svals[0]
sl.Stop = svals[1]
sl.Step = svals[2]
}
return sl.IntTensor(sl.Stop)
}
// NewFloat64SpacedLinear returns a new [Float64] tensor with num linearly
// spaced numbers between start and stop values, as tensors, which
// must be the same length and determine the cell shape of the output.
// If num is 0, then a default of 50 is used.
// If endpoint = true, then the stop value is _inclusive_, i.e., it will
// be the final value, otherwise it is exclusive.
// This corresponds to the NumPy linspace function.
func NewFloat64SpacedLinear(start, stop Tensor, num int, endpoint bool) *Float64 {
if num <= 0 {
num = 50
}
fnum := float64(num)
if endpoint {
fnum -= 1
}
step := Clone(start)
n := step.Len()
for i := range n {
step.SetFloat1D((stop.Float1D(i)-start.Float1D(i))/fnum, i)
}
var tsr *Float64
if start.Len() == 1 {
tsr = NewFloat64(num)
} else {
tsz := slices.Clone(start.Shape().Sizes)
tsz = append([]int{num}, tsz...)
tsr = NewFloat64(tsz...)
}
for r := range num {
for i := range n {
tsr.SetFloatRow(start.Float1D(i)+float64(r)*step.Float1D(i), r, i)
}
}
return tsr
}
// Code generated by "core generate"; DO NOT EDIT.
package tensor
import (
"cogentcore.org/core/enums"
)
var _DelimsValues = []Delims{0, 1, 2, 3}
// DelimsN is the highest valid value for type Delims, plus one.
const DelimsN Delims = 4
var _DelimsValueMap = map[string]Delims{`Tab`: 0, `Comma`: 1, `Space`: 2, `Detect`: 3}
var _DelimsDescMap = map[Delims]string{0: `Tab is the tab rune delimiter, for TSV tab separated values`, 1: `Comma is the comma rune delimiter, for CSV comma separated values`, 2: `Space is the space rune delimiter, for SSV space separated value`, 3: `Detect is used during reading a file -- reads the first line and detects tabs or commas`}
var _DelimsMap = map[Delims]string{0: `Tab`, 1: `Comma`, 2: `Space`, 3: `Detect`}
// String returns the string representation of this Delims value.
func (i Delims) String() string { return enums.String(i, _DelimsMap) }
// SetString sets the Delims value from its string representation,
// and returns an error if the string is invalid.
func (i *Delims) SetString(s string) error { return enums.SetString(i, s, _DelimsValueMap, "Delims") }
// Int64 returns the Delims value as an int64.
func (i Delims) Int64() int64 { return int64(i) }
// SetInt64 sets the Delims value from an int64.
func (i *Delims) SetInt64(in int64) { *i = Delims(in) }
// Desc returns the description of the Delims value.
func (i Delims) Desc() string { return enums.Desc(i, _DelimsDescMap) }
// DelimsValues returns all possible values for the type Delims.
func DelimsValues() []Delims { return _DelimsValues }
// Values returns all possible values for the type Delims.
func (i Delims) Values() []enums.Enum { return enums.Values(_DelimsValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i Delims) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *Delims) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Delims") }
var _SlicesMagicValues = []SlicesMagic{0, 1, 2}
// SlicesMagicN is the highest valid value for type SlicesMagic, plus one.
const SlicesMagicN SlicesMagic = 3
var _SlicesMagicValueMap = map[string]SlicesMagic{`FullAxis`: 0, `NewAxis`: 1, `Ellipsis`: 2}
var _SlicesMagicDescMap = map[SlicesMagic]string{0: `FullAxis indicates that the full existing axis length should be used. This is equivalent to Slice{}, but is more semantic. In NumPy it is equivalent to a single : colon.`, 1: `NewAxis creates a new singleton (length=1) axis, used to to reshape without changing the size. Can also be used in [Reshaped].`, 2: `Ellipsis (...) is used in [NewSliced] expressions to produce a flexibly-sized stretch of FullAxis dimensions, which automatically aligns the remaining slice elements based on the source dimensionality.`}
var _SlicesMagicMap = map[SlicesMagic]string{0: `FullAxis`, 1: `NewAxis`, 2: `Ellipsis`}
// String returns the string representation of this SlicesMagic value.
func (i SlicesMagic) String() string { return enums.String(i, _SlicesMagicMap) }
// SetString sets the SlicesMagic value from its string representation,
// and returns an error if the string is invalid.
func (i *SlicesMagic) SetString(s string) error {
return enums.SetString(i, s, _SlicesMagicValueMap, "SlicesMagic")
}
// Int64 returns the SlicesMagic value as an int64.
func (i SlicesMagic) Int64() int64 { return int64(i) }
// SetInt64 sets the SlicesMagic value from an int64.
func (i *SlicesMagic) SetInt64(in int64) { *i = SlicesMagic(in) }
// Desc returns the description of the SlicesMagic value.
func (i SlicesMagic) Desc() string { return enums.Desc(i, _SlicesMagicDescMap) }
// SlicesMagicValues returns all possible values for the type SlicesMagic.
func SlicesMagicValues() []SlicesMagic { return _SlicesMagicValues }
// Values returns all possible values for the type SlicesMagic.
func (i SlicesMagic) Values() []enums.Enum { return enums.Values(_SlicesMagicValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i SlicesMagic) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *SlicesMagic) UnmarshalText(text []byte) error {
return enums.UnmarshalText(i, text, "SlicesMagic")
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"fmt"
"reflect"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/metadata"
)
// Func represents a registered tensor function, which has
// In number of input Tensor arguments, and Out number of output
// arguments (typically 1). There can also be an 'any' first
// argument to support other kinds of parameters.
// This is used to make tensor functions available to the Goal language.
type Func struct {
// Name is the original CamelCase Go name for function
Name string
// Fun is the function, which must _only_ take some number of Tensor
// args, with an optional any first arg.
Fun any
// Args has parsed information about the function args, for Goal.
Args []*Arg
}
// Arg has key information that Goal needs about each arg, for converting
// expressions into the appropriate type.
type Arg struct {
// Type has full reflection type info.
Type reflect.Type
// IsTensor is true if it satisfies the Tensor interface.
IsTensor bool
// IsInt is true if Kind = Int, for shape, slice etc params.
IsInt bool
// IsVariadic is true if this is the last arg and has ...; type will be an array.
IsVariadic bool
}
// NewFunc creates a new Func desciption of the given
// function, which must have a signature like this:
// func([opt any,] a, b, out tensor.Tensor) error
// i.e., taking some specific number of Tensor arguments (up to 5).
// Functions can also take an 'any' first argument to handle other
// non-tensor inputs (e.g., function pointer, dirfs directory, etc).
// The name should be a standard 'package.FuncName' qualified, exported
// CamelCase name, with 'out' indicating the number of output arguments,
// and an optional arg indicating an 'any' first argument.
// The remaining arguments in the function (automatically
// determined) are classified as input arguments.
func NewFunc(name string, fun any) (*Func, error) {
fn := &Func{Name: name, Fun: fun}
fn.GetArgs()
return fn, nil
}
// GetArgs gets key info about each arg, for use by Goal transpiler.
func (fn *Func) GetArgs() {
ft := reflect.TypeOf(fn.Fun)
n := ft.NumIn()
if n == 0 {
return
}
fn.Args = make([]*Arg, n)
tsrt := reflect.TypeFor[Tensor]()
for i := range n {
at := ft.In(i)
ag := &Arg{Type: at}
if ft.IsVariadic() && i == n-1 {
ag.IsVariadic = true
}
if at.Kind() == reflect.Int || (at.Kind() == reflect.Slice && at.Elem().Kind() == reflect.Int) {
ag.IsInt = true
} else if at.Implements(tsrt) {
ag.IsTensor = true
}
fn.Args[i] = ag
}
}
func (fn *Func) String() string {
s := fn.Name + "("
na := len(fn.Args)
for i, a := range fn.Args {
if a.IsVariadic {
s += "..."
}
ts := a.Type.String()
if ts == "interface {}" {
ts = "any"
}
s += ts
if i < na-1 {
s += ", "
}
}
s += ")"
return s
}
// Funcs is the global tensor named function registry.
// All functions must have a signature like this:
// func([opt any,] a, b, out tensor.Tensor) error
// i.e., taking some specific number of Tensor arguments (up to 5),
// with the number of output vs. input arguments registered.
// Functions can also take an 'any' first argument to handle other
// non-tensor inputs (e.g., function pointer, dirfs directory, etc).
// This is used to make tensor functions available to the Goal
// language.
var Funcs map[string]*Func
// AddFunc adds given named function to the global tensor named function
// registry, which is used by Goal to call functions by name.
// See [NewFunc] for more informa.tion.
func AddFunc(name string, fun any) error {
if Funcs == nil {
Funcs = make(map[string]*Func)
}
_, ok := Funcs[name]
if ok {
return fmt.Errorf("tensor.AddFunc: function of name %q already exists, not added", name)
}
fn, err := NewFunc(name, fun)
if errors.Log(err) != nil {
return err
}
Funcs[name] = fn
// note: can record orig camel name if needed for docs etc later.
return nil
}
// FuncByName finds function of given name in the registry,
// returning an error if the function name has not been registered.
func FuncByName(name string) (*Func, error) {
fn, ok := Funcs[name]
if !ok {
return nil, fmt.Errorf("tensor.FuncByName: function of name %q not registered", name)
}
return fn, nil
}
// These generic functions provide a one liner for wrapping functions
// that take an output Tensor as the last argument, which is important
// for memory re-use of the output in performance-critical cases.
// The names indicate the number of input tensor arguments.
// Additional generic non-Tensor inputs are supported up to 2,
// with Gen1 and Gen2 versions.
// FloatPromoteType returns the DataType for Tensor(s) that promotes
// the Float type if any of the elements are of that type.
// Otherwise it returns the type of the first tensor.
func FloatPromoteType(tsr ...Tensor) reflect.Kind {
ft := tsr[0].DataType()
for i := 1; i < len(tsr); i++ {
t := tsr[i].DataType()
if t == reflect.Float64 {
ft = t
} else if t == reflect.Float32 && ft != reflect.Float64 {
ft = t
}
}
return ft
}
// CallOut1 adds output [Values] tensor for function.
func CallOut1(fun func(a Tensor, out Values) error, a Tensor) Values {
out := NewOfType(a.DataType())
errors.Log(fun(a, out))
return out
}
// CallOut1Float64 adds Float64 output [Values] tensor for function.
func CallOut1Float64(fun func(a Tensor, out Values) error, a Tensor) Values {
out := NewFloat64()
errors.Log(fun(a, out))
return out
}
func CallOut2Float64(fun func(a, b Tensor, out Values) error, a, b Tensor) Values {
out := NewFloat64()
errors.Log(fun(a, b, out))
return out
}
func CallOut2(fun func(a, b Tensor, out Values) error, a, b Tensor) Values {
out := NewOfType(FloatPromoteType(a, b))
errors.Log(fun(a, b, out))
return out
}
func CallOut3(fun func(a, b, c Tensor, out Values) error, a, b, c Tensor) Values {
out := NewOfType(FloatPromoteType(a, b, c))
errors.Log(fun(a, b, c, out))
return out
}
func CallOut2Bool(fun func(a, b Tensor, out *Bool) error, a, b Tensor) *Bool {
out := NewBool()
errors.Log(fun(a, b, out))
return out
}
func CallOut1Gen1[T any](fun func(g T, a Tensor, out Values) error, g T, a Tensor) Values {
out := NewOfType(a.DataType())
errors.Log(fun(g, a, out))
return out
}
func CallOut1Gen2[T any, S any](fun func(g T, h S, a Tensor, out Values) error, g T, h S, a Tensor) Values {
out := NewOfType(a.DataType())
errors.Log(fun(g, h, a, out))
return out
}
func CallOut2Gen1[T any](fun func(g T, a, b Tensor, out Values) error, g T, a, b Tensor) Values {
out := NewOfType(FloatPromoteType(a, b))
errors.Log(fun(g, a, b, out))
return out
}
func CallOut2Gen2[T any, S any](fun func(g T, h S, a, b Tensor, out Values) error, g T, h S, a, b Tensor) Values {
out := NewOfType(FloatPromoteType(a, b))
errors.Log(fun(g, h, a, b, out))
return out
}
//////// Metadata
// SetCalcFunc sets a function to calculate updated value for given tensor,
// storing the function pointer in the Metadata "CalcFunc" key for the tensor.
// Can be called by [Calc] function.
func SetCalcFunc(tsr Tensor, fun func() error) {
tsr.Metadata().Set("CalcFunc", fun)
}
// Calc calls function set by [SetCalcFunc] to compute an updated value for
// given tensor. Returns an error if func not set, or any error from func itself.
// Function is stored as CalcFunc in Metadata.
func Calc(tsr Tensor) error {
fun, err := metadata.Get[func() error](*tsr.Metadata(), "CalcFunc")
if err != nil {
return err
}
return fun()
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"reflect"
"slices"
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/base/reflectx"
)
// Indexed provides an arbitrarily indexed view onto another "source" [Tensor]
// with each index value providing a full n-dimensional index into the source.
// The shape of this view is determined by the shape of the [Indexed.Indexes]
// tensor up to the final innermost dimension, which holds the index values.
// Thus the innermost dimension size of the indexes is equal to the number
// of dimensions in the source tensor. Given the essential role of the
// indexes in this view, it is not usable without the indexes.
// This view is not memory-contiguous and does not support the [RowMajor]
// interface or efficient access to inner-dimensional subspaces.
// To produce a new concrete [Values] that has raw data actually
// organized according to the indexed order (i.e., the copy function
// of numpy), call [Indexed.AsValues].
type Indexed struct { //types:add
// Tensor source that we are an indexed view onto.
Tensor Tensor
// Indexes is the list of indexes into the source tensor,
// with the innermost dimension providing the index values
// (size = number of dimensions in the source tensor), and
// the remaining outer dimensions determine the shape
// of this [Indexed] tensor view.
Indexes *Int
}
// NewIndexed returns a new [Indexed] view of given tensor,
// with tensor of indexes into the source tensor.
func NewIndexed(tsr Tensor, idx *Int) *Indexed {
ix := &Indexed{Tensor: tsr}
ix.Indexes = idx
return ix
}
// AsIndexed returns the tensor as a [Indexed] view, if it is one.
// Otherwise, it returns nil; there is no usable "null" Indexed view.
func AsIndexed(tsr Tensor) *Indexed {
if ix, ok := tsr.(*Indexed); ok {
return ix
}
return nil
}
// SetTensor sets as indexes into given tensor with sequential initial indexes.
func (ix *Indexed) SetTensor(tsr Tensor) {
ix.Tensor = tsr
}
// SourceIndexes returns the actual indexes into underlying source tensor
// based on given list of indexes into the [Indexed.Indexes] tensor,
// _excluding_ the final innermost dimension.
func (ix *Indexed) SourceIndexes(i ...int) []int {
idx := slices.Clone(i)
idx = append(idx, 0) // first index
oned := ix.Indexes.Shape().IndexTo1D(idx...)
nd := ix.Tensor.NumDims()
return ix.Indexes.Values[oned : oned+nd]
}
// SourceIndexesFrom1D returns the full indexes into source tensor based on the
// given 1d index, which is based on the outer dimensions, excluding the
// final innermost dimension.
func (ix *Indexed) SourceIndexesFrom1D(oned int) []int {
nd := ix.Tensor.NumDims()
oned *= nd
return ix.Indexes.Values[oned : oned+nd]
}
func (ix *Indexed) Label() string { return label(metadata.Name(ix), ix.Shape()) }
func (ix *Indexed) String() string { return Sprintf("", ix, 0) }
func (ix *Indexed) Metadata() *metadata.Data { return ix.Tensor.Metadata() }
func (ix *Indexed) IsString() bool { return ix.Tensor.IsString() }
func (ix *Indexed) DataType() reflect.Kind { return ix.Tensor.DataType() }
func (ix *Indexed) Shape() *Shape { return NewShape(ix.ShapeSizes()...) }
func (ix *Indexed) Len() int { return ix.Shape().Len() }
func (ix *Indexed) NumDims() int { return ix.Indexes.NumDims() - 1 }
func (ix *Indexed) DimSize(dim int) int { return ix.Indexes.DimSize(dim) }
func (ix *Indexed) ShapeSizes() []int {
si := slices.Clone(ix.Indexes.ShapeSizes())
return si[:len(si)-1] // exclude last dim
}
// AsValues returns a copy of this tensor as raw [Values].
// This "renders" the Indexed view into a fully contiguous
// and optimized memory representation of that view, which will be faster
// to access for further processing, and enables all the additional
// functionality provided by the [Values] interface.
func (ix *Indexed) AsValues() Values {
dt := ix.Tensor.DataType()
vt := NewOfType(dt, ix.ShapeSizes()...)
n := ix.Len()
switch {
case ix.Tensor.IsString():
for i := range n {
vt.SetString1D(ix.String1D(i), i)
}
case reflectx.KindIsFloat(dt):
for i := range n {
vt.SetFloat1D(ix.Float1D(i), i)
}
default:
for i := range n {
vt.SetInt1D(ix.Int1D(i), i)
}
}
return vt
}
//////// Floats
// Float returns the value of given index as a float64.
// The indexes are indirected through the [Indexed.Indexes].
func (ix *Indexed) Float(i ...int) float64 {
return ix.Tensor.Float(ix.SourceIndexes(i...)...)
}
// SetFloat sets the value of given index as a float64
// The indexes are indirected through the [Indexed.Indexes].
func (ix *Indexed) SetFloat(val float64, i ...int) {
ix.Tensor.SetFloat(val, ix.SourceIndexes(i...)...)
}
// Float1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (ix *Indexed) Float1D(i int) float64 {
return ix.Tensor.Float(ix.SourceIndexesFrom1D(i)...)
}
// SetFloat1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (ix *Indexed) SetFloat1D(val float64, i int) {
ix.Tensor.SetFloat(val, ix.SourceIndexesFrom1D(i)...)
}
//////// Strings
// StringValue returns the value of given index as a string.
// The indexes are indirected through the [Indexed.Indexes].
func (ix *Indexed) StringValue(i ...int) string {
return ix.Tensor.StringValue(ix.SourceIndexes(i...)...)
}
// SetString sets the value of given index as a string
// The indexes are indirected through the [Indexed.Indexes].
func (ix *Indexed) SetString(val string, i ...int) {
ix.Tensor.SetString(val, ix.SourceIndexes(i...)...)
}
// String1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (ix *Indexed) String1D(i int) string {
return ix.Tensor.StringValue(ix.SourceIndexesFrom1D(i)...)
}
// SetString1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (ix *Indexed) SetString1D(val string, i int) {
ix.Tensor.SetString(val, ix.SourceIndexesFrom1D(i)...)
}
//////// Ints
// Int returns the value of given index as an int.
// The indexes are indirected through the [Indexed.Indexes].
func (ix *Indexed) Int(i ...int) int {
return ix.Tensor.Int(ix.SourceIndexes(i...)...)
}
// SetInt sets the value of given index as an int
// The indexes are indirected through the [Indexed.Indexes].
func (ix *Indexed) SetInt(val int, i ...int) {
ix.Tensor.SetInt(val, ix.SourceIndexes(i...)...)
}
// Int1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (ix *Indexed) Int1D(i int) int {
return ix.Tensor.Int(ix.SourceIndexesFrom1D(i)...)
}
// SetInt1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (ix *Indexed) SetInt1D(val int, i int) {
ix.Tensor.SetInt(val, ix.SourceIndexesFrom1D(i)...)
}
// check for interface impl
var _ Tensor = (*Indexed)(nil)
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"bytes"
"encoding/csv"
"fmt"
"io"
"log"
"os"
"strconv"
"strings"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/base/reflectx"
)
// Delim are standard CSV delimiter options (Tab, Comma, Space)
type Delims int32 //enums:enum
const (
// Tab is the tab rune delimiter, for TSV tab separated values
Tab Delims = iota
// Comma is the comma rune delimiter, for CSV comma separated values
Comma
// Space is the space rune delimiter, for SSV space separated value
Space
// Detect is used during reading a file -- reads the first line and detects tabs or commas
Detect
)
func (dl Delims) Rune() rune {
switch dl {
case Tab:
return '\t'
case Comma:
return ','
case Space:
return ' '
}
return '\t'
}
// SetPrecision sets the "precision" metadata value that determines
// the precision to use in writing floating point numbers to files.
func SetPrecision(obj any, prec int) {
metadata.SetTo(obj, "Precision", prec)
}
// Precision gets the "precision" metadata value that determines
// the precision to use in writing floating point numbers to files.
// returns an error if not set.
func Precision(obj any) (int, error) {
return metadata.GetFrom[int](obj, "Precision")
}
// SaveCSV writes a tensor to a comma-separated-values (CSV) file
// (where comma = any delimiter, specified in the delim arg).
// Outer-most dims are rows in the file, and inner-most is column --
// Reading just grabs all values and doesn't care about shape.
func SaveCSV(tsr Tensor, filename fsx.Filename, delim Delims) error {
fp, err := os.Create(string(filename))
defer fp.Close()
if err != nil {
log.Println(err)
return err
}
WriteCSV(tsr, fp, delim)
return nil
}
// OpenCSV reads a tensor from a comma-separated-values (CSV) file
// (where comma = any delimiter, specified in the delim arg),
// using the Go standard encoding/csv reader conforming
// to the official CSV standard.
// Reads all values and assigns as many as fit.
func OpenCSV(tsr Tensor, filename fsx.Filename, delim Delims) error {
fp, err := os.Open(string(filename))
defer fp.Close()
if err != nil {
log.Println(err)
return err
}
return ReadCSV(tsr, fp, delim)
}
//////// WriteCSV
// WriteCSV writes a tensor to a comma-separated-values (CSV) file
// (where comma = any delimiter, specified in the delim arg).
// Outer-most dims are rows in the file, and inner-most is column --
// Reading just grabs all values and doesn't care about shape.
func WriteCSV(tsr Tensor, w io.Writer, delim Delims) error {
prec := -1
if ps, err := Precision(tsr); err == nil {
prec = ps
}
cw := csv.NewWriter(w)
cw.Comma = delim.Rune()
nrow := tsr.DimSize(0)
nin := tsr.Len() / nrow
rec := make([]string, nin)
str := tsr.IsString()
for ri := 0; ri < nrow; ri++ {
for ci := 0; ci < nin; ci++ {
idx := ri*nin + ci
if str {
rec[ci] = tsr.String1D(idx)
} else {
rec[ci] = strconv.FormatFloat(tsr.Float1D(idx), 'g', prec, 64)
}
}
err := cw.Write(rec)
if err != nil {
log.Println(err)
return err
}
}
cw.Flush()
return nil
}
// ReadCSV reads a tensor from a comma-separated-values (CSV) file
// (where comma = any delimiter, specified in the delim arg),
// using the Go standard encoding/csv reader conforming
// to the official CSV standard.
// Reads all values and assigns as many as fit.
func ReadCSV(tsr Tensor, r io.Reader, delim Delims) error {
cr := csv.NewReader(r)
cr.Comma = delim.Rune()
rec, err := cr.ReadAll() // todo: lazy, avoid resizing
if err != nil || len(rec) == 0 {
return err
}
rows := len(rec)
cols := len(rec[0])
sz := tsr.Len()
idx := 0
for ri := 0; ri < rows; ri++ {
for ci := 0; ci < cols; ci++ {
str := rec[ri][ci]
tsr.SetString1D(str, idx)
idx++
if idx >= sz {
goto done
}
}
}
done:
return nil
}
func label(nm string, sh *Shape) string {
if nm != "" {
nm += " " + sh.String()
} else {
nm = sh.String()
}
return nm
}
// padToLength returns the given string with added spaces
// to pad out to target length. at least 1 space will be added
func padToLength(str string, tlen int) string {
slen := len(str)
if slen < tlen-1 {
return str + strings.Repeat(" ", tlen-slen)
}
return str + " "
}
// prepadToLength returns the given string with added spaces
// to pad out to target length at start (for numbers).
// at least 1 space will be added
func prepadToLength(str string, tlen int) string {
slen := len(str)
if slen < tlen-1 {
return strings.Repeat(" ", tlen-slen-1) + str + " "
}
return str + " "
}
// MaxPrintLineWidth is the maximum line width in characters
// to generate for tensor Sprintf function.
var MaxPrintLineWidth = 80
// Sprintf returns a string representation of the given tensor,
// with a maximum length of as given: output is terminated
// when it exceeds that length. If maxLen = 0, [MaxSprintLength] is used.
// The format is the per-element format string.
// If empty it uses general %g for number or %s for string.
func Sprintf(format string, tsr Tensor, maxLen int) string {
if maxLen == 0 {
maxLen = MaxSprintLength
}
defFmt := format == ""
if defFmt {
switch {
case tsr.IsString():
format = "%s"
case reflectx.KindIsInt(tsr.DataType()):
format = "%.10g"
default:
format = "%.10g"
}
}
nd := tsr.NumDims()
if nd == 1 && tsr.DimSize(0) == 1 { // scalar special case
if tsr.IsString() {
return fmt.Sprintf(format, tsr.String1D(0))
} else {
return fmt.Sprintf(format, tsr.Float1D(0))
}
}
mxwd := 0
n := min(tsr.Len(), maxLen)
for i := range n {
s := ""
if tsr.IsString() {
s = fmt.Sprintf(format, tsr.String1D(i))
} else {
s = fmt.Sprintf(format, tsr.Float1D(i))
}
if len(s) > mxwd {
mxwd = len(s)
}
}
onedRow := false
shp := tsr.Shape()
rowShape, colShape, _, colIdxs := Projection2DDimShapes(shp, onedRow)
rows, cols, _, _ := Projection2DShape(shp, onedRow)
rowWd := len(rowShape.String()) + 1
legend := ""
if nd > 2 {
leg := bytes.Repeat([]byte("r "), nd)
for _, i := range colIdxs {
leg[2*i] = 'c'
}
legend = "[" + string(leg[:len(leg)-1]) + "]"
}
rowWd = max(rowWd, len(legend)+1)
hdrWd := len(colShape.String()) + 1
colWd := mxwd + 1
var b strings.Builder
b.WriteString(tsr.Label())
noidx := false
if tsr.NumDims() == 1 {
b.WriteString(" ")
rowWd = len(tsr.Label()) + 1
noidx = true
} else {
b.WriteString("\n")
}
if !noidx && nd > 1 && cols > 1 {
colWd = max(colWd, hdrWd)
b.WriteString(padToLength(legend, rowWd))
totWd := rowWd
for c := 0; c < cols; c++ {
_, cc := Projection2DCoords(shp, onedRow, 0, c)
s := prepadToLength(fmt.Sprintf("%v", cc), colWd)
if totWd+len(s) > MaxPrintLineWidth {
b.WriteString("\n" + strings.Repeat(" ", rowWd))
totWd = rowWd
}
b.WriteString(s)
totWd += len(s)
}
b.WriteString("\n")
}
ctr := 0
for r := range rows {
rc, _ := Projection2DCoords(shp, onedRow, r, 0)
if !noidx {
b.WriteString(padToLength(fmt.Sprintf("%v", rc), rowWd))
}
ri := r
totWd := rowWd
for c := 0; c < cols; c++ {
s := ""
if tsr.IsString() {
s = padToLength(fmt.Sprintf(format, Projection2DString(tsr, onedRow, ri, c)), colWd)
} else {
s = prepadToLength(fmt.Sprintf(format, Projection2DValue(tsr, onedRow, ri, c)), colWd)
}
if totWd+len(s) > MaxPrintLineWidth {
b.WriteString("\n" + strings.Repeat(" ", rowWd))
totWd = rowWd
}
b.WriteString(s)
totWd += len(s)
}
b.WriteString("\n")
ctr += cols
if ctr > maxLen {
b.WriteString("...\n")
break
}
}
return b.String()
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"math"
"reflect"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/base/reflectx"
)
// Masked is a filtering wrapper around another "source" [Tensor],
// that provides a bit-masked view onto the Tensor defined by a [Bool] [Values]
// tensor with a matching shape. If the bool mask has a 'false'
// then the corresponding value cannot be Set, and Float access returns
// NaN indicating missing data (other type access returns the zero value).
// A new Masked view defaults to a full transparent view of the source tensor.
// To produce a new [Values] tensor with only the 'true' cases,
// (i.e., the copy function of numpy), call [Masked.AsValues].
type Masked struct { //types:add
// Tensor source that we are a masked view onto.
Tensor Tensor
// Bool tensor with same shape as source tensor, providing mask.
Mask *Bool
}
// NewMasked returns a new [Masked] view of given tensor,
// with given [Bool] mask values. If no mask is provided,
// a default full transparent (all bool values = true) mask is used.
func NewMasked(tsr Tensor, mask ...*Bool) *Masked {
ms := &Masked{Tensor: tsr}
if len(mask) == 1 {
ms.Mask = mask[0]
ms.SyncShape()
} else {
ms.Mask = NewBoolShape(tsr.Shape())
ms.Mask.SetTrue()
}
return ms
}
// Mask is the general purpose masking function, which checks
// if the mask arg is a Bool and uses if so.
// Otherwise, it logs an error.
func Mask(tsr, mask Tensor) Tensor {
if mb, ok := mask.(*Bool); ok {
return NewMasked(tsr, mb)
}
errors.Log(errors.New("tensor.Mask: provided tensor is not a Bool tensor"))
return tsr
}
// AsMasked returns the tensor as a [Masked] view.
// If it already is one, then it is returned, otherwise it is wrapped
// with an initially fully transparent mask.
func AsMasked(tsr Tensor) *Masked {
if ms, ok := tsr.(*Masked); ok {
return ms
}
return NewMasked(tsr)
}
// SetTensor sets the given source tensor. If the shape does not match
// the current Mask, then a new transparent mask is established.
func (ms *Masked) SetTensor(tsr Tensor) {
ms.Tensor = tsr
ms.SyncShape()
}
// SyncShape ensures that [Masked.Mask] shape is the same as source tensor.
// If the Mask does not exist or is a different shape from the source,
// then it is created or reshaped, and all values set to true ("transparent").
func (ms *Masked) SyncShape() {
if ms.Mask == nil {
ms.Mask = NewBoolShape(ms.Tensor.Shape())
ms.Mask.SetTrue()
return
}
if !ms.Mask.Shape().IsEqual(ms.Tensor.Shape()) {
SetShapeFrom(ms.Mask, ms.Tensor)
ms.Mask.SetTrue()
}
}
func (ms *Masked) Label() string { return label(metadata.Name(ms), ms.Shape()) }
func (ms *Masked) String() string { return Sprintf("", ms, 0) }
func (ms *Masked) Metadata() *metadata.Data { return ms.Tensor.Metadata() }
func (ms *Masked) IsString() bool { return ms.Tensor.IsString() }
func (ms *Masked) DataType() reflect.Kind { return ms.Tensor.DataType() }
func (ms *Masked) ShapeSizes() []int { return ms.Tensor.ShapeSizes() }
func (ms *Masked) Shape() *Shape { return ms.Tensor.Shape() }
func (ms *Masked) Len() int { return ms.Tensor.Len() }
func (ms *Masked) NumDims() int { return ms.Tensor.NumDims() }
func (ms *Masked) DimSize(dim int) int { return ms.Tensor.DimSize(dim) }
// AsValues returns a copy of this tensor as raw [Values].
// This "renders" the Masked view into a fully contiguous
// and optimized memory representation of that view.
// Because the masking pattern is unpredictable, only a 1D shape is possible.
func (ms *Masked) AsValues() Values {
dt := ms.Tensor.DataType()
n := ms.Len()
switch {
case ms.Tensor.IsString():
vals := make([]string, 0, n)
for i := range n {
if !ms.Mask.Bool1D(i) {
continue
}
vals = append(vals, ms.Tensor.String1D(i))
}
return NewStringFromValues(vals...)
case reflectx.KindIsFloat(dt):
vals := make([]float64, 0, n)
for i := range n {
if !ms.Mask.Bool1D(i) {
continue
}
vals = append(vals, ms.Tensor.Float1D(i))
}
return NewFloat64FromValues(vals...)
default:
vals := make([]int, 0, n)
for i := range n {
if !ms.Mask.Bool1D(i) {
continue
}
vals = append(vals, ms.Tensor.Int1D(i))
}
return NewIntFromValues(vals...)
}
}
// SourceIndexes returns a flat [Int] tensor of the mask values
// that match the given getTrue argument state.
// These can be used as indexes in the [Indexed] view, for example.
// The resulting tensor is 2D with inner dimension = number of source
// tensor dimensions, to hold the indexes, and outer dimension = number
// of indexes.
func (ms *Masked) SourceIndexes(getTrue bool) *Int {
n := ms.Len()
nd := ms.Tensor.NumDims()
idxs := make([]int, 0, n*nd)
for i := range n {
if ms.Mask.Bool1D(i) != getTrue {
continue
}
ix := ms.Tensor.Shape().IndexFrom1D(i)
idxs = append(idxs, ix...)
}
it := NewIntFromValues(idxs...)
it.SetShapeSizes(len(idxs)/nd, nd)
return it
}
//////// Floats
func (ms *Masked) Float(i ...int) float64 {
if !ms.Mask.Bool(i...) {
return math.NaN()
}
return ms.Tensor.Float(i...)
}
func (ms *Masked) SetFloat(val float64, i ...int) {
if !ms.Mask.Bool(i...) {
return
}
ms.Tensor.SetFloat(val, i...)
}
func (ms *Masked) Float1D(i int) float64 {
if !ms.Mask.Bool1D(i) {
return math.NaN()
}
return ms.Tensor.Float1D(i)
}
func (ms *Masked) SetFloat1D(val float64, i int) {
if !ms.Mask.Bool1D(i) {
return
}
ms.Tensor.SetFloat1D(val, i)
}
//////// Strings
func (ms *Masked) StringValue(i ...int) string {
if !ms.Mask.Bool(i...) {
return ""
}
return ms.Tensor.StringValue(i...)
}
func (ms *Masked) SetString(val string, i ...int) {
if !ms.Mask.Bool(i...) {
return
}
ms.Tensor.SetString(val, i...)
}
func (ms *Masked) String1D(i int) string {
if !ms.Mask.Bool1D(i) {
return ""
}
return ms.Tensor.String1D(i)
}
func (ms *Masked) SetString1D(val string, i int) {
if !ms.Mask.Bool1D(i) {
return
}
ms.Tensor.SetString1D(val, i)
}
//////// Ints
func (ms *Masked) Int(i ...int) int {
if !ms.Mask.Bool(i...) {
return 0
}
return ms.Tensor.Int(i...)
}
func (ms *Masked) SetInt(val int, i ...int) {
if !ms.Mask.Bool(i...) {
return
}
ms.Tensor.SetInt(val, i...)
}
func (ms *Masked) Int1D(i int) int {
if !ms.Mask.Bool1D(i) {
return 0
}
return ms.Tensor.Int1D(i)
}
// SetInt1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (ms *Masked) SetInt1D(val int, i int) {
if !ms.Mask.Bool1D(i) {
return
}
ms.Tensor.SetInt1D(val, i)
}
// Filter sets the mask values using given Filter function.
// The filter function gets the 1D index into the source tensor.
func (ms *Masked) Filter(filterer func(tsr Tensor, idx int) bool) {
n := ms.Tensor.Len()
for i := range n {
ms.Mask.SetBool1D(filterer(ms.Tensor, i), i)
}
}
// check for interface impl
var _ Tensor = (*Masked)(nil)
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"fmt"
"strconv"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/num"
"cogentcore.org/core/base/reflectx"
)
// Number is a tensor of numerical values
type Number[T num.Number] struct {
Base[T]
}
// Float64 is an alias for Number[float64].
type Float64 = Number[float64]
// Float32 is an alias for Number[float32].
type Float32 = Number[float32]
// Int is an alias for Number[int].
type Int = Number[int]
// Int32 is an alias for Number[int32].
type Int32 = Number[int32]
// Uint32 is an alias for Number[uint32].
type Uint32 = Number[uint32]
// Byte is an alias for Number[byte].
type Byte = Number[byte]
// NewFloat32 returns a new [Float32] tensor
// with the given sizes per dimension (shape).
func NewFloat32(sizes ...int) *Float32 {
return New[float32](sizes...).(*Float32)
}
// NewFloat64 returns a new [Float64] tensor
// with the given sizes per dimension (shape).
func NewFloat64(sizes ...int) *Float64 {
return New[float64](sizes...).(*Float64)
}
// NewInt returns a new Int tensor
// with the given sizes per dimension (shape).
func NewInt(sizes ...int) *Int {
return New[int](sizes...).(*Int)
}
// NewInt32 returns a new Int32 tensor
// with the given sizes per dimension (shape).
func NewInt32(sizes ...int) *Int32 {
return New[int32](sizes...).(*Int32)
}
// NewUint32 returns a new Uint32 tensor
// with the given sizes per dimension (shape).
func NewUint32(sizes ...int) *Uint32 {
return New[uint32](sizes...).(*Uint32)
}
// NewByte returns a new Byte tensor
// with the given sizes per dimension (shape).
func NewByte(sizes ...int) *Byte {
return New[uint8](sizes...).(*Byte)
}
// NewNumber returns a new n-dimensional tensor of numerical values
// with the given sizes per dimension (shape).
func NewNumber[T num.Number](sizes ...int) *Number[T] {
tsr := &Number[T]{}
tsr.SetShapeSizes(sizes...)
tsr.Values = make([]T, tsr.Len())
return tsr
}
// NewNumberShape returns a new n-dimensional tensor of numerical values
// using given shape.
func NewNumberShape[T num.Number](shape *Shape) *Number[T] {
tsr := &Number[T]{}
tsr.shape.CopyFrom(shape)
tsr.Values = make([]T, tsr.Len())
return tsr
}
// todo: this should in principle work with yaegi:add but it is crashing
// will come back to it later.
// NewNumberFromValues returns a new 1-dimensional tensor of given value type
// initialized directly from the given slice values, which are not copied.
// The resulting Tensor thus "wraps" the given values.
func NewNumberFromValues[T num.Number](vals ...T) *Number[T] {
n := len(vals)
tsr := &Number[T]{}
tsr.Values = vals
tsr.SetShapeSizes(n)
return tsr
}
// String satisfies the fmt.Stringer interface for string of tensor data.
func (tsr *Number[T]) String() string { return Sprintf("", tsr, 0) }
func (tsr *Number[T]) IsString() bool { return false }
func (tsr *Number[T]) AsValues() Values { return tsr }
func (tsr *Number[T]) SetAdd(val T, i ...int) {
tsr.Values[tsr.shape.IndexTo1D(i...)] += val
}
func (tsr *Number[T]) SetSub(val T, i ...int) {
tsr.Values[tsr.shape.IndexTo1D(i...)] -= val
}
func (tsr *Number[T]) SetMul(val T, i ...int) {
tsr.Values[tsr.shape.IndexTo1D(i...)] *= val
}
func (tsr *Number[T]) SetDiv(val T, i ...int) {
tsr.Values[tsr.shape.IndexTo1D(i...)] /= val
}
/////// Strings
func (tsr *Number[T]) SetString(val string, i ...int) {
if fv, err := strconv.ParseFloat(val, 64); err == nil {
tsr.Values[tsr.shape.IndexTo1D(i...)] = T(fv)
}
}
func (tsr Number[T]) SetString1D(val string, i int) {
if fv, err := strconv.ParseFloat(val, 64); err == nil {
tsr.Values[i] = T(fv)
}
}
func (tsr *Number[T]) SetStringRow(val string, row, cell int) {
if fv, err := strconv.ParseFloat(val, 64); err == nil {
_, sz := tsr.shape.RowCellSize()
tsr.Values[row*sz+cell] = T(fv)
}
}
// AppendRowString adds a row and sets string value(s), up to number of cells.
func (tsr *Number[T]) AppendRowString(val ...string) {
if tsr.NumDims() == 0 {
tsr.SetShapeSizes(0)
}
nrow, sz := tsr.shape.RowCellSize()
tsr.SetNumRows(nrow + 1)
mx := min(sz, len(val))
for i := range mx {
tsr.SetStringRow(val[i], nrow, i)
}
}
/////// Floats
func (tsr *Number[T]) Float(i ...int) float64 {
return float64(tsr.Values[tsr.shape.IndexTo1D(i...)])
}
func (tsr *Number[T]) SetFloat(val float64, i ...int) {
tsr.Values[tsr.shape.IndexTo1D(i...)] = T(val)
}
func (tsr *Number[T]) Float1D(i int) float64 {
return float64(tsr.Values[NegIndex(i, len(tsr.Values))])
}
func (tsr *Number[T]) SetFloat1D(val float64, i int) {
tsr.Values[NegIndex(i, len(tsr.Values))] = T(val)
}
func (tsr *Number[T]) FloatRow(row, cell int) float64 {
_, sz := tsr.shape.RowCellSize()
i := row*sz + cell
return float64(tsr.Values[NegIndex(i, len(tsr.Values))])
}
func (tsr *Number[T]) SetFloatRow(val float64, row, cell int) {
_, sz := tsr.shape.RowCellSize()
tsr.Values[row*sz+cell] = T(val)
}
// AppendRowFloat adds a row and sets float value(s), up to number of cells.
func (tsr *Number[T]) AppendRowFloat(val ...float64) {
if tsr.NumDims() == 0 {
tsr.SetShapeSizes(0)
}
nrow, sz := tsr.shape.RowCellSize()
tsr.SetNumRows(nrow + 1)
mx := min(sz, len(val))
for i := range mx {
tsr.SetFloatRow(val[i], nrow, i)
}
}
/////// Ints
func (tsr *Number[T]) Int(i ...int) int {
return int(tsr.Values[tsr.shape.IndexTo1D(i...)])
}
func (tsr *Number[T]) SetInt(val int, i ...int) {
tsr.Values[tsr.shape.IndexTo1D(i...)] = T(val)
}
func (tsr *Number[T]) Int1D(i int) int {
return int(tsr.Values[NegIndex(i, len(tsr.Values))])
}
func (tsr *Number[T]) SetInt1D(val int, i int) {
tsr.Values[NegIndex(i, len(tsr.Values))] = T(val)
}
func (tsr *Number[T]) IntRow(row, cell int) int {
_, sz := tsr.shape.RowCellSize()
i := row*sz + cell
return int(tsr.Values[i])
}
func (tsr *Number[T]) SetIntRow(val int, row, cell int) {
_, sz := tsr.shape.RowCellSize()
tsr.Values[row*sz+cell] = T(val)
}
// AppendRowInt adds a row and sets int value(s), up to number of cells.
func (tsr *Number[T]) AppendRowInt(val ...int) {
if tsr.NumDims() == 0 {
tsr.SetShapeSizes(0)
}
nrow, sz := tsr.shape.RowCellSize()
tsr.SetNumRows(nrow + 1)
mx := min(sz, len(val))
for i := range mx {
tsr.SetIntRow(val[i], nrow, i)
}
}
// SetZeros is simple convenience function initialize all values to 0
func (tsr *Number[T]) SetZeros() {
for j := range tsr.Values {
tsr.Values[j] = 0
}
}
// Clone clones this tensor, creating a duplicate copy of itself with its
// own separate memory representation of all the values.
func (tsr *Number[T]) Clone() Values {
csr := NewNumberShape[T](&tsr.shape)
copy(csr.Values, tsr.Values)
return csr
}
// CopyFrom copies all avail values from other tensor into this tensor, with an
// optimized implementation if the other tensor is of the same type, and
// otherwise it goes through appropriate standard type.
func (tsr *Number[T]) CopyFrom(frm Values) {
if fsm, ok := frm.(*Number[T]); ok {
copy(tsr.Values, fsm.Values)
return
}
sz := min(tsr.Len(), frm.Len())
if reflectx.KindIsInt(tsr.DataType()) {
for i := range sz {
tsr.Values[i] = T(frm.Int1D(i))
}
} else {
for i := range sz {
tsr.Values[i] = T(frm.Float1D(i))
}
}
}
// AppendFrom appends values from other tensor into this tensor,
// which must have the same cell size as this tensor.
// It uses and optimized implementation if the other tensor
// is of the same type, and otherwise it goes through
// appropriate standard type.
func (tsr *Number[T]) AppendFrom(frm Values) Values {
rows, cell := tsr.shape.RowCellSize()
frows, fcell := frm.Shape().RowCellSize()
if cell != fcell {
errors.Log(fmt.Errorf("tensor.AppendFrom: cell sizes do not match: %d != %d", cell, fcell))
return tsr
}
tsr.SetNumRows(rows + frows)
st := rows * cell
fsz := frows * fcell
if fsm, ok := frm.(*Number[T]); ok {
copy(tsr.Values[st:st+fsz], fsm.Values)
return tsr
}
for i := 0; i < fsz; i++ {
tsr.Values[st+i] = T(frm.Float1D(i))
}
return tsr
}
// CopyCellsFrom copies given range of values from other tensor into this tensor,
// using flat 1D indexes: to = starting index in this Tensor to start copying into,
// start = starting index on from Tensor to start copying from, and n = number of
// values to copy. Uses an optimized implementation if the other tensor is
// of the same type, and otherwise it goes through appropriate standard type.
func (tsr *Number[T]) CopyCellsFrom(frm Values, to, start, n int) {
if fsm, ok := frm.(*Number[T]); ok {
copy(tsr.Values[to:to+n], fsm.Values[start:start+n])
return
}
for i := range n {
tsr.Values[to+i] = T(frm.Float1D(start + i))
}
}
// SubSpace returns a new tensor with innermost subspace at given
// offset(s) in outermost dimension(s) (len(offs) < NumDims).
// The new tensor points to the values of the this tensor (i.e., modifications
// will affect both), as its Values slice is a view onto the original (which
// is why only inner-most contiguous supsaces are supported).
// Use AsValues() method to separate the two.
func (tsr *Number[T]) SubSpace(offs ...int) Values {
b := tsr.subSpaceImpl(offs...)
rt := &Number[T]{Base: *b}
return rt
}
// RowTensor is a convenience version of [RowMajor.SubSpace] to return the
// SubSpace for the outermost row dimension. [Rows] defines a version
// of this that indirects through the row indexes.
func (tsr *Number[T]) RowTensor(row int) Values {
return tsr.SubSpace(row)
}
// SetRowTensor sets the values of the SubSpace at given row to given values.
func (tsr *Number[T]) SetRowTensor(val Values, row int) {
_, cells := tsr.shape.RowCellSize()
st := row * cells
mx := min(val.Len(), cells)
tsr.CopyCellsFrom(val, st, 0, mx)
}
// AppendRow adds a row and sets values to given values.
func (tsr *Number[T]) AppendRow(val Values) {
if tsr.NumDims() == 0 {
tsr.SetShapeSizes(0)
}
nrow := tsr.DimSize(0)
tsr.SetNumRows(nrow + 1)
tsr.SetRowTensor(val, nrow)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
const (
// OnedRow is for onedRow arguments to Projection2D functions,
// specifies that the 1D case goes along the row.
OnedRow = true
// OnedColumn is for onedRow arguments to Projection2D functions,
// specifies that the 1D case goes along the column.
OnedColumn = false
)
// Projection2DShape returns the size of a 2D projection of the given tensor Shape,
// collapsing higher dimensions down to 2D (and 1D up to 2D).
// For the 1D case, onedRow determines if the values are row-wise or not.
// Even multiples of inner-most dimensions are placed along the row, odd in the column.
// If there are an odd number of dimensions, the first dimension is row-wise, and
// the remaining inner dimensions use the above logic from there, as if it was even.
// rowEx returns the number of "extra" (outer-dimensional) rows
// and colEx returns the number of extra cols, to add extra spacing between these dimensions.
func Projection2DShape(shp *Shape, onedRow bool) (rows, cols, rowEx, colEx int) {
if shp.Len() == 0 {
return 1, 1, 0, 0
}
nd := shp.NumDims()
if nd == 1 {
if onedRow {
return shp.DimSize(0), 1, 0, 0
}
return 1, shp.DimSize(0), 0, 0
}
if nd == 2 {
return shp.DimSize(0), shp.DimSize(1), 0, 0
}
rowShape, colShape, rowIdxs, colIdxs := Projection2DDimShapes(shp, onedRow)
rows = rowShape.Len()
cols = colShape.Len()
nri := len(rowIdxs)
if nri > 1 {
rowEx = 1
for i := range nri - 1 {
rowEx *= shp.DimSize(rowIdxs[i])
}
}
nci := len(colIdxs)
if nci > 1 {
colEx = 1
for i := range nci - 1 {
colEx *= shp.DimSize(colIdxs[i])
}
}
return
}
// Projection2DDimShapes returns the shapes and dimension indexes for a 2D projection
// of given tensor Shape, collapsing higher dimensions down to 2D (and 1D up to 2D).
// For the 1D case, onedRow determines if the values are row-wise or not.
// Even multiples of inner-most dimensions are placed along the row, odd in the column.
// If there are an odd number of dimensions, the first dimension is row-wise, and
// the remaining inner dimensions use the above logic from there, as if it was even.
// This is the main organizing function for all Projection2D calls.
func Projection2DDimShapes(shp *Shape, onedRow bool) (rowShape, colShape *Shape, rowIdxs, colIdxs []int) {
nd := shp.NumDims()
if nd == 1 {
if onedRow {
return NewShape(shp.DimSize(0)), NewShape(1), []int{0}, nil
}
return NewShape(1), NewShape(shp.DimSize(0)), nil, []int{0}
}
if nd == 2 {
return NewShape(shp.DimSize(0)), NewShape(shp.DimSize(1)), []int{0}, []int{1}
}
var rs, cs []int
odd := nd%2 == 1
sd := 0
end := nd
if odd {
end = nd - 1
sd = 1
rs = []int{shp.DimSize(0)}
rowIdxs = []int{0}
}
for d := range end {
ad := d + sd
if d%2 == 0 { // even goes to row
rs = append(rs, shp.DimSize(ad))
rowIdxs = append(rowIdxs, ad)
} else {
cs = append(cs, shp.DimSize(ad))
colIdxs = append(colIdxs, ad)
}
}
rowShape = NewShape(rs...)
colShape = NewShape(cs...)
return
}
// Projection2DIndex returns the flat 1D index for given row, col coords for a 2D projection
// of the given tensor shape, collapsing higher dimensions down to 2D (and 1D up to 2D).
// See [Projection2DShape] for full info.
func Projection2DIndex(shp *Shape, onedRow bool, row, col int) int {
if shp.Len() == 0 {
return 0
}
nd := shp.NumDims()
if nd == 1 {
if onedRow {
return row
}
return col
}
if nd == 2 {
return shp.IndexTo1D(row, col)
}
rowShape, colShape, rowIdxs, colIdxs := Projection2DDimShapes(shp, onedRow)
ris := rowShape.IndexFrom1D(row)
cis := colShape.IndexFrom1D(col)
ixs := make([]int, nd)
for i, ri := range rowIdxs {
ixs[ri] = ris[i]
}
for i, ci := range colIdxs {
ixs[ci] = cis[i]
}
return shp.IndexTo1D(ixs...)
}
// Projection2DCoords returns the corresponding full-dimensional coordinates
// that go into the given row, col coords for a 2D projection of the given tensor,
// collapsing higher dimensions down to 2D (and 1D up to 2D).
// See [Projection2DShape] for full info.
func Projection2DCoords(shp *Shape, onedRow bool, row, col int) (rowCoords, colCoords []int) {
if shp.Len() == 0 {
return []int{0}, []int{0}
}
idx := Projection2DIndex(shp, onedRow, row, col)
dims := shp.IndexFrom1D(idx)
nd := shp.NumDims()
if nd == 1 {
if onedRow {
return dims, []int{0}
}
return []int{0}, dims
}
if nd == 2 {
return dims[:1], dims[1:]
}
_, _, rowIdxs, colIdxs := Projection2DDimShapes(shp, onedRow)
rowCoords = make([]int, len(rowIdxs))
colCoords = make([]int, len(colIdxs))
for i, ri := range rowIdxs {
rowCoords[i] = dims[ri]
}
for i, ci := range colIdxs {
colCoords[i] = dims[ci]
}
return
}
// Projection2DValue returns the float64 value at given row, col coords for a 2D projection
// of the given tensor, collapsing higher dimensions down to 2D (and 1D up to 2D).
// See [Projection2DShape] for full info.
func Projection2DValue(tsr Tensor, onedRow bool, row, col int) float64 {
idx := Projection2DIndex(tsr.Shape(), onedRow, row, col)
return tsr.Float1D(idx)
}
// Projection2DString returns the string value at given row, col coords for a 2D projection
// of the given tensor, collapsing higher dimensions down to 2D (and 1D up to 2D).
// See [Projection2DShape] for full info.
func Projection2DString(tsr Tensor, onedRow bool, row, col int) string {
idx := Projection2DIndex(tsr.Shape(), onedRow, row, col)
return tsr.String1D(idx)
}
// Projection2DSet sets a float64 value at given row, col coords for a 2D projection
// of the given tensor, collapsing higher dimensions down to 2D (and 1D up to 2D).
// See [Projection2DShape] for full info.
func Projection2DSet(tsr Tensor, onedRow bool, row, col int, val float64) {
idx := Projection2DIndex(tsr.Shape(), onedRow, row, col)
tsr.SetFloat1D(val, idx)
}
// Projection2DSetString sets a string value at given row, col coords for a 2D projection
// of the given tensor, collapsing higher dimensions down to 2D (and 1D up to 2D).
// See [Projection2DShape] for full info.
func Projection2DSetString(tsr Tensor, onedRow bool, row, col int, val string) {
idx := Projection2DIndex(tsr.Shape(), onedRow, row, col)
tsr.SetString1D(val, idx)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"reflect"
"slices"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/metadata"
)
// Reshaped is a reshaping wrapper around another "source" [Tensor],
// that provides a length-preserving reshaped view onto the source Tensor.
// Reshaping by adding new size=1 dimensions (via [NewAxis] value) is
// often important for properly aligning two tensors in a computationally
// compatible manner; see the [AlignShapes] function.
// [Reshaped.AsValues] on this view returns a new [Values] with the view
// shape, calling [Clone] on the source tensor to get the values.
type Reshaped struct { //types:add
// Tensor source that we are a masked view onto.
Tensor Tensor
// Reshape is the effective shape we use for access.
// This must have the same Len() as the source Tensor.
Reshape Shape
}
// NewReshaped returns a new [Reshaped] view of given tensor, with given shape
// sizes. If no such sizes are provided, the source shape is used.
// A single -1 value can be used to automatically specify the remaining tensor
// length, as long as the other sizes are an even multiple of the total length.
// A single -1 returns a 1D view of the entire tensor.
func NewReshaped(tsr Tensor, sizes ...int) *Reshaped {
rs := &Reshaped{Tensor: tsr}
if len(sizes) == 0 {
rs.Reshape.CopyFrom(tsr.Shape())
} else {
errors.Log(rs.SetShapeSizes(sizes...))
}
return rs
}
// Reshape returns a view of the given tensor with given shape sizes.
// A single -1 value can be used to automatically specify the remaining tensor
// length, as long as the other sizes are an even multiple of the total length.
// A single -1 returns a 1D view of the entire tensor.
func Reshape(tsr Tensor, sizes ...int) Tensor {
if len(sizes) == 0 {
err := errors.New("tensor.Reshape: must pass shape sizes")
errors.Log(err)
return tsr
}
if len(sizes) == 1 {
sz := sizes[0]
if sz == -1 {
return As1D(tsr)
}
}
rs := &Reshaped{Tensor: tsr}
errors.Log(rs.SetShapeSizes(sizes...))
return rs
}
// Transpose returns a new [Reshaped] tensor with the strides
// switched so that rows and column dimensions are effectively
// reversed.
func Transpose(tsr Tensor) Tensor {
rs := &Reshaped{Tensor: tsr}
rs.Reshape.CopyFrom(tsr.Shape())
rs.Reshape.Strides = ColumnMajorStrides(rs.Reshape.Sizes...)
return rs
}
// NewRowCellsView returns a 2D [Reshaped] view onto the given tensor,
// with a single outer "row" dimension and a single inner "cells" dimension,
// with the given 'split' dimension specifying where the cells start.
// All dimensions prior to split are collapsed to form the new outer row dimension,
// and the remainder are collapsed to form the 1D cells dimension.
// This is useful for stats, metrics and other packages that operate
// on data in this shape.
func NewRowCellsView(tsr Tensor, split int) *Reshaped {
sizes := tsr.ShapeSizes()
rows := sizes[:split]
cells := sizes[split:]
nr := 1
for _, r := range rows {
nr *= r
}
nc := 1
for _, c := range cells {
nc *= c
}
return NewReshaped(tsr, nr, nc)
}
// AsReshaped returns the tensor as a [Reshaped] view.
// If it already is one, then it is returned, otherwise it is wrapped
// with an initial shape equal to the source tensor.
func AsReshaped(tsr Tensor) *Reshaped {
if rs, ok := tsr.(*Reshaped); ok {
return rs
}
return NewReshaped(tsr)
}
// SetShapeSizes sets our shape sizes to the given values, which must result in
// the same length as the source tensor. An error is returned if not.
// If a different subset of content is desired, use another view such as [Sliced].
// Note that any number of size = 1 dimensions can be added without affecting
// the length, and the [NewAxis] value can be used to semantically
// indicate when such a new dimension is being inserted. This is often useful
// for aligning two tensors to achieve a desired computation; see [AlignShapes]
// function. A single -1 can be used to specify a dimension size that takes the
// remaining length, as long as the other sizes are an even multiple of the length.
// A single -1 indicates to use the full length.
func (rs *Reshaped) SetShapeSizes(sizes ...int) error {
sln := rs.Tensor.Len()
if sln == 0 {
return nil
}
if sln == 1 {
sz := sizes[0]
if sz < 0 {
rs.Reshape.SetShapeSizes(sln)
return nil
}
}
sz := slices.Clone(sizes)
ln := 1
negIdx := -1
for i, s := range sz {
if s < 0 {
negIdx = i
} else {
ln *= s
}
}
if negIdx >= 0 {
if sln%ln != 0 {
return errors.New("tensor.Reshaped SetShapeSizes: -1 cannot be used because the remaining dimensions are not an even multiple of the source tensor length")
}
sz[negIdx] = sln / ln
}
rs.Reshape.SetShapeSizes(sz...)
if rs.Reshape.Len() != sln {
return errors.New("tensor.Reshaped SetShapeSizes: new length is different from source tensor; use Sliced or other views to change view content")
}
return nil
}
func (rs *Reshaped) Label() string { return label(metadata.Name(rs), rs.Shape()) }
func (rs *Reshaped) String() string { return Sprintf("", rs, 0) }
func (rs *Reshaped) Metadata() *metadata.Data { return rs.Tensor.Metadata() }
func (rs *Reshaped) IsString() bool { return rs.Tensor.IsString() }
func (rs *Reshaped) DataType() reflect.Kind { return rs.Tensor.DataType() }
func (rs *Reshaped) ShapeSizes() []int { return slices.Clone(rs.Reshape.Sizes) }
func (rs *Reshaped) Shape() *Shape { return &rs.Reshape }
func (rs *Reshaped) Len() int { return rs.Reshape.Len() }
func (rs *Reshaped) NumDims() int { return rs.Reshape.NumDims() }
func (rs *Reshaped) DimSize(dim int) int { return rs.Reshape.DimSize(dim) }
// AsValues returns a copy of this tensor as raw [Values], with
// the same shape as our view. This calls [Clone] on the source
// tensor to get the Values and then sets our shape sizes to it.
func (rs *Reshaped) AsValues() Values {
vals := Clone(rs.Tensor)
vals.SetShapeSizes(rs.Reshape.Sizes...)
return vals
}
//////// Floats
func (rs *Reshaped) Float(i ...int) float64 {
return rs.Tensor.Float1D(rs.Reshape.IndexTo1D(i...))
}
func (rs *Reshaped) SetFloat(val float64, i ...int) {
rs.Tensor.SetFloat1D(val, rs.Reshape.IndexTo1D(i...))
}
func (rs *Reshaped) Float1D(i int) float64 { return rs.Tensor.Float1D(i) }
func (rs *Reshaped) SetFloat1D(val float64, i int) { rs.Tensor.SetFloat1D(val, i) }
//////// Strings
func (rs *Reshaped) StringValue(i ...int) string {
return rs.Tensor.String1D(rs.Reshape.IndexTo1D(i...))
}
func (rs *Reshaped) SetString(val string, i ...int) {
rs.Tensor.SetString1D(val, rs.Reshape.IndexTo1D(i...))
}
func (rs *Reshaped) String1D(i int) string { return rs.Tensor.String1D(i) }
func (rs *Reshaped) SetString1D(val string, i int) { rs.Tensor.SetString1D(val, i) }
//////// Ints
func (rs *Reshaped) Int(i ...int) int {
return rs.Tensor.Int1D(rs.Reshape.IndexTo1D(i...))
}
func (rs *Reshaped) SetInt(val int, i ...int) {
rs.Tensor.SetInt1D(val, rs.Reshape.IndexTo1D(i...))
}
func (rs *Reshaped) Int1D(i int) int { return rs.Tensor.Int1D(i) }
func (rs *Reshaped) SetInt1D(val int, i int) { rs.Tensor.SetInt1D(val, i) }
// check for interface impl
var _ Tensor = (*Reshaped)(nil)
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"cmp"
"math"
"math/rand"
"reflect"
"slices"
"sort"
"strings"
"cogentcore.org/core/base/metadata"
)
// Rows is a row-indexed wrapper view around a [Values] [Tensor] that allows
// arbitrary row-wise ordering and filtering according to the [Rows.Indexes].
// Sorting and filtering a tensor along this outermost row dimension only
// requires updating the indexes while leaving the underlying Tensor alone.
// Unlike the more general [Sliced] view, Rows maintains memory contiguity
// for the inner dimensions ("cells") within each row, and supports the [RowMajor]
// interface, with the [Set]FloatRow[Cell] methods providing efficient access.
// Use [Rows.AsValues] to obtain a concrete [Values] representation with the
// current row sorting.
type Rows struct { //types:add
// Tensor source that we are an indexed view onto.
// Note that this must be a concrete [Values] tensor, to enable efficient
// [RowMajor] access and subspace functions.
Tensor Values
// Indexes are the indexes into Tensor rows, with nil = sequential.
// Only set if order is different from default sequential order.
// Use the [Rows.RowIndex] method for nil-aware logic.
Indexes []int
}
// NewRows returns a new [Rows] view of given tensor,
// with optional list of indexes (none / nil = sequential).
func NewRows(tsr Values, idxs ...int) *Rows {
rw := &Rows{Tensor: tsr, Indexes: slices.Clone(idxs)}
return rw
}
// AsRows returns the tensor as a [Rows] view.
// If it already is one, then it is returned, otherwise
// a new Rows is created to wrap around the given tensor, which is
// enforced to be a [Values] tensor either because it already is one,
// or by calling [Tensor.AsValues] on it.
func AsRows(tsr Tensor) *Rows {
if rw, ok := tsr.(*Rows); ok {
return rw
}
return NewRows(tsr.AsValues())
}
// SetTensor sets as indexes into given [Values] tensor with sequential initial indexes.
func (rw *Rows) SetTensor(tsr Values) {
rw.Tensor = tsr
rw.Sequential()
}
func (rw *Rows) IsString() bool { return rw.Tensor.IsString() }
func (rw *Rows) DataType() reflect.Kind { return rw.Tensor.DataType() }
// RowIndex returns the actual index into underlying tensor row based on given
// index value. If Indexes == nil, index is passed through.
func (rw *Rows) RowIndex(idx int) int {
if rw.Indexes == nil {
return idx
}
return rw.Indexes[idx]
}
// NumRows returns the effective number of rows in this Rows view,
// which is the length of the index list or number of outer
// rows dimension of tensor if no indexes (full sequential view).
func (rw *Rows) NumRows() int {
if rw.Indexes == nil {
return rw.Tensor.DimSize(0)
}
return len(rw.Indexes)
}
func (rw *Rows) String() string { return Sprintf("", rw.Tensor, 0) }
func (rw *Rows) Label() string { return rw.Tensor.Label() }
func (rw *Rows) Metadata() *metadata.Data { return rw.Tensor.Metadata() }
func (rw *Rows) NumDims() int { return rw.Tensor.NumDims() }
// If we have Indexes, this is the effective shape sizes using
// the current number of indexes as the outermost row dimension size.
func (rw *Rows) ShapeSizes() []int {
if rw.Indexes == nil || rw.Tensor.NumDims() == 0 {
return rw.Tensor.ShapeSizes()
}
sh := slices.Clone(rw.Tensor.ShapeSizes())
sh[0] = len(rw.Indexes)
return sh
}
// Shape() returns a [Shape] representation of the tensor shape
// (dimension sizes). If we have Indexes, this is the effective
// shape using the current number of indexes as the outermost row dimension size.
func (rw *Rows) Shape() *Shape {
if rw.Indexes == nil {
return rw.Tensor.Shape()
}
return NewShape(rw.ShapeSizes()...)
}
// Len returns the total number of elements in the tensor,
// taking into account the Indexes via [Rows],
// as NumRows() * cell size.
func (rw *Rows) Len() int {
rows := rw.NumRows()
_, cells := rw.Tensor.Shape().RowCellSize()
return cells * rows
}
// DimSize returns size of given dimension, returning NumRows()
// for first dimension.
func (rw *Rows) DimSize(dim int) int {
if dim == 0 {
return rw.NumRows()
}
return rw.Tensor.DimSize(dim)
}
// RowCellSize returns the size of the outermost Row shape dimension
// (via [Rows.NumRows] method), and the size of all the remaining
// inner dimensions (the "cell" size).
func (rw *Rows) RowCellSize() (rows, cells int) {
_, cells = rw.Tensor.Shape().RowCellSize()
rows = rw.NumRows()
return
}
// ValidIndexes deletes all invalid indexes from the list.
// Call this if rows (could) have been deleted from tensor.
func (rw *Rows) ValidIndexes() {
if rw.Tensor.DimSize(0) <= 0 || rw.Indexes == nil {
rw.Indexes = nil
return
}
ni := rw.NumRows()
for i := ni - 1; i >= 0; i-- {
if rw.Indexes[i] >= rw.Tensor.DimSize(0) {
rw.Indexes = append(rw.Indexes[:i], rw.Indexes[i+1:]...)
}
}
}
// Sequential sets Indexes to nil, resulting in sequential row-wise access into tensor.
func (rw *Rows) Sequential() { //types:add
rw.Indexes = nil
}
// IndexesNeeded is called prior to an operation that needs actual indexes,
// e.g., Sort, Filter. If Indexes == nil, they are set to all rows, otherwise
// current indexes are left as is. Use Sequential, then IndexesNeeded to ensure
// all rows are represented.
func (rw *Rows) IndexesNeeded() {
if rw.Tensor.DimSize(0) <= 0 {
rw.Indexes = nil
return
}
if rw.Indexes != nil {
return
}
rw.Indexes = make([]int, rw.Tensor.DimSize(0))
for i := range rw.Indexes {
rw.Indexes[i] = i
}
}
// ExcludeMissing deletes indexes where the values are missing, as indicated by NaN.
// Uses first cell of higher dimensional data.
func (rw *Rows) ExcludeMissing() { //types:add
if rw.Tensor.DimSize(0) <= 0 {
rw.Indexes = nil
return
}
rw.IndexesNeeded()
ni := rw.NumRows()
for i := ni - 1; i >= 0; i-- {
if math.IsNaN(rw.Tensor.FloatRow(rw.Indexes[i], 0)) {
rw.Indexes = append(rw.Indexes[:i], rw.Indexes[i+1:]...)
}
}
}
// Permuted sets indexes to a permuted order. If indexes already exist
// then existing list of indexes is permuted, otherwise a new set of
// permuted indexes are generated
func (rw *Rows) Permuted() {
if rw.Tensor.DimSize(0) <= 0 {
rw.Indexes = nil
return
}
if rw.Indexes == nil {
rw.Indexes = rand.Perm(rw.Tensor.DimSize(0))
} else {
rand.Shuffle(len(rw.Indexes), func(i, j int) {
rw.Indexes[i], rw.Indexes[j] = rw.Indexes[j], rw.Indexes[i]
})
}
}
const (
// Ascending specifies an ascending sort direction for tensor Sort routines
Ascending = true
// Descending specifies a descending sort direction for tensor Sort routines
Descending = false
// StableSort specifies using stable, original order-preserving sort, which is slower.
StableSort = true
// Unstable specifies using faster but unstable sorting.
UnstableSort = false
)
// SortFunc sorts the row-wise indexes using given compare function.
// The compare function operates directly on row numbers into the Tensor
// as these row numbers have already been projected through the indexes.
// cmp(a, b) should return a negative number when a < b, a positive
// number when a > b and zero when a == b.
func (rw *Rows) SortFunc(cmp func(tsr Values, i, j int) int) {
rw.IndexesNeeded()
slices.SortFunc(rw.Indexes, func(a, b int) int {
return cmp(rw.Tensor, a, b) // key point: these are already indirected through indexes!!
})
}
// SortIndexes sorts the indexes into our Tensor directly in
// numerical order, producing the native ordering, while preserving
// any filtering that might have occurred.
func (rw *Rows) SortIndexes() {
if rw.Indexes == nil {
return
}
sort.Ints(rw.Indexes)
}
// CompareAscending is a sort compare function that reverses direction
// based on the ascending bool.
func CompareAscending[T cmp.Ordered](a, b T, ascending bool) int {
if ascending {
return cmp.Compare(a, b)
}
return cmp.Compare(b, a)
}
// Sort does default alpha or numeric sort of row-wise data.
// Uses first cell of higher dimensional data.
func (rw *Rows) Sort(ascending bool) {
if rw.Tensor.IsString() {
rw.SortFunc(func(tsr Values, i, j int) int {
return CompareAscending(tsr.StringRow(i, 0), tsr.StringRow(j, 0), ascending)
})
} else {
rw.SortFunc(func(tsr Values, i, j int) int {
return CompareAscending(tsr.FloatRow(i, 0), tsr.FloatRow(j, 0), ascending)
})
}
}
// SortStableFunc stably sorts the row-wise indexes using given compare function.
// The compare function operates directly on row numbers into the Tensor
// as these row numbers have already been projected through the indexes.
// cmp(a, b) should return a negative number when a < b, a positive
// number when a > b and zero when a == b.
// It is *essential* that it always returns 0 when the two are equal
// for the stable function to actually work.
func (rw *Rows) SortStableFunc(cmp func(tsr Values, i, j int) int) {
rw.IndexesNeeded()
slices.SortStableFunc(rw.Indexes, func(a, b int) int {
return cmp(rw.Tensor, a, b) // key point: these are already indirected through indexes!!
})
}
// SortStable does stable default alpha or numeric sort.
// Uses first cell of higher dimensional data.
func (rw *Rows) SortStable(ascending bool) {
if rw.Tensor.IsString() {
rw.SortStableFunc(func(tsr Values, i, j int) int {
return CompareAscending(tsr.StringRow(i, 0), tsr.StringRow(j, 0), ascending)
})
} else {
rw.SortStableFunc(func(tsr Values, i, j int) int {
return CompareAscending(tsr.FloatRow(i, 0), tsr.FloatRow(j, 0), ascending)
})
}
}
// FilterFunc is a function used for filtering that returns
// true if Tensor row should be included in the current filtered
// view of the tensor, and false if it should be removed.
type FilterFunc func(tsr Values, row int) bool
// Filter filters the indexes using given Filter function.
// The Filter function operates directly on row numbers into the Tensor
// as these row numbers have already been projected through the indexes.
func (rw *Rows) Filter(filterer func(tsr Values, row int) bool) {
rw.IndexesNeeded()
sz := len(rw.Indexes)
for i := sz - 1; i >= 0; i-- { // always go in reverse for filtering
if !filterer(rw.Tensor, rw.Indexes[i]) { // delete
rw.Indexes = append(rw.Indexes[:i], rw.Indexes[i+1:]...)
}
}
}
// FilterOptions are options to a Filter function
// determining how the string filter value is used for matching.
type FilterOptions struct { //types:add
// Exclude means to exclude matches,
// with the default (false) being to include
Exclude bool
// Contains means the string only needs to contain the target string,
// with the default (false) requiring a complete match to entire string.
Contains bool
// IgnoreCase means that differences in case are ignored in comparing strings,
// with the default (false) using case.
IgnoreCase bool
}
// FilterString filters the indexes using string values compared to given
// string. Includes rows with matching values unless the Exclude option is set.
// If Contains option is set, it only checks if row contains string;
// if IgnoreCase, ignores case, otherwise filtering is case sensitive.
// Uses first cell of higher dimensional data.
func (rw *Rows) FilterString(str string, opts FilterOptions) { //types:add
lowstr := strings.ToLower(str)
rw.Filter(func(tsr Values, row int) bool {
val := tsr.StringRow(row, 0)
has := false
switch {
case opts.Contains && opts.IgnoreCase:
has = strings.Contains(strings.ToLower(val), lowstr)
case opts.Contains:
has = strings.Contains(val, str)
case opts.IgnoreCase:
has = strings.EqualFold(val, str)
default:
has = (val == str)
}
if opts.Exclude {
return !has
}
return has
})
}
// AsValues returns this tensor as raw [Values].
// If the row [Rows.Indexes] are nil, then the wrapped Values tensor
// is returned. Otherwise, it "renders" the Rows view into a fully contiguous
// and optimized memory representation of that view, which will be faster
// to access for further processing, and enables all the additional
// functionality provided by the [Values] interface.
func (rw *Rows) AsValues() Values {
if rw.Indexes == nil {
return rw.Tensor
}
vt := NewOfType(rw.Tensor.DataType(), rw.ShapeSizes()...)
rows := rw.NumRows()
for r := range rows {
vt.SetRowTensor(rw.RowTensor(r), r)
}
return vt
}
// CloneIndexes returns a copy of the current Rows view with new indexes,
// with a pointer to the same underlying Tensor as the source.
func (rw *Rows) CloneIndexes() *Rows {
nix := &Rows{}
nix.Tensor = rw.Tensor
nix.CopyIndexes(rw)
return nix
}
// CopyIndexes copies indexes from other Rows view.
func (rw *Rows) CopyIndexes(oix *Rows) {
if oix.Indexes == nil {
rw.Indexes = nil
} else {
rw.Indexes = slices.Clone(oix.Indexes)
}
}
// addRowsIndexes adds n rows to indexes starting at end of current tensor size
func (rw *Rows) addRowsIndexes(n int) { //types:add
if rw.Indexes == nil {
return
}
stidx := rw.Tensor.DimSize(0)
for i := stidx; i < stidx+n; i++ {
rw.Indexes = append(rw.Indexes, i)
}
}
// AddRows adds n rows to end of underlying Tensor, and to the indexes in this view
func (rw *Rows) AddRows(n int) { //types:add
stidx := rw.Tensor.DimSize(0)
rw.addRowsIndexes(n)
rw.Tensor.SetNumRows(stidx + n)
}
// InsertRows adds n rows to end of underlying Tensor, and to the indexes starting at
// given index in this view
func (rw *Rows) InsertRows(at, n int) {
stidx := rw.Tensor.DimSize(0)
rw.IndexesNeeded()
rw.Tensor.SetNumRows(stidx + n)
nw := make([]int, n, n+len(rw.Indexes)-at)
for i := 0; i < n; i++ {
nw[i] = stidx + i
}
rw.Indexes = append(rw.Indexes[:at], append(nw, rw.Indexes[at:]...)...)
}
// DeleteRows deletes n rows of indexes starting at given index in the list of indexes
func (rw *Rows) DeleteRows(at, n int) {
rw.IndexesNeeded()
rw.Indexes = append(rw.Indexes[:at], rw.Indexes[at+n:]...)
}
// Swap switches the indexes for i and j
func (rw *Rows) Swap(i, j int) {
if rw.Indexes == nil {
return
}
rw.Indexes[i], rw.Indexes[j] = rw.Indexes[j], rw.Indexes[i]
}
/////// Floats
// Float returns the value of given index as a float64.
// The first index value is indirected through the indexes.
func (rw *Rows) Float(i ...int) float64 {
if rw.Indexes == nil {
return rw.Tensor.Float(i...)
}
ic := slices.Clone(i)
ic[0] = rw.Indexes[ic[0]]
return rw.Tensor.Float(ic...)
}
// SetFloat sets the value of given index as a float64
// The first index value is indirected through the [Rows.Indexes].
func (rw *Rows) SetFloat(val float64, i ...int) {
if rw.Indexes == nil {
rw.Tensor.SetFloat(val, i...)
return
}
ic := slices.Clone(i)
ic[0] = rw.Indexes[ic[0]]
rw.Tensor.SetFloat(val, ic...)
}
// FloatRow returns the value at given row and cell,
// where row is outermost dim, and cell is 1D index into remaining inner dims.
// Row is indirected through the [Rows.Indexes].
// This is the preferred interface for all Rows operations.
func (rw *Rows) FloatRow(row, cell int) float64 {
return rw.Tensor.FloatRow(rw.RowIndex(row), cell)
}
// SetFloatRow sets the value at given row and cell,
// where row is outermost dim, and cell is 1D index into remaining inner dims.
// Row is indirected through the [Rows.Indexes].
// This is the preferred interface for all Rows operations.
func (rw *Rows) SetFloatRow(val float64, row, cell int) {
rw.Tensor.SetFloatRow(val, rw.RowIndex(row), cell)
}
// Float1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (rw *Rows) Float1D(i int) float64 {
if rw.Indexes == nil {
return rw.Tensor.Float1D(i)
}
return rw.Float(rw.Tensor.Shape().IndexFrom1D(i)...)
}
// SetFloat1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (rw *Rows) SetFloat1D(val float64, i int) {
if rw.Indexes == nil {
rw.Tensor.SetFloat1D(val, i)
return
}
rw.SetFloat(val, rw.Tensor.Shape().IndexFrom1D(i)...)
}
/////// Strings
// StringValue returns the value of given index as a string.
// The first index value is indirected through the indexes.
func (rw *Rows) StringValue(i ...int) string {
if rw.Indexes == nil {
return rw.Tensor.StringValue(i...)
}
ic := slices.Clone(i)
ic[0] = rw.Indexes[ic[0]]
return rw.Tensor.StringValue(ic...)
}
// SetString sets the value of given index as a string
// The first index value is indirected through the [Rows.Indexes].
func (rw *Rows) SetString(val string, i ...int) {
if rw.Indexes == nil {
rw.Tensor.SetString(val, i...)
return
}
ic := slices.Clone(i)
ic[0] = rw.Indexes[ic[0]]
rw.Tensor.SetString(val, ic...)
}
// StringRow returns the value at given row and cell,
// where row is outermost dim, and cell is 1D index into remaining inner dims.
// Row is indirected through the [Rows.Indexes].
// This is the preferred interface for all Rows operations.
func (rw *Rows) StringRow(row, cell int) string {
return rw.Tensor.StringRow(rw.RowIndex(row), cell)
}
// SetStringRow sets the value at given row and cell,
// where row is outermost dim, and cell is 1D index into remaining inner dims.
// Row is indirected through the [Rows.Indexes].
// This is the preferred interface for all Rows operations.
func (rw *Rows) SetStringRow(val string, row, cell int) {
rw.Tensor.SetStringRow(val, rw.RowIndex(row), cell)
}
// AppendRowFloat adds a row and sets float value(s), up to number of cells.
func (rw *Rows) AppendRowFloat(val ...float64) {
rw.addRowsIndexes(1)
rw.Tensor.AppendRowFloat(val...)
}
// String1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (rw *Rows) String1D(i int) string {
if rw.Indexes == nil {
return rw.Tensor.String1D(i)
}
return rw.StringValue(rw.Tensor.Shape().IndexFrom1D(i)...)
}
// SetString1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (rw *Rows) SetString1D(val string, i int) {
if rw.Indexes == nil {
rw.Tensor.SetString1D(val, i)
return
}
rw.SetString(val, rw.Tensor.Shape().IndexFrom1D(i)...)
}
// AppendRowString adds a row and sets string value(s), up to number of cells.
func (rw *Rows) AppendRowString(val ...string) {
rw.addRowsIndexes(1)
rw.Tensor.AppendRowString(val...)
}
/////// Ints
// Int returns the value of given index as an int.
// The first index value is indirected through the indexes.
func (rw *Rows) Int(i ...int) int {
if rw.Indexes == nil {
return rw.Tensor.Int(i...)
}
ic := slices.Clone(i)
ic[0] = rw.Indexes[ic[0]]
return rw.Tensor.Int(ic...)
}
// SetInt sets the value of given index as an int
// The first index value is indirected through the [Rows.Indexes].
func (rw *Rows) SetInt(val int, i ...int) {
if rw.Indexes == nil {
rw.Tensor.SetInt(val, i...)
return
}
ic := slices.Clone(i)
ic[0] = rw.Indexes[ic[0]]
rw.Tensor.SetInt(val, ic...)
}
// IntRow returns the value at given row and cell,
// where row is outermost dim, and cell is 1D index into remaining inner dims.
// Row is indirected through the [Rows.Indexes].
// This is the preferred interface for all Rows operations.
func (rw *Rows) IntRow(row, cell int) int {
return rw.Tensor.IntRow(rw.RowIndex(row), cell)
}
// SetIntRow sets the value at given row and cell,
// where row is outermost dim, and cell is 1D index into remaining inner dims.
// Row is indirected through the [Rows.Indexes].
// This is the preferred interface for all Rows operations.
func (rw *Rows) SetIntRow(val int, row, cell int) {
rw.Tensor.SetIntRow(val, rw.RowIndex(row), cell)
}
// AppendRowInt adds a row and sets int value(s), up to number of cells.
func (rw *Rows) AppendRowInt(val ...int) {
rw.addRowsIndexes(1)
rw.Tensor.AppendRowInt(val...)
}
// Int1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (rw *Rows) Int1D(i int) int {
if rw.Indexes == nil {
return rw.Tensor.Int1D(i)
}
return rw.Int(rw.Tensor.Shape().IndexFrom1D(i)...)
}
// SetInt1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (rw *Rows) SetInt1D(val int, i int) {
if rw.Indexes == nil {
rw.Tensor.SetInt1D(val, i)
return
}
rw.SetInt(val, rw.Tensor.Shape().IndexFrom1D(i)...)
}
/////// SubSpaces
// SubSpace returns a new tensor with innermost subspace at given
// offset(s) in outermost dimension(s) (len(offs) < NumDims).
// The new tensor points to the values of the this tensor (i.e., modifications
// will affect both), as its Values slice is a view onto the original (which
// is why only inner-most contiguous supsaces are supported).
// Use Clone() method to separate the two.
// Rows version does indexed indirection of the outermost row dimension
// of the offsets.
func (rw *Rows) SubSpace(offs ...int) Values {
if len(offs) == 0 {
return nil
}
offs[0] = rw.RowIndex(offs[0])
return rw.Tensor.SubSpace(offs...)
}
// RowTensor is a convenience version of [Rows.SubSpace] to return the
// SubSpace for the outermost row dimension, indirected through the indexes.
func (rw *Rows) RowTensor(row int) Values {
return rw.Tensor.RowTensor(rw.RowIndex(row))
}
// SetRowTensor sets the values of the SubSpace at given row to given values,
// with row indirected through the indexes.
func (rw *Rows) SetRowTensor(val Values, row int) {
rw.Tensor.SetRowTensor(val, rw.RowIndex(row))
}
// AppendRow adds a row and sets values to given values.
func (rw *Rows) AppendRow(val Values) {
nrow := rw.Tensor.DimSize(0)
rw.AddRows(1)
rw.Tensor.SetRowTensor(val, nrow)
}
// check for interface impl
var _ RowMajor = (*Rows)(nil)
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"fmt"
"slices"
)
// Shape manages a tensor's shape information, including sizes and strides,
// and can compute the flat index into an underlying 1D data storage array based on an
// n-dimensional index (and vice-versa).
// Per Go / C / Python conventions, indexes are Row-Major, ordered from
// outer to inner left-to-right, so the inner-most is right-most.
type Shape struct {
// size per dimension.
Sizes []int
// offsets for each dimension.
Strides []int `display:"-"`
}
// NewShape returns a new shape with given sizes.
// RowMajor ordering is used by default.
func NewShape(sizes ...int) *Shape {
sh := &Shape{}
sh.SetShapeSizes(sizes...)
return sh
}
// SetShapeSizes sets the shape sizes from list of ints.
// RowMajor ordering is used by default.
func (sh *Shape) SetShapeSizes(sizes ...int) {
sh.Sizes = slices.Clone(sizes)
sh.Strides = RowMajorStrides(sizes...)
}
// SetShapeSizesFromTensor sets the shape sizes from given tensor.
// RowMajor ordering is used by default.
func (sh *Shape) SetShapeSizesFromTensor(sizes Tensor) {
sh.SetShapeSizes(AsIntSlice(sizes)...)
}
// SizesAsTensor returns shape sizes as an Int Tensor.
func (sh *Shape) SizesAsTensor() *Int {
return NewIntFromValues(sh.Sizes...)
}
// CopyFrom copies the shape parameters from another Shape struct.
// copies the data so it is not accidentally subject to updates.
func (sh *Shape) CopyFrom(cp *Shape) {
sh.Sizes = slices.Clone(cp.Sizes)
sh.Strides = slices.Clone(cp.Strides)
}
// Len returns the total length of elements in the tensor
// (i.e., the product of the shape sizes).
func (sh *Shape) Len() int {
if len(sh.Sizes) == 0 {
return 0
}
ln := 1
for _, v := range sh.Sizes {
ln *= v
}
return ln
}
// NumDims returns the total number of dimensions.
func (sh *Shape) NumDims() int { return len(sh.Sizes) }
// DimSize returns the size of given dimension.
func (sh *Shape) DimSize(i int) int {
return sh.Sizes[i]
}
// IndexIsValid() returns true if given index is valid (within ranges for all dimensions)
func (sh *Shape) IndexIsValid(idx ...int) bool {
if len(idx) != sh.NumDims() {
return false
}
for i, v := range sh.Sizes {
if idx[i] < 0 || idx[i] >= v {
return false
}
}
return true
}
// IsEqual returns true if this shape is same as other (does not compare names)
func (sh *Shape) IsEqual(oth *Shape) bool {
if slices.Compare(sh.Sizes, oth.Sizes) != 0 {
return false
}
if slices.Compare(sh.Strides, oth.Strides) != 0 {
return false
}
return true
}
// RowCellSize returns the size of the outermost Row shape dimension,
// and the size of all the remaining inner dimensions (the "cell" size).
// Used for Tensors that are columns in a data table.
func (sh *Shape) RowCellSize() (rows, cells int) {
if len(sh.Sizes) == 0 {
return 0, 1
}
rows = sh.Sizes[0]
if len(sh.Sizes) == 1 {
cells = 1
} else if rows > 0 {
cells = sh.Len() / rows
} else {
ln := 1
for _, v := range sh.Sizes[1:] {
ln *= v
}
cells = ln
}
return
}
// IndexTo1D returns the flat 1D index from given n-dimensional indicies.
// No checking is done on the length or size of the index values relative
// to the shape of the tensor.
func (sh *Shape) IndexTo1D(index ...int) int {
oned := 0
for i, v := range index {
oned += v * sh.Strides[i]
}
return oned
}
// IndexFrom1D returns the n-dimensional index from a "flat" 1D array index.
func (sh *Shape) IndexFrom1D(oned int) []int {
nd := len(sh.Sizes)
index := make([]int, nd)
rem := oned
for i := nd - 1; i >= 0; i-- {
s := sh.Sizes[i]
if s == 0 {
return index
}
iv := rem % s
rem /= s
index[i] = iv
}
return index
}
// String satisfies the fmt.Stringer interface
func (sh *Shape) String() string {
return fmt.Sprintf("%v", sh.Sizes)
}
// RowMajorStrides returns strides for sizes where the first dimension is outermost
// and subsequent dimensions are progressively inner.
func RowMajorStrides(sizes ...int) []int {
if len(sizes) == 0 {
return nil
}
sizes[0] = max(1, sizes[0]) // critical for strides to not be nil due to rows = 0
rem := int(1)
for _, v := range sizes {
rem *= v
}
if rem == 0 {
strides := make([]int, len(sizes))
for i := range strides {
strides[i] = rem
}
return strides
}
strides := make([]int, len(sizes))
for i, v := range sizes {
rem /= v
strides[i] = rem
}
return strides
}
// ColumnMajorStrides returns strides for sizes where the first dimension is inner-most
// and subsequent dimensions are progressively outer
func ColumnMajorStrides(sizes ...int) []int {
total := int(1)
for _, v := range sizes {
if v == 0 {
strides := make([]int, len(sizes))
for i := range strides {
strides[i] = total
}
return strides
}
}
strides := make([]int, len(sizes))
for i, v := range sizes {
strides[i] = total
total *= v
}
return strides
}
// AddShapes returns a new shape by adding two shapes one after the other.
func AddShapes(shape1, shape2 *Shape) *Shape {
sh1 := shape1.Sizes
sh2 := shape2.Sizes
nsh := make([]int, len(sh1)+len(sh2))
copy(nsh, sh1)
copy(nsh[len(sh1):], sh2)
sh := NewShape(nsh...)
return sh
}
// CellsSizes returns the sizes of inner cells dimensions given
// overall tensor sizes. It returns []int{1} for the 1D case.
// Used for ensuring cell-wise outputs are the right size.
func CellsSize(sizes []int) []int {
csz := slices.Clone(sizes)
if len(csz) == 1 {
csz[0] = 1
} else {
csz = csz[1:]
}
return csz
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"math/rand"
"reflect"
"slices"
"sort"
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/base/reflectx"
"cogentcore.org/core/base/slicesx"
)
// Sliced provides a re-sliced view onto another "source" [Tensor],
// defined by a set of [Sliced.Indexes] for each dimension (must have
// at least 1 index per dimension to avoid a null view).
// Thus, each dimension can be transformed in arbitrary ways relative
// to the original tensor (filtered subsets, reversals, sorting, etc).
// This view is not memory-contiguous and does not support the [RowMajor]
// interface or efficient access to inner-dimensional subspaces.
// A new Sliced view defaults to a full transparent view of the source tensor.
// There is additional cost for every access operation associated with the
// indexed indirection, and access is always via the full n-dimensional indexes.
// See also [Rows] for a version that only indexes the outermost row dimension,
// which is much more efficient for this common use-case, and does support [RowMajor].
// To produce a new concrete [Values] that has raw data actually organized according
// to the indexed order (i.e., the copy function of numpy), call [Sliced.AsValues].
type Sliced struct { //types:add
// Tensor source that we are an indexed view onto.
Tensor Tensor
// Indexes are the indexes for each dimension, with dimensions as the outer
// slice (enforced to be the same length as the NumDims of the source Tensor),
// and a list of dimension index values (within range of DimSize(d)).
// A nil list of indexes for a dimension automatically provides a full,
// sequential view of that dimension.
Indexes [][]int
}
// NewSliced returns a new [Sliced] view of given tensor,
// with optional list of indexes for each dimension (none / nil = sequential).
// Any dimensions without indexes default to nil = full sequential view.
func NewSliced(tsr Tensor, idxs ...[]int) *Sliced {
sl := &Sliced{Tensor: tsr, Indexes: idxs}
sl.ValidIndexes()
return sl
}
// Reslice returns a new [Sliced] (and potentially [Reshaped]) view of given tensor,
// with given slice expressions for each dimension, which can be:
// - an integer, indicating a specific index value along that dimension.
// Can use negative numbers to index from the end.
// This axis will also be removed using a [Reshaped].
// - a [Slice] object expressing a range of indexes.
// - [FullAxis] includes the full original axis (equivalent to `Slice{}`).
// - [Ellipsis] creates a flexibly-sized stretch of FullAxis dimensions,
// which automatically aligns the remaining slice elements based on the source
// dimensionality.
// - [NewAxis] creates a new singleton (length=1) axis, used to to reshape
// without changing the size. This triggers a [Reshaped].
// - any remaining dimensions without indexes default to nil = full sequential view.
func Reslice(tsr Tensor, sls ...any) Tensor {
ns := len(sls)
if ns == 0 {
return NewSliced(tsr)
}
nd := tsr.NumDims()
ed := nd - ns // extra dimensions
ixs := make([][]int, nd)
doReshape := false // indicates if we need a Reshaped
reshape := make([]int, 0, nd+2) // if we need one, this is the target shape
ci := 0
for d := range ns {
s := sls[d]
switch x := s.(type) {
case int:
doReshape = true // doesn't add to new shape.
if x < 0 {
ixs[ci] = []int{tsr.DimSize(ci) + x}
} else {
ixs[ci] = []int{x}
}
case Slice:
ixs[ci] = x.IntSlice(tsr.DimSize(ci))
reshape = append(reshape, len(ixs[ci]))
case SlicesMagic:
switch x {
case FullAxis:
ixs[ci] = Slice{}.IntSlice(tsr.DimSize(ci))
reshape = append(reshape, len(ixs[ci]))
case NewAxis:
ed++ // we are not real
doReshape = true
reshape = append(reshape, 1)
continue // skip the increment in ci
case Ellipsis:
ed++ // extra for us
for range ed {
ixs[ci] = Slice{}.IntSlice(tsr.DimSize(ci))
reshape = append(reshape, len(ixs[ci]))
ci++
}
if ed > 0 {
ci--
}
ed = 0 // ate them up
}
}
ci++
}
for range ed { // fill any extra dimensions
ixs[ci] = Slice{}.IntSlice(tsr.DimSize(ci))
reshape = append(reshape, len(ixs[ci]))
ci++
}
sl := NewSliced(tsr, ixs...)
if doReshape {
if len(reshape) == 0 { // all indexes
reshape = []int{1}
}
return NewReshaped(sl, reshape...)
}
return sl
}
// AsSliced returns the tensor as a [Sliced] view.
// If it already is one, then it is returned, otherwise it is wrapped
// in a new Sliced, with default full sequential ("transparent") view.
func AsSliced(tsr Tensor) *Sliced {
if sl, ok := tsr.(*Sliced); ok {
return sl
}
return NewSliced(tsr)
}
// SetTensor sets tensor as source for this view, and initializes a full
// transparent view onto source (calls [Sliced.Sequential]).
func (sl *Sliced) SetTensor(tsr Tensor) {
sl.Tensor = tsr
sl.Sequential()
}
// SourceIndex returns the actual index into source tensor dimension
// based on given index value.
func (sl *Sliced) SourceIndex(dim, idx int) int {
ix := sl.Indexes[dim]
if ix == nil {
return idx
}
return ix[idx]
}
// SourceIndexes returns the actual n-dimensional indexes into source tensor
// based on given list of indexes based on the Sliced view shape.
func (sl *Sliced) SourceIndexes(i ...int) []int {
ix := slices.Clone(i)
for d, idx := range i {
ix[d] = sl.SourceIndex(d, idx)
}
return ix
}
// SourceIndexesFrom1D returns the n-dimensional indexes into source tensor
// based on the given 1D index based on the Sliced view shape.
func (sl *Sliced) SourceIndexesFrom1D(oned int) []int {
sh := sl.Shape()
oix := sh.IndexFrom1D(oned) // full indexes in our coords
return sl.SourceIndexes(oix...)
}
// ValidIndexes ensures that [Sliced.Indexes] are valid,
// removing any out-of-range values and setting the view to nil (full sequential)
// for any dimension with no indexes (which is an invalid condition).
// Call this when any structural changes are made to underlying Tensor.
func (sl *Sliced) ValidIndexes() {
nd := sl.Tensor.NumDims()
sl.Indexes = slicesx.SetLength(sl.Indexes, nd)
for d := range nd {
ni := len(sl.Indexes[d])
if ni == 0 { // invalid
sl.Indexes[d] = nil // full
continue
}
ds := sl.Tensor.DimSize(d)
ix := sl.Indexes[d]
for i := ni - 1; i >= 0; i-- {
if ix[i] >= ds {
ix = append(ix[:i], ix[i+1:]...)
}
}
sl.Indexes[d] = ix
}
}
// Sequential sets all Indexes to nil, resulting in full sequential access into tensor.
func (sl *Sliced) Sequential() { //types:add
nd := sl.Tensor.NumDims()
sl.Indexes = slicesx.SetLength(sl.Indexes, nd)
for d := range nd {
sl.Indexes[d] = nil
}
}
// IndexesNeeded is called prior to an operation that needs actual indexes,
// on given dimension. If Indexes == nil, they are set to all items, otherwise
// current indexes are left as is. Use Sequential, then IndexesNeeded to ensure
// all dimension indexes are represented.
func (sl *Sliced) IndexesNeeded(d int) {
ix := sl.Indexes[d]
if ix != nil {
return
}
ix = make([]int, sl.Tensor.DimSize(d))
for i := range ix {
ix[i] = i
}
sl.Indexes[d] = ix
}
func (sl *Sliced) Label() string { return label(metadata.Name(sl), sl.Shape()) }
func (sl *Sliced) String() string { return Sprintf("", sl, 0) }
func (sl *Sliced) Metadata() *metadata.Data { return sl.Tensor.Metadata() }
func (sl *Sliced) IsString() bool { return sl.Tensor.IsString() }
func (sl *Sliced) DataType() reflect.Kind { return sl.Tensor.DataType() }
func (sl *Sliced) Shape() *Shape { return NewShape(sl.ShapeSizes()...) }
func (sl *Sliced) Len() int { return sl.Shape().Len() }
func (sl *Sliced) NumDims() int { return sl.Tensor.NumDims() }
// For each dimension, we return the effective shape sizes using
// the current number of indexes per dimension.
func (sl *Sliced) ShapeSizes() []int {
nd := sl.Tensor.NumDims()
if nd == 0 {
return sl.Tensor.ShapeSizes()
}
sh := slices.Clone(sl.Tensor.ShapeSizes())
for d := range nd {
if sl.Indexes[d] != nil {
sh[d] = len(sl.Indexes[d])
}
}
return sh
}
// DimSize returns the effective view size of given dimension.
func (sl *Sliced) DimSize(dim int) int {
if sl.Indexes[dim] != nil {
return len(sl.Indexes[dim])
}
return sl.Tensor.DimSize(dim)
}
// AsValues returns a copy of this tensor as raw [Values].
// This "renders" the Sliced view into a fully contiguous
// and optimized memory representation of that view, which will be faster
// to access for further processing, and enables all the additional
// functionality provided by the [Values] interface.
func (sl *Sliced) AsValues() Values {
dt := sl.Tensor.DataType()
vt := NewOfType(dt, sl.ShapeSizes()...)
n := sl.Len()
switch {
case sl.Tensor.IsString():
for i := range n {
vt.SetString1D(sl.String1D(i), i)
}
case reflectx.KindIsFloat(dt):
for i := range n {
vt.SetFloat1D(sl.Float1D(i), i)
}
default:
for i := range n {
vt.SetInt1D(sl.Int1D(i), i)
}
}
return vt
}
//////// Floats
// Float returns the value of given index as a float64.
// The indexes are indirected through the [Sliced.Indexes].
func (sl *Sliced) Float(i ...int) float64 {
return sl.Tensor.Float(sl.SourceIndexes(i...)...)
}
// SetFloat sets the value of given index as a float64
// The indexes are indirected through the [Sliced.Indexes].
func (sl *Sliced) SetFloat(val float64, i ...int) {
sl.Tensor.SetFloat(val, sl.SourceIndexes(i...)...)
}
// Float1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (sl *Sliced) Float1D(i int) float64 {
return sl.Tensor.Float(sl.SourceIndexesFrom1D(i)...)
}
// SetFloat1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (sl *Sliced) SetFloat1D(val float64, i int) {
sl.Tensor.SetFloat(val, sl.SourceIndexesFrom1D(i)...)
}
//////// Strings
// StringValue returns the value of given index as a string.
// The indexes are indirected through the [Sliced.Indexes].
func (sl *Sliced) StringValue(i ...int) string {
return sl.Tensor.StringValue(sl.SourceIndexes(i...)...)
}
// SetString sets the value of given index as a string
// The indexes are indirected through the [Sliced.Indexes].
func (sl *Sliced) SetString(val string, i ...int) {
sl.Tensor.SetString(val, sl.SourceIndexes(i...)...)
}
// String1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (sl *Sliced) String1D(i int) string {
return sl.Tensor.StringValue(sl.SourceIndexesFrom1D(i)...)
}
// SetString1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (sl *Sliced) SetString1D(val string, i int) {
sl.Tensor.SetString(val, sl.SourceIndexesFrom1D(i)...)
}
//////// Ints
// Int returns the value of given index as an int.
// The indexes are indirected through the [Sliced.Indexes].
func (sl *Sliced) Int(i ...int) int {
return sl.Tensor.Int(sl.SourceIndexes(i...)...)
}
// SetInt sets the value of given index as an int
// The indexes are indirected through the [Sliced.Indexes].
func (sl *Sliced) SetInt(val int, i ...int) {
sl.Tensor.SetInt(val, sl.SourceIndexes(i...)...)
}
// Int1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (sl *Sliced) Int1D(i int) int {
return sl.Tensor.Int(sl.SourceIndexesFrom1D(i)...)
}
// SetInt1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (sl *Sliced) SetInt1D(val int, i int) {
sl.Tensor.SetInt(val, sl.SourceIndexesFrom1D(i)...)
}
// Permuted sets indexes in given dimension to a permuted order.
// If indexes already exist then existing list of indexes is permuted,
// otherwise a new set of permuted indexes are generated
func (sl *Sliced) Permuted(dim int) {
ix := sl.Indexes[dim]
if ix == nil {
ix = rand.Perm(sl.Tensor.DimSize(dim))
} else {
rand.Shuffle(len(ix), func(i, j int) {
ix[i], ix[j] = ix[j], ix[i]
})
}
sl.Indexes[dim] = ix
}
// SortFunc sorts the indexes along given dimension using given compare function.
// The compare function operates directly on indexes into the Tensor
// as these row numbers have already been projected through the indexes.
// cmp(a, b) should return a negative number when a < b, a positive
// number when a > b and zero when a == b.
func (sl *Sliced) SortFunc(dim int, cmp func(tsr Tensor, dim, i, j int) int) {
sl.IndexesNeeded(dim)
ix := sl.Indexes[dim]
slices.SortFunc(ix, func(a, b int) int {
return cmp(sl.Tensor, dim, a, b) // key point: these are already indirected through indexes!!
})
sl.Indexes[dim] = ix
}
// SortIndexes sorts the indexes along given dimension directly in
// numerical order, producing the native ordering, while preserving
// any filtering that might have occurred.
func (sl *Sliced) SortIndexes(dim int) {
ix := sl.Indexes[dim]
if ix == nil {
return
}
sort.Ints(ix)
sl.Indexes[dim] = ix
}
// SortStableFunc stably sorts along given dimension using given compare function.
// The compare function operates directly on row numbers into the Tensor
// as these row numbers have already been projected through the indexes.
// cmp(a, b) should return a negative number when a < b, a positive
// number when a > b and zero when a == b.
// It is *essential* that it always returns 0 when the two are equal
// for the stable function to actually work.
func (sl *Sliced) SortStableFunc(dim int, cmp func(tsr Tensor, dim, i, j int) int) {
sl.IndexesNeeded(dim)
ix := sl.Indexes[dim]
slices.SortStableFunc(ix, func(a, b int) int {
return cmp(sl.Tensor, dim, a, b) // key point: these are already indirected through indexes!!
})
sl.Indexes[dim] = ix
}
// Filter filters the indexes using the given Filter function
// for setting the indexes for given dimension, and index into the
// source data.
func (sl *Sliced) Filter(dim int, filterer func(tsr Tensor, dim, idx int) bool) {
sl.IndexesNeeded(dim)
ix := sl.Indexes[dim]
sz := len(ix)
for i := sz - 1; i >= 0; i-- { // always go in reverse for filtering
if !filterer(sl, dim, ix[i]) { // delete
ix = append(ix[:i], ix[i+1:]...)
}
}
sl.Indexes[dim] = ix
}
// check for interface impl
var _ Tensor = (*Sliced)(nil)
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
// SlicesMagic are special elements in slice expressions, including
// NewAxis, FullAxis, and Ellipsis in [NewSliced] expressions.
type SlicesMagic int //enums:enum
const (
// FullAxis indicates that the full existing axis length should be used.
// This is equivalent to Slice{}, but is more semantic. In NumPy it is
// equivalent to a single : colon.
FullAxis SlicesMagic = iota
// NewAxis creates a new singleton (length=1) axis, used to to reshape
// without changing the size. Can also be used in [Reshaped].
NewAxis
// Ellipsis (...) is used in [NewSliced] expressions to produce
// a flexibly-sized stretch of FullAxis dimensions, which automatically
// aligns the remaining slice elements based on the source dimensionality.
Ellipsis
)
// Slice represents a slice of index values, for extracting slices of data,
// along a dimension of a given size, which is provided separately as an argument.
// Uses standard 'for' loop logic with a Start and _exclusive_ Stop value,
// and a Step increment: for i := Start; i < Stop; i += Step.
// The values stored in this struct are the _inputs_ for computing the actual
// slice values based on the actual size parameter for the dimension.
// Negative numbers count back from the end (i.e., size + val), and
// the zero value results in a list of all values in the dimension, with Step = 1 if 0.
// The behavior is identical to the NumPy slice.
type Slice struct {
// Start is the starting value. If 0 and Step < 0, = size-1;
// If negative, = size+Start.
Start int
// Stop value. If 0 and Step >= 0, = size;
// If 0 and Step < 0, = -1, to include whole range.
// If negative = size+Stop.
Stop int
// Step increment. If 0, = 1; if negative then Start must be > Stop
// to produce anything.
Step int
}
// NewSlice returns a new Slice with given srat, stop, step values.
func NewSlice(start, stop, step int) Slice {
return Slice{Start: start, Stop: stop, Step: step}
}
// GetStart is the actual start value given the size of the dimension.
func (sl Slice) GetStart(size int) int {
if sl.Start == 0 && sl.Step < 0 {
return size - 1
}
if sl.Start < 0 {
return size + sl.Start
}
return sl.Start
}
// GetStop is the actual end value given the size of the dimension.
func (sl Slice) GetStop(size int) int {
if sl.Stop == 0 && sl.Step >= 0 {
return size
}
if sl.Stop == 0 && sl.Step < 0 {
return -1
}
if sl.Stop < 0 {
return size + sl.Stop
}
return min(sl.Stop, size)
}
// GetStep is the actual increment value.
func (sl Slice) GetStep() int {
if sl.Step == 0 {
return 1
}
return sl.Step
}
// Len is the number of elements in the actual slice given
// size of the dimension.
func (sl Slice) Len(size int) int {
s := sl.GetStart(size)
e := sl.GetStop(size)
i := sl.GetStep()
n := max((e-s)/i, 0)
pe := s + n*i
if i < 0 {
if pe > e {
n++
}
} else {
if pe < e {
n++
}
}
return n
}
// ToIntSlice writes values to given []int slice, with given size parameter
// for the dimension being sliced. If slice is wrong size to hold values,
// not all are written: allocate ints using Len(size) to fit.
func (sl Slice) ToIntSlice(size int, ints []int) {
n := len(ints)
if n == 0 {
return
}
s := sl.GetStart(size)
e := sl.GetStop(size)
inc := sl.GetStep()
idx := 0
if inc < 0 {
for i := s; i > e; i += inc {
ints[idx] = i
idx++
if idx >= n {
break
}
}
} else {
for i := s; i < e; i += inc {
ints[idx] = i
idx++
if idx >= n {
break
}
}
}
}
// IntSlice returns []int slice with slice index values, up to given actual size.
func (sl Slice) IntSlice(size int) []int {
n := sl.Len(size)
if n == 0 {
return nil
}
ints := make([]int, n)
sl.ToIntSlice(size, ints)
return ints
}
// IntTensor returns an [Int] [Tensor] for slice, using actual size.
func (sl Slice) IntTensor(size int) *Int {
n := sl.Len(size)
if n == 0 {
return nil
}
tsr := NewInt(n)
sl.ToIntSlice(size, tsr.Values)
return tsr
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"fmt"
"strconv"
"cogentcore.org/core/base/errors"
)
// String is a tensor of string values
type String struct {
Base[string]
}
// NewString returns a new n-dimensional tensor of string values
// with the given sizes per dimension (shape).
func NewString(sizes ...int) *String {
tsr := &String{}
tsr.SetShapeSizes(sizes...)
tsr.Values = make([]string, tsr.Len())
return tsr
}
// NewStringShape returns a new n-dimensional tensor of string values
// using given shape.
func NewStringShape(shape *Shape) *String {
tsr := &String{}
tsr.shape.CopyFrom(shape)
tsr.Values = make([]string, tsr.Len())
return tsr
}
// StringToFloat64 converts string value to float64 using strconv,
// returning 0 if any error
func StringToFloat64(str string) float64 {
if fv, err := strconv.ParseFloat(str, 64); err == nil {
return fv
}
return 0
}
// Float64ToString converts float64 to string value using strconv, g format
func Float64ToString(val float64) string {
return strconv.FormatFloat(val, 'g', -1, 64)
}
// String satisfies the fmt.Stringer interface for string of tensor data.
func (tsr *String) String() string {
return Sprintf("", tsr, 0)
}
func (tsr *String) IsString() bool {
return true
}
func (tsr *String) AsValues() Values { return tsr }
/////// Strings
func (tsr *String) SetString(val string, i ...int) {
tsr.Values[tsr.shape.IndexTo1D(i...)] = val
}
func (tsr *String) String1D(i int) string {
return tsr.Values[i]
}
func (tsr *String) SetString1D(val string, i int) {
tsr.Values[NegIndex(i, len(tsr.Values))] = val
}
func (tsr *String) StringRow(row, cell int) string {
_, sz := tsr.shape.RowCellSize()
return tsr.Values[row*sz+cell]
}
func (tsr *String) SetStringRow(val string, row, cell int) {
_, sz := tsr.shape.RowCellSize()
tsr.Values[row*sz+cell] = val
}
// AppendRowString adds a row and sets string value(s), up to number of cells.
func (tsr *String) AppendRowString(val ...string) {
if tsr.NumDims() == 0 {
tsr.SetShapeSizes(0)
}
nrow, sz := tsr.shape.RowCellSize()
tsr.SetNumRows(nrow + 1)
mx := min(sz, len(val))
for i := range mx {
tsr.SetStringRow(val[i], nrow, i)
}
}
/////// Floats
func (tsr *String) Float(i ...int) float64 {
return StringToFloat64(tsr.Values[tsr.shape.IndexTo1D(i...)])
}
func (tsr *String) SetFloat(val float64, i ...int) {
tsr.Values[tsr.shape.IndexTo1D(i...)] = Float64ToString(val)
}
func (tsr *String) Float1D(i int) float64 {
return StringToFloat64(tsr.Values[NegIndex(i, len(tsr.Values))])
}
func (tsr *String) SetFloat1D(val float64, i int) {
tsr.Values[NegIndex(i, len(tsr.Values))] = Float64ToString(val)
}
func (tsr *String) FloatRow(row, cell int) float64 {
_, sz := tsr.shape.RowCellSize()
return StringToFloat64(tsr.Values[row*sz+cell])
}
func (tsr *String) SetFloatRow(val float64, row, cell int) {
_, sz := tsr.shape.RowCellSize()
tsr.Values[row*sz+cell] = Float64ToString(val)
}
// AppendRowFloat adds a row and sets float value(s), up to number of cells.
func (tsr *String) AppendRowFloat(val ...float64) {
if tsr.NumDims() == 0 {
tsr.SetShapeSizes(0)
}
nrow, sz := tsr.shape.RowCellSize()
tsr.SetNumRows(nrow + 1)
mx := min(sz, len(val))
for i := range mx {
tsr.SetFloatRow(val[i], nrow, i)
}
}
/////// Ints
func (tsr *String) Int(i ...int) int {
return errors.Ignore1(strconv.Atoi(tsr.Values[tsr.shape.IndexTo1D(i...)]))
}
func (tsr *String) SetInt(val int, i ...int) {
tsr.Values[tsr.shape.IndexTo1D(i...)] = strconv.Itoa(val)
}
func (tsr *String) Int1D(i int) int {
return errors.Ignore1(strconv.Atoi(tsr.Values[NegIndex(i, len(tsr.Values))]))
}
func (tsr *String) SetInt1D(val int, i int) {
tsr.Values[NegIndex(i, len(tsr.Values))] = strconv.Itoa(val)
}
func (tsr *String) IntRow(row, cell int) int {
_, sz := tsr.shape.RowCellSize()
return errors.Ignore1(strconv.Atoi(tsr.Values[row*sz+cell]))
}
func (tsr *String) SetIntRow(val int, row, cell int) {
_, sz := tsr.shape.RowCellSize()
tsr.Values[row*sz+cell] = strconv.Itoa(val)
}
// AppendRowInt adds a row and sets int value(s), up to number of cells.
func (tsr *String) AppendRowInt(val ...int) {
if tsr.NumDims() == 0 {
tsr.SetShapeSizes(0)
}
nrow, sz := tsr.shape.RowCellSize()
tsr.SetNumRows(nrow + 1)
mx := min(sz, len(val))
for i := range mx {
tsr.SetIntRow(val[i], nrow, i)
}
}
// SetZeros is a simple convenience function initialize all values to the
// zero value of the type (empty strings for string type).
func (tsr *String) SetZeros() {
for j := range tsr.Values {
tsr.Values[j] = ""
}
}
// Clone clones this tensor, creating a duplicate copy of itself with its
// own separate memory representation of all the values, and returns
// that as a Tensor (which can be converted into the known type as needed).
func (tsr *String) Clone() Values {
csr := NewStringShape(&tsr.shape)
copy(csr.Values, tsr.Values)
return csr
}
// CopyFrom copies all avail values from other tensor into this tensor, with an
// optimized implementation if the other tensor is of the same type, and
// otherwise it goes through appropriate standard type.
func (tsr *String) CopyFrom(frm Values) {
if fsm, ok := frm.(*String); ok {
copy(tsr.Values, fsm.Values)
return
}
sz := min(tsr.Len(), frm.Len())
for i := 0; i < sz; i++ {
tsr.Values[i] = Float64ToString(frm.Float1D(i))
}
}
// AppendFrom appends values from other tensor into this tensor,
// which must have the same cell size as this tensor.
// It uses and optimized implementation if the other tensor
// is of the same type, and otherwise it goes through
// appropriate standard type.
func (tsr *String) AppendFrom(frm Values) Values {
rows, cell := tsr.shape.RowCellSize()
frows, fcell := frm.Shape().RowCellSize()
if cell != fcell {
errors.Log(fmt.Errorf("tensor.AppendFrom: cell sizes do not match: %d != %d", cell, fcell))
return tsr
}
tsr.SetNumRows(rows + frows)
st := rows * cell
fsz := frows * fcell
if fsm, ok := frm.(*String); ok {
copy(tsr.Values[st:st+fsz], fsm.Values)
return tsr
}
for i := 0; i < fsz; i++ {
tsr.Values[st+i] = Float64ToString(frm.Float1D(i))
}
return tsr
}
// CopyCellsFrom copies given range of values from other tensor into this tensor,
// using flat 1D indexes: to = starting index in this Tensor to start copying into,
// start = starting index on from Tensor to start copying from, and n = number of
// values to copy. Uses an optimized implementation if the other tensor is
// of the same type, and otherwise it goes through appropriate standard type.
func (tsr *String) CopyCellsFrom(frm Values, to, start, n int) {
if fsm, ok := frm.(*String); ok {
for i := 0; i < n; i++ {
tsr.Values[to+i] = fsm.Values[start+i]
}
return
}
for i := 0; i < n; i++ {
tsr.Values[to+i] = Float64ToString(frm.Float1D(start + i))
}
}
// SubSpace returns a new tensor with innermost subspace at given
// offset(s) in outermost dimension(s) (len(offs) < NumDims).
// The new tensor points to the values of the this tensor (i.e., modifications
// will affect both), as its Values slice is a view onto the original (which
// is why only inner-most contiguous supsaces are supported).
// Use Clone() method to separate the two.
func (tsr *String) SubSpace(offs ...int) Values {
b := tsr.subSpaceImpl(offs...)
rt := &String{Base: *b}
return rt
}
// RowTensor is a convenience version of [RowMajor.SubSpace] to return the
// SubSpace for the outermost row dimension. [Rows] defines a version
// of this that indirects through the row indexes.
func (tsr *String) RowTensor(row int) Values {
return tsr.SubSpace(row)
}
// SetRowTensor sets the values of the SubSpace at given row to given values.
func (tsr *String) SetRowTensor(val Values, row int) {
_, cells := tsr.shape.RowCellSize()
st := row * cells
mx := min(val.Len(), cells)
tsr.CopyCellsFrom(val, st, 0, mx)
}
// AppendRow adds a row and sets values to given values.
func (tsr *String) AppendRow(val Values) {
if tsr.NumDims() == 0 {
tsr.SetShapeSizes(0)
}
nrow := tsr.DimSize(0)
tsr.SetNumRows(nrow + 1)
tsr.SetRowTensor(val, nrow)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
//go:generate core generate
import (
"fmt"
"reflect"
"cogentcore.org/core/base/metadata"
)
// DataTypes are the primary tensor data types with specific support.
// Any numerical type can also be used. bool is represented using an
// efficient bit slice.
type DataTypes interface {
string | bool | float32 | float64 | int | int32 | uint32 | byte
}
// MaxSprintLength is the default maximum length of a String() representation
// of a tensor, as generated by the Sprint function. Defaults to 1000.
var MaxSprintLength = 1000
// todo: add a conversion function to copy data from Column-Major to a tensor:
// It is also possible to use Column-Major order, which is used in R, Julia, and MATLAB
// where the inner-most index is first and outermost last.
// Tensor is the most general interface for n-dimensional tensors.
// Per C / Go / Python conventions, indexes are Row-Major, ordered from
// outer to inner left-to-right, so the inner-most is right-most.
// It is implemented for raw [Values] with direct integer indexing
// by the [Number], [String], and [Bool] types, covering the different
// concrete types specified by [DataTypes] (see [Values] for
// additional interface methods for raw value types).
// For float32 and float64 values, use NaN to indicate missing values,
// as all of the data analysis and plot packages skip NaNs.
// View Tensor types provide different ways of viewing a source tensor,
// including [Sliced] for arbitrary slices of dimension indexes,
// [Masked] for boolean masked access of individual elements,
// and [Indexed] for arbitrary indexes of values, organized into the
// shape of the indexes, not the original source data.
// [Reshaped] provides length preserving reshaping (mostly for computational
// alignment purposes), and [Rows] provides an optimized row-indexed
// view for [table.Table] data.
type Tensor interface {
fmt.Stringer
// Label satisfies the core.Labeler interface for a summary
// description of the tensor, including metadata Name if set.
Label() string
// Metadata returns the metadata for this tensor, which can be used
// to encode name, docs, shape dimension names, plotting options, etc.
Metadata() *metadata.Data
// Shape() returns a [Shape] representation of the tensor shape
// (dimension sizes). For tensors that present a view onto another
// tensor, this typically must be constructed.
// In general, it is better to use the specific [Tensor.ShapeSizes],
// [Tensor.ShapeSizes], [Tensor.DimSize] etc as neeed.
Shape() *Shape
// ShapeSizes returns the sizes of each dimension as a slice of ints.
// The returned slice is a copy and can be modified without side effects.
ShapeSizes() []int
// Len returns the total number of elements in the tensor,
// i.e., the product of all shape dimensions.
// Len must always be such that the 1D() accessors return
// values using indexes from 0..Len()-1.
Len() int
// NumDims returns the total number of dimensions.
NumDims() int
// DimSize returns size of given dimension.
DimSize(dim int) int
// DataType returns the type of the data elements in the tensor.
// Bool is returned for the Bool tensor type.
DataType() reflect.Kind
// IsString returns true if the data type is a String; otherwise it is numeric.
IsString() bool
// AsValues returns this tensor as raw [Values]. If it already is,
// it is returned directly. If it is a View tensor, the view is
// "rendered" into a fully contiguous and optimized [Values] representation
// of that view, which will be faster to access for further processing,
// and enables all the additional functionality provided by the [Values] interface.
AsValues() Values
//////// Floats
// Float returns the value of given n-dimensional index (matching Shape) as a float64.
Float(i ...int) float64
// SetFloat sets the value of given n-dimensional index (matching Shape) as a float64.
SetFloat(val float64, i ...int)
// Float1D returns the value of given 1-dimensional index (0-Len()-1) as a float64.
// If index is negative, it indexes from the end of the list (-1 = last).
// This can be somewhat expensive in wrapper views ([Rows], [Sliced]), which
// convert the flat index back into a full n-dimensional index and use that api.
// [Tensor.FloatRow] is preferred.
Float1D(i int) float64
// SetFloat1D sets the value of given 1-dimensional index (0-Len()-1) as a float64.
// If index is negative, it indexes from the end of the list (-1 = last).
// This can be somewhat expensive in the commonly-used [Rows] view;
// [Tensor.SetFloatRow] is preferred.
SetFloat1D(val float64, i int)
//////// Strings
// StringValue returns the value of given n-dimensional index (matching Shape) as a string.
// 'String' conflicts with [fmt.Stringer], so we have to use StringValue here.
StringValue(i ...int) string
// SetString sets the value of given n-dimensional index (matching Shape) as a string.
SetString(val string, i ...int)
// String1D returns the value of given 1-dimensional index (0-Len()-1) as a string.
// If index is negative, it indexes from the end of the list (-1 = last).
String1D(i int) string
// SetString1D sets the value of given 1-dimensional index (0-Len()-1) as a string.
// If index is negative, it indexes from the end of the list (-1 = last).
SetString1D(val string, i int)
//////// Ints
// Int returns the value of given n-dimensional index (matching Shape) as a int.
Int(i ...int) int
// SetInt sets the value of given n-dimensional index (matching Shape) as a int.
SetInt(val int, i ...int)
// Int1D returns the value of given 1-dimensional index (0-Len()-1) as a int.
// If index is negative, it indexes from the end of the list (-1 = last).
Int1D(i int) int
// SetInt1D sets the value of given 1-dimensional index (0-Len()-1) as a int.
// If index is negative, it indexes from the end of the list (-1 = last).
SetInt1D(val int, i int)
}
// NegIndex handles negative index values as counting backward from n.
func NegIndex(i, n int) int {
if i < 0 {
return n + i
}
return i
}
// Copyright (c) 2020, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensormpi
import (
"fmt"
"log"
"cogentcore.org/lab/base/mpi"
)
// Alloc allocates n items to current mpi proc based on WorldSize and WorldRank.
// Returns start and end (exclusive) range for current proc.
func AllocN(n int) (st, end int, err error) {
nproc := mpi.WorldSize()
if n%nproc != 0 {
err = fmt.Errorf("tensormpi.AllocN: number: %d is not an even multiple of number of MPI procs: %d -- must be!", n, nproc)
log.Println(err)
}
pt := n / nproc
st = pt * mpi.WorldRank()
end = st + pt
return
}
// Copyright (c) 2021, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensormpi
import (
"errors"
"fmt"
"math/rand"
"cogentcore.org/lab/base/mpi"
)
// RandCheck checks that the current random numbers generated across each
// MPI processor are identical.
func RandCheck(comm *mpi.Comm) error {
ws := comm.Size()
rnd := rand.Int()
src := []int{rnd}
agg := make([]int, ws)
err := comm.AllGatherInt(agg, src)
if err != nil {
return err
}
errs := ""
for i := range agg {
if agg[i] != rnd {
errs += fmt.Sprintf("%d ", i)
}
}
if errs != "" {
err = errors.New("tensormpi.RandCheck: random numbers differ in procs: " + errs)
mpi.Printf("%s\n", err)
}
return err
}
// Copyright (c) 2020, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensormpi
import (
"cogentcore.org/lab/base/mpi"
"cogentcore.org/lab/table"
)
// GatherTableRows does an MPI AllGather on given src table data, gathering into dest.
// dest will have np * src.Rows Rows, filled with each processor's data, in order.
// dest must be a clone of src: if not same number of cols, will be configured from src.
func GatherTableRows(dest, src *table.Table, comm *mpi.Comm) {
sr := src.NumRows()
np := mpi.WorldSize()
dr := np * sr
if dest.NumColumns() != src.NumColumns() {
*dest = *src.Clone()
}
dest.SetNumRows(dr)
for ci, st := range src.Columns.Values {
dt := dest.Columns.Values[ci]
GatherTensorRows(dt, st, comm)
}
}
// ReduceTable does an MPI AllReduce on given src table data using given operation,
// gathering into dest.
// each processor must have the same table organization -- the tensor values are
// just aggregated directly across processors.
// dest will be a clone of src if not the same (cos & rows),
// does nothing for strings.
func ReduceTable(dest, src *table.Table, comm *mpi.Comm, op mpi.Op) {
sr := src.NumRows()
if dest.NumColumns() != src.NumColumns() {
*dest = *src.Clone()
}
dest.SetNumRows(sr)
for ci, st := range src.Columns.Values {
dt := dest.Columns.Values[ci]
ReduceTensor(dt, st, comm, op)
}
}
// Copyright (c) 2020, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensormpi
import (
"reflect"
"cogentcore.org/lab/base/mpi"
"cogentcore.org/lab/tensor"
)
// GatherTensorRows does an MPI AllGather on given src tensor data, gathering into dest,
// using a row-based tensor organization (as in an table.Table).
// dest will have np * src.Rows Rows, filled with each processor's data, in order.
// dest must have same overall shape as src at start, but rows will be enforced.
func GatherTensorRows(dest, src tensor.Values, comm *mpi.Comm) error {
dt := src.DataType()
if dt == reflect.String {
return GatherTensorRowsString(dest.(*tensor.String), src.(*tensor.String), comm)
}
sr, _ := src.Shape().RowCellSize()
dr, _ := dest.Shape().RowCellSize()
np := mpi.WorldSize()
dl := np * sr
if dr != dl {
dest.SetNumRows(dl)
dr = dl
}
var err error
switch dt {
case reflect.Bool:
// todo
case reflect.Uint8:
dt := dest.(*tensor.Byte)
st := src.(*tensor.Byte)
err = comm.AllGatherU8(dt.Values, st.Values)
case reflect.Int32:
dt := dest.(*tensor.Int32)
st := src.(*tensor.Int32)
err = comm.AllGatherI32(dt.Values, st.Values)
case reflect.Int:
dt := dest.(*tensor.Int)
st := src.(*tensor.Int)
err = comm.AllGatherInt(dt.Values, st.Values)
case reflect.Float32:
dt := dest.(*tensor.Float32)
st := src.(*tensor.Float32)
err = comm.AllGatherF32(dt.Values, st.Values)
case reflect.Float64:
dt := dest.(*tensor.Float64)
st := src.(*tensor.Float64)
err = comm.AllGatherF64(dt.Values, st.Values)
}
return err
}
// GatherTensorRowsString does an MPI AllGather on given String src tensor data,
// gathering into dest, using a row-based tensor organization (as in an table.Table).
// dest will have np * src.Rows Rows, filled with each processor's data, in order.
// dest must have same overall shape as src at start, but rows will be enforced.
func GatherTensorRowsString(dest, src *tensor.String, comm *mpi.Comm) error {
sr, _ := src.Shape().RowCellSize()
dr, _ := dest.Shape().RowCellSize()
np := mpi.WorldSize()
dl := np * sr
if dr != dl {
dest.SetNumRows(dl)
dr = dl
}
ssz := len(src.Values)
dsz := len(dest.Values)
sln := make([]int, ssz)
dln := make([]int, dsz)
for i, s := range src.Values {
sln[i] = len(s)
}
err := comm.AllGatherInt(dln, sln)
if err != nil {
return err
}
mxlen := 0
for _, l := range dln {
mxlen = max(mxlen, l)
}
if mxlen == 0 {
return nil // nothing to transfer
}
sdt := make([]byte, ssz*mxlen)
ddt := make([]byte, dsz*mxlen)
idx := 0
for _, s := range src.Values {
l := len(s)
copy(sdt[idx:idx+l], []byte(s))
idx += mxlen
}
err = comm.AllGatherU8(ddt, sdt)
idx = 0
for i := range dest.Values {
l := dln[i]
s := string(ddt[idx : idx+l])
dest.Values[i] = s
idx += mxlen
}
return err
}
// ReduceTensor does an MPI AllReduce on given src tensor data, using given operation,
// gathering into dest. dest must have same overall shape as src -- will be enforced.
// IMPORTANT: src and dest must be different slices!
// each processor must have the same shape and organization for this to make sense.
// does nothing for strings.
func ReduceTensor(dest, src tensor.Values, comm *mpi.Comm, op mpi.Op) error {
dt := src.DataType()
if dt == reflect.String {
return nil
}
slen := src.Len()
if slen != dest.Len() {
tensor.SetShapeFrom(dest, src)
}
var err error
switch dt {
case reflect.Bool:
dt := dest.(*tensor.Bool)
st := src.(*tensor.Bool)
err = comm.AllReduceU8(op, dt.Values, st.Values)
case reflect.Uint8:
dt := dest.(*tensor.Byte)
st := src.(*tensor.Byte)
err = comm.AllReduceU8(op, dt.Values, st.Values)
case reflect.Int32:
dt := dest.(*tensor.Int32)
st := src.(*tensor.Int32)
err = comm.AllReduceI32(op, dt.Values, st.Values)
case reflect.Int:
dt := dest.(*tensor.Int)
st := src.(*tensor.Int)
err = comm.AllReduceInt(op, dt.Values, st.Values)
case reflect.Float32:
dt := dest.(*tensor.Float32)
st := src.(*tensor.Float32)
err = comm.AllReduceF32(op, dt.Values, st.Values)
case reflect.Float64:
dt := dest.(*tensor.Float64)
st := src.(*tensor.Float64)
err = comm.AllReduceF64(op, dt.Values, st.Values)
}
return err
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tmath
import (
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/tensor"
)
// Equal stores in the output the bool value a == b.
func Equal(a, b tensor.Tensor) *tensor.Bool {
return tensor.CallOut2Bool(EqualOut, a, b)
}
// EqualOut stores in the output the bool value a == b.
func EqualOut(a, b tensor.Tensor, out *tensor.Bool) error {
if a.IsString() {
return tensor.BoolStringsFuncOut(func(a, b string) bool { return a == b }, a, b, out)
}
return tensor.BoolFloatsFuncOut(func(a, b float64) bool { return a == b }, a, b, out)
}
// Less stores in the output the bool value a < b.
func Less(a, b tensor.Tensor) *tensor.Bool {
return tensor.CallOut2Bool(LessOut, a, b)
}
// LessOut stores in the output the bool value a < b.
func LessOut(a, b tensor.Tensor, out *tensor.Bool) error {
if a.IsString() {
return tensor.BoolStringsFuncOut(func(a, b string) bool { return a < b }, a, b, out)
}
return tensor.BoolFloatsFuncOut(func(a, b float64) bool { return a < b }, a, b, out)
}
// Greater stores in the output the bool value a > b.
func Greater(a, b tensor.Tensor) *tensor.Bool {
return tensor.CallOut2Bool(GreaterOut, a, b)
}
// GreaterOut stores in the output the bool value a > b.
func GreaterOut(a, b tensor.Tensor, out *tensor.Bool) error {
if a.IsString() {
return tensor.BoolStringsFuncOut(func(a, b string) bool { return a > b }, a, b, out)
}
return tensor.BoolFloatsFuncOut(func(a, b float64) bool { return a > b }, a, b, out)
}
// NotEqual stores in the output the bool value a != b.
func NotEqual(a, b tensor.Tensor) *tensor.Bool {
return tensor.CallOut2Bool(NotEqualOut, a, b)
}
// NotEqualOut stores in the output the bool value a != b.
func NotEqualOut(a, b tensor.Tensor, out *tensor.Bool) error {
if a.IsString() {
return tensor.BoolStringsFuncOut(func(a, b string) bool { return a != b }, a, b, out)
}
return tensor.BoolFloatsFuncOut(func(a, b float64) bool { return a != b }, a, b, out)
}
// LessEqual stores in the output the bool value a <= b.
func LessEqual(a, b tensor.Tensor) *tensor.Bool {
return tensor.CallOut2Bool(LessEqualOut, a, b)
}
// LessEqualOut stores in the output the bool value a <= b.
func LessEqualOut(a, b tensor.Tensor, out *tensor.Bool) error {
if a.IsString() {
return tensor.BoolStringsFuncOut(func(a, b string) bool { return a <= b }, a, b, out)
}
return tensor.BoolFloatsFuncOut(func(a, b float64) bool { return a <= b }, a, b, out)
}
// GreaterEqual stores in the output the bool value a >= b.
func GreaterEqual(a, b tensor.Tensor) *tensor.Bool {
return tensor.CallOut2Bool(GreaterEqualOut, a, b)
}
// GreaterEqualOut stores in the output the bool value a >= b.
func GreaterEqualOut(a, b tensor.Tensor, out *tensor.Bool) error {
if a.IsString() {
return tensor.BoolStringsFuncOut(func(a, b string) bool { return a >= b }, a, b, out)
}
return tensor.BoolFloatsFuncOut(func(a, b float64) bool { return a >= b }, a, b, out)
}
// Or stores in the output the bool value a || b.
func Or(a, b tensor.Tensor) *tensor.Bool {
return tensor.CallOut2Bool(OrOut, a, b)
}
// OrOut stores in the output the bool value a || b.
func OrOut(a, b tensor.Tensor, out *tensor.Bool) error {
return tensor.BoolIntsFuncOut(func(a, b int) bool { return a > 0 || b > 0 }, a, b, out)
}
// And stores in the output the bool value a || b.
func And(a, b tensor.Tensor) *tensor.Bool {
return tensor.CallOut2Bool(AndOut, a, b)
}
// AndOut stores in the output the bool value a || b.
func AndOut(a, b tensor.Tensor, out *tensor.Bool) error {
return tensor.BoolIntsFuncOut(func(a, b int) bool { return a > 0 && b > 0 }, a, b, out)
}
// Not stores in the output the bool value !a.
func Not(a tensor.Tensor) *tensor.Bool {
out := tensor.NewBool()
errors.Log(NotOut(a, out))
return out
}
// NotOut stores in the output the bool value !a.
func NotOut(a tensor.Tensor, out *tensor.Bool) error {
out.SetShapeSizes(a.Shape().Sizes...)
alen := a.Len()
tensor.VectorizeThreaded(1, func(tsr ...tensor.Tensor) int { return alen },
func(idx int, tsr ...tensor.Tensor) {
out.SetBool1D(tsr[0].Int1D(idx) == 0, idx)
}, a, out)
return nil
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tmath
import (
"math"
"cogentcore.org/lab/tensor"
)
func Abs(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(AbsOut, in)
}
func AbsOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Abs(a) }, in, out)
}
func Acos(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(AcosOut, in)
}
func AcosOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Acos(a) }, in, out)
}
func Acosh(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(AcoshOut, in)
}
func AcoshOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Acosh(a) }, in, out)
}
func Asin(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(AsinOut, in)
}
func AsinOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Asin(a) }, in, out)
}
func Asinh(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(AsinhOut, in)
}
func AsinhOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Asinh(a) }, in, out)
}
func Atan(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(AtanOut, in)
}
func AtanOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Atan(a) }, in, out)
}
func Atanh(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(AtanhOut, in)
}
func AtanhOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Atanh(a) }, in, out)
}
func Cbrt(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(CbrtOut, in)
}
func CbrtOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Cbrt(a) }, in, out)
}
func Ceil(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(CeilOut, in)
}
func CeilOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Ceil(a) }, in, out)
}
func Cos(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(CosOut, in)
}
func CosOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Cos(a) }, in, out)
}
func Cosh(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(CoshOut, in)
}
func CoshOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Cosh(a) }, in, out)
}
func Erf(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(ErfOut, in)
}
func ErfOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Erf(a) }, in, out)
}
func Erfc(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(ErfcOut, in)
}
func ErfcOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Erfc(a) }, in, out)
}
func Erfcinv(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(ErfcinvOut, in)
}
func ErfcinvOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Erfcinv(a) }, in, out)
}
func Erfinv(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(ErfinvOut, in)
}
func ErfinvOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Erfinv(a) }, in, out)
}
func Exp(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(ExpOut, in)
}
func ExpOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Exp(a) }, in, out)
}
func Exp2(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(Exp2Out, in)
}
func Exp2Out(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Exp2(a) }, in, out)
}
func Expm1(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(Expm1Out, in)
}
func Expm1Out(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Expm1(a) }, in, out)
}
func Floor(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(FloorOut, in)
}
func FloorOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Floor(a) }, in, out)
}
func Gamma(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(GammaOut, in)
}
func GammaOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Gamma(a) }, in, out)
}
func J0(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(J0Out, in)
}
func J0Out(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.J0(a) }, in, out)
}
func J1(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(J1Out, in)
}
func J1Out(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.J1(a) }, in, out)
}
func Log(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(LogOut, in)
}
func LogOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Log(a) }, in, out)
}
func Log10(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(Log10Out, in)
}
func Log10Out(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Log10(a) }, in, out)
}
func Log1p(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(Log1pOut, in)
}
func Log1pOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Log1p(a) }, in, out)
}
func Log2(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(Log2Out, in)
}
func Log2Out(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Log2(a) }, in, out)
}
func Logb(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(LogbOut, in)
}
func LogbOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Logb(a) }, in, out)
}
func Round(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(RoundOut, in)
}
func RoundOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Round(a) }, in, out)
}
func RoundToEven(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(RoundToEvenOut, in)
}
func RoundToEvenOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.RoundToEven(a) }, in, out)
}
func Sin(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(SinOut, in)
}
func SinOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Sin(a) }, in, out)
}
func Sinh(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(SinhOut, in)
}
func SinhOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Sinh(a) }, in, out)
}
func Sqrt(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(SqrtOut, in)
}
func SqrtOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Sqrt(a) }, in, out)
}
func Tan(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(TanOut, in)
}
func TanOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Tan(a) }, in, out)
}
func Tanh(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(TanhOut, in)
}
func TanhOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Tanh(a) }, in, out)
}
func Trunc(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(TruncOut, in)
}
func TruncOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Trunc(a) }, in, out)
}
func Y0(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(Y0Out, in)
}
func Y0Out(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Y0(a) }, in, out)
}
func Y1(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(Y1Out, in)
}
func Y1Out(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Y1(a) }, in, out)
}
//////// Binary
func Atan2(y, x tensor.Tensor) tensor.Values {
return tensor.CallOut2(Atan2Out, y, x)
}
func Atan2Out(y, x tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Atan2(a, b) }, y, x, out)
}
func Copysign(x, y tensor.Tensor) tensor.Values {
return tensor.CallOut2(CopysignOut, x, y)
}
func CopysignOut(x, y tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Copysign(a, b) }, x, y, out)
}
func Dim(x, y tensor.Tensor) tensor.Values {
return tensor.CallOut2(DimOut, x, y)
}
func DimOut(x, y tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Dim(a, b) }, x, y, out)
}
func Hypot(x, y tensor.Tensor) tensor.Values {
return tensor.CallOut2(HypotOut, x, y)
}
func HypotOut(x, y tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Hypot(a, b) }, x, y, out)
}
func Max(x, y tensor.Tensor) tensor.Values {
return tensor.CallOut2(MaxOut, x, y)
}
func MaxOut(x, y tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Max(a, b) }, x, y, out)
}
func Min(x, y tensor.Tensor) tensor.Values {
return tensor.CallOut2(MinOut, x, y)
}
func MinOut(x, y tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Min(a, b) }, x, y, out)
}
func Nextafter(x, y tensor.Tensor) tensor.Values {
return tensor.CallOut2(NextafterOut, x, y)
}
func NextafterOut(x, y tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Nextafter(a, b) }, x, y, out)
}
func Pow(x, y tensor.Tensor) tensor.Values {
return tensor.CallOut2(PowOut, x, y)
}
func PowOut(x, y tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Pow(a, b) }, x, y, out)
}
func Remainder(x, y tensor.Tensor) tensor.Values {
return tensor.CallOut2(RemainderOut, x, y)
}
func RemainderOut(x, y tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Remainder(a, b) }, x, y, out)
}
/*
func Nextafter32(x, y float32) (r float32)
func Inf(sign int) float64
func IsInf(f float64, sign int) bool
func IsNaN(f float64) (is bool)
func NaN() float64
func Signbit(x float64) bool
func Float32bits(f float32) uint32
func Float32frombits(b uint32) float32
func Float64bits(f float64) uint64
func Float64frombits(b uint64) float64
func FMA(x, y, z float64) float64
func Jn(n int, in tensor.Tensor, out tensor.Values)
func Yn(n int, in tensor.Tensor, out tensor.Values)
func Ldexp(frac float64, exp int) float64
func Ilogb(x float64) int
func Pow10(n int) float64
func Frexp(f float64) (frac float64, exp int)
func Modf(f float64) (int float64, frac float64)
func Lgamma(x float64) (lgamma float64, sign int)
func Sincos(x float64) (sin, cos float64)
*/
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tmath
import (
"math"
"cogentcore.org/lab/tensor"
)
// Assign assigns values from b into a.
func Assign(a, b tensor.Tensor) error {
return tensor.FloatAssignFunc(func(a, b float64) float64 { return b }, a, b)
}
// AddAssign does += add assign values from b into a.
func AddAssign(a, b tensor.Tensor) error {
if a.IsString() {
return tensor.StringAssignFunc(func(a, b string) string { return a + b }, a, b)
}
return tensor.FloatAssignFunc(func(a, b float64) float64 { return a + b }, a, b)
}
// SubAssign does -= sub assign values from b into a.
func SubAssign(a, b tensor.Tensor) error {
return tensor.FloatAssignFunc(func(a, b float64) float64 { return a - b }, a, b)
}
// MulAssign does *= mul assign values from b into a.
func MulAssign(a, b tensor.Tensor) error {
return tensor.FloatAssignFunc(func(a, b float64) float64 { return a * b }, a, b)
}
// DivAssign does /= divide assign values from b into a.
func DivAssign(a, b tensor.Tensor) error {
return tensor.FloatAssignFunc(func(a, b float64) float64 { return a / b }, a, b)
}
// ModAssign does %= modulus assign values from b into a.
func ModAssign(a, b tensor.Tensor) error {
return tensor.FloatAssignFunc(func(a, b float64) float64 { return math.Mod(a, b) }, a, b)
}
// Inc increments values in given tensor by 1.
func Inc(a tensor.Tensor) error {
alen := a.Len()
tensor.VectorizeThreaded(1, func(tsr ...tensor.Tensor) int { return alen },
func(idx int, tsr ...tensor.Tensor) {
tsr[0].SetFloat1D(tsr[0].Float1D(idx)+1.0, idx)
}, a)
return nil
}
// Dec decrements values in given tensor by 1.
func Dec(a tensor.Tensor) error {
alen := a.Len()
tensor.VectorizeThreaded(1, func(tsr ...tensor.Tensor) int { return alen },
func(idx int, tsr ...tensor.Tensor) {
tsr[0].SetFloat1D(tsr[0].Float1D(idx)-1.0, idx)
}, a)
return nil
}
// Add adds two tensors into output.
func Add(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(AddOut, a, b)
}
// AddOut adds two tensors into output.
func AddOut(a, b tensor.Tensor, out tensor.Values) error {
if a.IsString() {
return tensor.StringBinaryFuncOut(func(a, b string) string { return a + b }, a, b, out)
}
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return a + b }, a, b, out)
}
// Sub subtracts tensors into output.
func Sub(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(SubOut, a, b)
}
// SubOut subtracts two tensors into output.
func SubOut(a, b tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return a - b }, a, b, out)
}
// Mul multiplies tensors into output.
func Mul(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(MulOut, a, b)
}
// MulOut multiplies two tensors into output.
func MulOut(a, b tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return a * b }, a, b, out)
}
// Div divides tensors into output. always does floating point division,
// even with integer operands.
func Div(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2Float64(DivOut, a, b)
}
// DivOut divides two tensors into output.
func DivOut(a, b tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return a / b }, a, b, out)
}
// Mod performs modulus a%b on tensors into output.
func Mod(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(ModOut, a, b)
}
// ModOut performs modulus a%b on tensors into output.
func ModOut(a, b tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Mod(a, b) }, a, b, out)
}
// Negate stores in the output the bool value -a.
func Negate(a tensor.Tensor) tensor.Values {
return tensor.CallOut1(NegateOut, a)
}
// NegateOut stores in the output the bool value -a.
func NegateOut(a tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(in float64) float64 { return -in }, a, out)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"fmt"
"reflect"
"cogentcore.org/core/base/metadata"
)
// Values is an extended [Tensor] interface for raw value tensors.
// This supports direct setting of the shape of the underlying values,
// sub-space access to inner-dimensional subspaces of values, etc.
type Values interface {
RowMajor
// SetShapeSizes sets the dimension sizes of the tensor, and resizes
// backing storage appropriately, retaining all existing data that fits.
SetShapeSizes(sizes ...int)
// SetNumRows sets the number of rows (outermost dimension).
// It is safe to set this to 0. For incrementally growing tensors (e.g., a log)
// it is best to first set the anticipated full size, which allocates the
// full amount of memory, and then set to 0 and grow incrementally.
SetNumRows(rows int)
// Sizeof returns the number of bytes contained in the Values of this tensor.
// For String types, this is just the string pointers, not the string content.
Sizeof() int64
// Bytes returns the underlying byte representation of the tensor values.
// This is the actual underlying data, so make a copy if it can be
// unintentionally modified or retained more than for immediate use.
Bytes() []byte
// SetZeros is a convenience function initialize all values to the
// zero value of the type (empty strings for string type).
// New tensors always start out with zeros.
SetZeros()
// Clone clones this tensor, creating a duplicate copy of itself with its
// own separate memory representation of all the values.
Clone() Values
// CopyFrom copies all values from other tensor into this tensor, with an
// optimized implementation if the other tensor is of the same type, and
// otherwise it goes through the appropriate standard type (Float, Int, String).
CopyFrom(from Values)
// CopyCellsFrom copies given range of values from other tensor into this tensor,
// using flat 1D indexes: to = starting index in this Tensor to start copying into,
// start = starting index on from Tensor to start copying from, and n = number of
// values to copy. Uses an optimized implementation if the other tensor is
// of the same type, and otherwise it goes through appropriate standard type.
CopyCellsFrom(from Values, to, start, n int)
// AppendFrom appends all values from other tensor into this tensor, with an
// optimized implementation if the other tensor is of the same type, and
// otherwise it goes through the appropriate standard type (Float, Int, String).
AppendFrom(from Values) Values
}
// New returns a new n-dimensional tensor of given value type
// with the given sizes per dimension (shape).
func New[T DataTypes](sizes ...int) Values {
var v T
switch any(v).(type) {
case string:
return NewString(sizes...)
case bool:
return NewBool(sizes...)
case float64:
return NewNumber[float64](sizes...)
case float32:
return NewNumber[float32](sizes...)
case int:
return NewNumber[int](sizes...)
case int32:
return NewNumber[int32](sizes...)
case uint32:
return NewNumber[uint32](sizes...)
case byte:
return NewNumber[byte](sizes...)
default:
panic("tensor.New: unexpected error: type not supported")
}
}
// NewOfType returns a new n-dimensional tensor of given reflect.Kind type
// with the given sizes per dimension (shape).
// Supported types are in [DataTypes].
func NewOfType(typ reflect.Kind, sizes ...int) Values {
switch typ {
case reflect.String:
return NewString(sizes...)
case reflect.Bool:
return NewBool(sizes...)
case reflect.Float64:
return NewNumber[float64](sizes...)
case reflect.Float32:
return NewNumber[float32](sizes...)
case reflect.Int:
return NewNumber[int](sizes...)
case reflect.Int32:
return NewNumber[int32](sizes...)
case reflect.Uint8:
return NewNumber[byte](sizes...)
default:
panic(fmt.Sprintf("tensor.NewOfType: type not supported: %v", typ))
}
}
// metadata helpers
// SetShapeNames sets the tensor shape dimension names into given metadata.
func SetShapeNames(md *metadata.Data, names ...string) {
md.Set("ShapeNames", names)
}
// ShapeNames gets the tensor shape dimension names from given metadata.
func ShapeNames(md *metadata.Data) []string {
names, _ := metadata.Get[[]string](*md, "ShapeNames")
return names
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"math"
"runtime"
"sync"
)
var (
// ThreadingThreshod is the threshold in number of flops (floating point ops),
// computed as tensor N * flops per element, to engage actual parallel processing.
// Heuristically, numbers below this threshold do not result in
// an overall speedup, due to overhead costs. See tmath/ops_test.go for benchmark.
ThreadingThreshold = 300
// NumThreads is the number of threads to use for parallel threading.
// The default of 0 causes the [runtime.GOMAXPROCS] to be used.
NumThreads = 0
)
// Vectorize applies given function 'fun' to tensor elements indexed
// by given index, with the 'nfun' providing the number of indexes
// to vectorize over, and initializing any output vectors.
// Thus the nfun is often specific to a particular class of functions.
// Both functions are called with the same set
// of Tensors passed as the final argument(s).
// The role of each tensor is function-dependent: there could be multiple
// inputs and outputs, and the output could be effectively scalar,
// as in a sum operation. The interpretation of the index is
// function dependent as well, but often is used to iterate over
// the outermost row dimension of the tensor.
// This version runs purely sequentially on on this go routine.
// See VectorizeThreaded and VectorizeGPU for other versions.
func Vectorize(nfun func(tsr ...Tensor) int, fun func(idx int, tsr ...Tensor), tsr ...Tensor) {
n := nfun(tsr...)
if n <= 0 {
return
}
for idx := range n {
fun(idx, tsr...)
}
}
// VectorizeThreaded is a version of [Vectorize] that will automatically
// distribute the computation in parallel across multiple "threads" (goroutines)
// if the number of elements to be computed times the given flops
// (floating point operations) for the function exceeds the [ThreadingThreshold].
// Heuristically, numbers below this threshold do not result
// in an overall speedup, due to overhead costs.
// Each elemental math operation in the function adds a flop.
// See estimates in [tmath] for basic math functions.
func VectorizeThreaded(flops int, nfun func(tsr ...Tensor) int, fun func(idx int, tsr ...Tensor), tsr ...Tensor) {
n := nfun(tsr...)
if n <= 0 {
return
}
if flops < 0 {
flops = 1
}
if n*flops < ThreadingThreshold {
Vectorize(nfun, fun, tsr...)
return
}
VectorizeOnThreads(0, nfun, fun, tsr...)
}
// DefaultNumThreads returns the default number of threads to use:
// NumThreads if non-zero, otherwise [runtime.GOMAXPROCS].
func DefaultNumThreads() int {
if NumThreads > 0 {
return NumThreads
}
return runtime.GOMAXPROCS(0)
}
// VectorizeOnThreads runs given [Vectorize] function on given number
// of threads. Use [VectorizeThreaded] to only use parallel threads when
// it is likely to be beneficial, in terms of the ThreadingThreshold.
// If threads is 0, then the [DefaultNumThreads] will be used:
// GOMAXPROCS subject to NumThreads constraint if non-zero.
func VectorizeOnThreads(threads int, nfun func(tsr ...Tensor) int, fun func(idx int, tsr ...Tensor), tsr ...Tensor) {
if threads == 0 {
threads = DefaultNumThreads()
}
n := nfun(tsr...)
if n <= 0 {
return
}
nper := int(math.Ceil(float64(n) / float64(threads)))
wait := sync.WaitGroup{}
for start := 0; start < n; start += nper {
end := start + nper
if end > n {
end = n
}
wait.Add(1) // todo: move out of loop
go func() {
for idx := start; idx < end; idx++ {
fun(idx, tsr...)
}
wait.Done()
}()
}
wait.Wait()
}
// NFirstRows is an N function for Vectorize that returns the number of
// outer-dimension rows (or Indexes) of the first tensor.
func NFirstRows(tsr ...Tensor) int {
if len(tsr) == 0 {
return 0
}
return tsr[0].DimSize(0)
}
// NFirstLen is an N function for Vectorize that returns the number of
// elements in the tensor, taking into account the Indexes view.
func NFirstLen(tsr ...Tensor) int {
if len(tsr) == 0 {
return 0
}
return tsr[0].Len()
}
// NMinLen is an N function for Vectorize that returns the min number of
// elements across given number of tensors in the list. Use a closure
// to call this with the nt.
func NMinLen(nt int, tsr ...Tensor) int {
nt = min(len(tsr), nt)
if nt == 0 {
return 0
}
n := tsr[0].Len()
for i := 1; i < nt; i++ {
n = min(n, tsr[0].Len())
}
return n
}
// Copyright (c) 2019, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensorcore
import (
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/core"
"cogentcore.org/core/math32/minmax"
)
// Layout are layout options for displaying tensors.
type Layout struct { //types:add --setters
// OddRow means that even-numbered dimensions are displayed as Y*X rectangles.
// This determines along which dimension to display any remaining
// odd dimension: OddRow = true = organize vertically along row
// dimension, false = organize horizontally across column dimension.
OddRow bool
// TopZero means that the Y=0 coordinate is displayed from the top-down;
// otherwise the Y=0 coordinate is displayed from the bottom up,
// which is typical for emergent network patterns.
TopZero bool
// Image will display the data as a bitmap image. If a 2D tensor, then it will
// be a greyscale image. If a 3D tensor with size of either the first
// or last dim = either 3 or 4, then it is a RGB(A) color image.
Image bool
}
// GridStyle are options for displaying tensors
type GridStyle struct { //types:add --setters
Layout
// Range to plot
Range minmax.Range64 `display:"inline"`
// MinMax has the actual range of data, if not using fixed Range.
MinMax minmax.F64 `display:"inline"`
// ColorMap is the name of the color map to use in translating values to colors.
ColorMap core.ColorMapName
// GridFill sets proportion of grid square filled by the color block:
// 1 = all, .5 = half, etc.
GridFill float32 `min:"0.1" max:"1" step:"0.1" default:"0.9,1"`
// DimExtra is the amount of extra space to add at dimension boundaries,
// as a proportion of total grid size.
DimExtra float32 `min:"0" max:"1" step:"0.02" default:"0.1,0.3"`
// Size sets the minimum and maximum size for grid squares.
Size minmax.F32 `display:"inline"`
// TotalSize sets the total preferred display size along largest dimension.
// Grid squares will be sized to fit within this size,
// subject to the Size.Min / Max constraints, which have precedence.
TotalSize float32
// FontSize is the font size in standard point units for labels.
FontSize float32
}
// Defaults sets defaults for values that are at nonsensical initial values
func (gs *GridStyle) Defaults() {
gs.Range.SetMin(-1).SetMax(1)
gs.ColorMap = "ColdHot"
gs.GridFill = 0.9
gs.DimExtra = 0.3
gs.Size.Set(2, 32)
gs.TotalSize = 100
gs.FontSize = 24
}
// NewGridStyle returns a new GridStyle with defaults.
func NewGridStyle() *GridStyle {
gs := &GridStyle{}
gs.Defaults()
return gs
}
func (gs *GridStyle) ApplyStylersFrom(obj any) {
st := GetGridStylersFrom(obj)
if st == nil {
return
}
st.Run(gs)
}
// GridStylers is a list of styling functions that set GridStyle properties.
// These are called in the order added.
type GridStylers []func(s *GridStyle)
// Add Adds a styling function to the list.
func (st *GridStylers) Add(f func(s *GridStyle)) {
*st = append(*st, f)
}
// Run runs the list of styling functions on given [GridStyle] object.
func (st *GridStylers) Run(s *GridStyle) {
for _, f := range *st {
f(s)
}
}
// SetGridStylersTo sets the [GridStylers] into given object's [metadata].
func SetGridStylersTo(obj any, st GridStylers) {
metadata.SetTo(obj, "GridStylers", st)
}
// GetGridStylersFrom returns [GridStylers] from given object's [metadata].
// Returns nil if none or no metadata.
func GetGridStylersFrom(obj any) GridStylers {
st, _ := metadata.GetFrom[GridStylers](obj, "GridStylers")
return st
}
// AddGridStylerTo adds the given [GridStyler] function into given object's [metadata].
func AddGridStylerTo(obj any, f func(s *GridStyle)) {
st := GetGridStylersFrom(obj)
st.Add(f)
SetGridStylersTo(obj, st)
}
// Copyright (c) 2023, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package tensorcore provides GUI Cogent Core widgets for tensor types.
package tensorcore
//go:generate core generate
import (
"bytes"
"encoding/csv"
"fmt"
"image"
"log"
"strconv"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/fileinfo"
"cogentcore.org/core/base/fileinfo/mimedata"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/icons"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/states"
"cogentcore.org/core/styles/units"
"cogentcore.org/core/tree"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
)
// Table provides a GUI widget for representing [table.Table] values.
type Table struct {
core.ListBase
// Table is the table that we're a view of.
Table *table.Table `set:"-"`
// GridStyle has global grid display styles. GridStylers on the Table
// are applied to this on top of defaults.
GridStyle GridStyle `set:"-"`
// ColumnGridStyle has per column grid display styles.
ColumnGridStyle map[int]*GridStyle `set:"-"`
// current sort index.
SortIndex int
// whether current sort order is descending.
SortDescending bool
// number of columns in table (as of last update).
nCols int `edit:"-"`
// headerWidths has number of characters in each header, per visfields.
headerWidths []int `copier:"-" display:"-" json:"-" xml:"-"`
// colMaxWidths records maximum width in chars of string type fields.
colMaxWidths []int `set:"-" copier:"-" json:"-" xml:"-"`
// blank values for out-of-range rows.
blankString string
blankFloat float64
// blankCells has per column blank tensor cells.
blankCells map[int]*tensor.Float64 `set:"-"`
}
// check for interface impl
var _ core.Lister = (*Table)(nil)
func (tb *Table) Init() {
tb.ListBase.Init()
tb.SortIndex = -1
tb.GridStyle.Defaults()
tb.ColumnGridStyle = map[int]*GridStyle{}
tb.blankCells = map[int]*tensor.Float64{}
tb.Makers.Normal[0] = func(p *tree.Plan) { // TODO: reduce redundancy with ListBase Maker
svi := tb.This.(core.Lister)
svi.UpdateSliceSize()
scrollTo := -1
if tb.InitSelectedIndex >= 0 {
tb.SelectedIndex = tb.InitSelectedIndex
tb.InitSelectedIndex = -1
scrollTo = tb.SelectedIndex
}
if scrollTo >= 0 {
tb.ScrollToIndex(scrollTo)
}
tb.UpdateStartIndex()
tb.UpdateMaxWidths()
tb.Updater(func() {
tb.UpdateStartIndex()
})
tb.MakeHeader(p)
tb.MakeGrid(p, func(p *tree.Plan) {
for i := 0; i < tb.VisibleRows; i++ {
svi.MakeRow(p, i)
}
})
}
}
func (tb *Table) SliceIndex(i int) (si, vi int, invis bool) {
si = tb.StartIndex + i
vi = -1
if si < tb.Table.NumRows() {
vi = tb.Table.RowIndex(si)
}
invis = vi < 0
return
}
// StyleValue performs additional value widget styling
func (tb *Table) StyleValue(w core.Widget, s *styles.Style, row, col int) {
hw := float32(tb.headerWidths[col])
if col == tb.SortIndex {
hw += 6
}
if len(tb.colMaxWidths) > col {
hw = max(float32(tb.colMaxWidths[col]), hw)
}
hv := units.Ch(hw)
s.Min.X.Value = max(s.Min.X.Value, hv.Convert(s.Min.X.Unit, &s.UnitContext).Value)
s.SetTextWrap(false)
}
// SetTable sets the source table that we are viewing, using a sequential view,
// and then configures the display
func (tb *Table) SetTable(dt *table.Table) *Table {
if dt == nil {
tb.Table = nil
} else {
tb.Table = table.NewView(dt)
tb.GridStyle.ApplyStylersFrom(tb.Table)
}
tb.This.(core.Lister).UpdateSliceSize()
tb.SetSliceBase()
tb.Update()
return tb
}
// SetSlice sets the source table to a [table.NewSliceTable]
// from the given slice.
func (tb *Table) SetSlice(sl any) *Table {
return tb.SetTable(errors.Log1(table.NewSliceTable(sl)))
}
// AsyncUpdateTable updates the display for asynchronous updating from
// other goroutines. Also updates indexview (calling Sequential).
func (tb *Table) AsyncUpdateTable() {
tb.AsyncLock()
tb.Table.Sequential()
tb.ScrollToIndexNoUpdate(tb.SliceSize - 1)
tb.Update()
tb.AsyncUnlock()
}
func (tb *Table) UpdateSliceSize() int {
tb.Table.ValidIndexes() // table could have changed
if tb.Table.NumRows() == 0 {
tb.Table.Sequential()
}
tb.SliceSize = tb.Table.NumRows()
tb.nCols = tb.Table.NumColumns()
return tb.SliceSize
}
func (tb *Table) UpdateMaxWidths() {
if len(tb.headerWidths) != tb.nCols {
tb.headerWidths = make([]int, tb.nCols)
tb.colMaxWidths = make([]int, tb.nCols)
}
if tb.SliceSize == 0 {
return
}
for fli := 0; fli < tb.nCols; fli++ {
tb.colMaxWidths[fli] = 0
col := tb.Table.Columns.Values[fli]
stsr, isstr := col.(*tensor.String)
if !isstr {
continue
}
mxw := 0
for _, ixi := range tb.Table.Indexes {
if ixi >= 0 {
sval := stsr.Values[ixi]
mxw = max(mxw, len(sval))
}
}
tb.colMaxWidths[fli] = mxw
}
}
func (tb *Table) MakeHeader(p *tree.Plan) {
tree.AddAt(p, "header", func(w *core.Frame) {
core.ToolbarStyles(w)
w.FinalStyler(func(s *styles.Style) {
s.Padding.Zero()
s.Grow.Set(0, 0)
s.Gap.Set(units.Em(0.5)) // matches grid default
})
w.Maker(func(p *tree.Plan) {
if tb.ShowIndexes {
tree.AddAt(p, "_head-index", func(w *core.Text) { // TODO: is not working
w.SetType(core.TextBodyMedium)
w.Styler(func(s *styles.Style) {
s.Align.Self = styles.Center
})
w.SetText("Index")
})
}
for fli := 0; fli < tb.nCols; fli++ {
field := tb.Table.Columns.Keys[fli]
tree.AddAt(p, "head-"+field, func(w *core.Button) {
w.SetType(core.ButtonAction)
w.Styler(func(s *styles.Style) {
s.Justify.Content = styles.Start
})
w.OnClick(func(e events.Event) {
tb.SortColumn(fli)
})
if tb.Table.Columns.Values[fli].NumDims() > 1 {
w.AddContextMenu(func(m *core.Scene) {
core.NewButton(m).SetText("Edit grid style").SetIcon(icons.Edit).
OnClick(func(e events.Event) {
tb.EditGridStyle(fli)
})
})
}
w.Updater(func() {
field := tb.Table.Columns.Keys[fli]
w.SetText(field).SetTooltip(field + " (tap to sort by)")
tb.headerWidths[fli] = len(field)
if fli == tb.SortIndex {
if tb.SortDescending {
w.SetIndicator(icons.KeyboardArrowDown)
} else {
w.SetIndicator(icons.KeyboardArrowUp)
}
} else {
w.SetIndicator(icons.Blank)
}
})
})
}
})
})
}
// SliceHeader returns the Frame header for slice grid
func (tb *Table) SliceHeader() *core.Frame {
return tb.Child(0).(*core.Frame)
}
// RowWidgetNs returns number of widgets per row and offset for index label
func (tb *Table) RowWidgetNs() (nWidgPerRow, idxOff int) {
nWidgPerRow = 1 + tb.nCols
idxOff = 1
if !tb.ShowIndexes {
nWidgPerRow -= 1
idxOff = 0
}
return
}
func (tb *Table) MakeRow(p *tree.Plan, i int) {
svi := tb.This.(core.Lister)
si, _, invis := svi.SliceIndex(i)
itxt := strconv.Itoa(i)
if tb.ShowIndexes {
tb.MakeGridIndex(p, i, si, itxt, invis)
}
for fli := 0; fli < tb.nCols; fli++ {
col := tb.Table.Columns.Values[fli]
valnm := fmt.Sprintf("value-%v.%v", fli, itxt)
_, isstr := col.(*tensor.String)
if col.NumDims() == 1 {
str := ""
fval := float64(0)
tree.AddNew(p, valnm, func() core.Value {
if isstr {
return core.NewValue(&str, "")
} else {
return core.NewValue(&fval, "")
}
}, func(w core.Value) {
wb := w.AsWidget()
tb.MakeValue(w, i)
w.AsTree().SetProperty(core.ListColProperty, fli)
if !tb.IsReadOnly() {
wb.OnChange(func(e events.Event) {
_, vi, invis := svi.SliceIndex(i)
if !invis {
if isstr {
col.SetString1D(str, vi)
} else {
col.SetFloat1D(fval, vi)
}
}
tb.This.(core.Lister).UpdateMaxWidths()
tb.SendChange()
})
}
wb.Updater(func() {
_, vi, invis := svi.SliceIndex(i)
if !invis {
if isstr {
str = col.String1D(vi)
core.Bind(&str, w)
} else {
fval = col.Float1D(vi)
core.Bind(&fval, w)
}
} else {
if isstr {
core.Bind(tb.blankString, w)
} else {
core.Bind(tb.blankFloat, w)
}
}
wb.SetReadOnly(tb.IsReadOnly())
wb.SetState(invis, states.Invisible)
if svi.HasStyler() {
w.Style()
}
if invis {
wb.SetSelected(false)
}
})
})
} else {
tree.AddAt(p, valnm, func(w *TensorGrid) {
w.SetReadOnly(tb.IsReadOnly())
wb := w.AsWidget()
w.SetProperty(core.ListRowProperty, i)
w.SetProperty(core.ListColProperty, fli)
w.Styler(func(s *styles.Style) {
s.Grow.Set(0, 0)
})
wb.Updater(func() {
si, vi, invis := svi.SliceIndex(i)
var cell tensor.Tensor
if invis {
cell = tb.blankCell(fli, col)
} else {
cell = col.RowTensor(vi)
}
wb.ValueTitle = tb.ValueTitle + "[" + strconv.Itoa(si) + "]"
w.SetState(invis, states.Invisible)
w.SetTensor(cell)
w.GridStyle = *tb.GetColumnGridStyle(fli)
})
})
}
}
}
// blankCell returns tensor blanks for given tensor col
func (tb *Table) blankCell(cidx int, col tensor.Tensor) *tensor.Float64 {
if ctb, has := tb.blankCells[cidx]; has {
return ctb
}
ctb := tensor.New[float64](col.ShapeSizes()...).(*tensor.Float64)
tb.blankCells[cidx] = ctb
return ctb
}
// GetColumnGridStyle gets grid style for given column.
func (tb *Table) GetColumnGridStyle(col int) *GridStyle {
if ctd, has := tb.ColumnGridStyle[col]; has {
return ctd
}
ctd := &GridStyle{}
*ctd = tb.GridStyle
if tb.Table != nil {
cl := tb.Table.Columns.Values[col]
ctd.ApplyStylersFrom(cl)
}
return ctd
}
// NewAt inserts a new blank element at given index in the slice -- -1
// means the end
func (tb *Table) NewAt(idx int) {
tb.NewAtSelect(idx)
tb.Table.InsertRows(idx, 1)
tb.SelectIndexEvent(idx, events.SelectOne)
tb.Update()
tb.IndexGrabFocus(idx)
}
// DeleteAt deletes element at given index from slice
func (tb *Table) DeleteAt(idx int) {
if idx < 0 || idx >= tb.SliceSize {
return
}
tb.DeleteAtSelect(idx)
tb.Table.DeleteRows(idx, 1)
tb.Update()
}
// SortColumn sorts the slice for given column index.
// Toggles ascending vs. descending if already sorting on this dimension.
func (tb *Table) SortColumn(fldIndex int) {
sgh := tb.SliceHeader()
_, idxOff := tb.RowWidgetNs()
for fli := 0; fli < tb.nCols; fli++ {
hdr := sgh.Child(idxOff + fli).(*core.Button)
hdr.SetType(core.ButtonAction)
if fli == fldIndex {
if tb.SortIndex == fli {
tb.SortDescending = !tb.SortDescending
} else {
tb.SortDescending = false
}
}
}
tb.SortIndex = fldIndex
if fldIndex == -1 {
tb.Table.SortIndexes()
} else {
tb.Table.IndexesNeeded()
col := tb.Table.ColumnByIndex(tb.SortIndex)
col.Sort(!tb.SortDescending)
tb.Table.IndexesFromTensor(col)
}
tb.Update() // requires full update due to sort button icon
}
// EditGridStyle shows an editor dialog for grid style for given column index.
func (tb *Table) EditGridStyle(col int) {
ctd := tb.GetColumnGridStyle(col)
d := core.NewBody("Tensor grid style")
core.NewForm(d).SetStruct(ctd).
OnChange(func(e events.Event) {
tb.ColumnGridStyle[col] = ctd
tb.Update()
})
core.NewButton(d).SetText("Edit global style").SetIcon(icons.Edit).
OnClick(func(e events.Event) {
tb.EditGlobalGridStyle()
})
d.RunWindowDialog(tb)
}
// EditGlobalGridStyle shows an editor dialog for global grid styles.
func (tb *Table) EditGlobalGridStyle() {
d := core.NewBody("Tensor grid style")
core.NewForm(d).SetStruct(&tb.GridStyle).
OnChange(func(e events.Event) {
tb.Update()
})
d.RunWindowDialog(tb)
}
func (tb *Table) HasStyler() bool { return false }
func (tb *Table) StyleRow(w core.Widget, idx, fidx int) {}
// SortFieldName returns the name of the field being sorted, along with :up or
// :down depending on descending
func (tb *Table) SortFieldName() string {
if tb.SortIndex >= 0 && tb.SortIndex < tb.nCols {
nm := tb.Table.Columns.Keys[tb.SortIndex]
if tb.SortDescending {
nm += ":down"
} else {
nm += ":up"
}
return nm
}
return ""
}
// SetSortField sets sorting to happen on given field and direction -- see
// SortFieldName for details
func (tb *Table) SetSortFieldName(nm string) {
if nm == "" {
return
}
spnm := strings.Split(nm, ":")
got := false
for fli := 0; fli < tb.nCols; fli++ {
fld := tb.Table.Columns.Keys[fli]
if fld == spnm[0] {
got = true
// fmt.Println("sorting on:", fld.Name, fli, "from:", nm)
tb.SortIndex = fli
}
}
if len(spnm) == 2 {
if spnm[1] == "down" {
tb.SortDescending = true
} else {
tb.SortDescending = false
}
}
_ = got
// if got {
// tv.SortSlice()
// }
}
// RowFirstVisWidget returns the first visible widget for given row (could be
// index or not) -- false if out of range
func (tb *Table) RowFirstVisWidget(row int) (*core.WidgetBase, bool) {
if !tb.IsRowInBounds(row) {
return nil, false
}
nWidgPerRow, idxOff := tb.RowWidgetNs()
lg := tb.ListGrid
w := lg.Children[row*nWidgPerRow].(core.Widget).AsWidget()
if w.Geom.TotalBBox != (image.Rectangle{}) {
return w, true
}
ridx := nWidgPerRow * row
for fli := 0; fli < tb.nCols; fli++ {
w := lg.Child(ridx + idxOff + fli).(core.Widget).AsWidget()
if w.Geom.TotalBBox != (image.Rectangle{}) {
return w, true
}
}
return nil, false
}
// RowGrabFocus grabs the focus for the first focusable widget in given row --
// returns that element or nil if not successful -- note: grid must have
// already rendered for focus to be grabbed!
func (tb *Table) RowGrabFocus(row int) *core.WidgetBase {
if !tb.IsRowInBounds(row) || tb.InFocusGrab { // range check
return nil
}
nWidgPerRow, idxOff := tb.RowWidgetNs()
ridx := nWidgPerRow * row
lg := tb.ListGrid
// first check if we already have focus
for fli := 0; fli < tb.nCols; fli++ {
w := lg.Child(ridx + idxOff + fli).(core.Widget).AsWidget()
if w.StateIs(states.Focused) || w.ContainsFocus() {
return w
}
}
tb.InFocusGrab = true
defer func() { tb.InFocusGrab = false }()
for fli := 0; fli < tb.nCols; fli++ {
w := lg.Child(ridx + idxOff + fli).(core.Widget).AsWidget()
if w.CanFocus() {
w.SetFocus()
return w
}
}
return nil
}
//////// Header layout
func (tb *Table) SizeFinal() {
tb.ListBase.SizeFinal()
lg := tb.ListGrid
sh := tb.SliceHeader()
sh.ForWidgetChildren(func(i int, cw core.Widget, cwb *core.WidgetBase) bool {
sgb := core.AsWidget(lg.Child(i))
gsz := &sgb.Geom.Size
if gsz.Actual.Total.X == 0 {
return tree.Continue
}
ksz := &cwb.Geom.Size
ksz.Actual.Total.X = gsz.Actual.Total.X
ksz.Actual.Content.X = gsz.Actual.Content.X
ksz.Alloc.Total.X = gsz.Alloc.Total.X
ksz.Alloc.Content.X = gsz.Alloc.Content.X
return tree.Continue
})
gsz := &lg.Geom.Size
ksz := &sh.Geom.Size
if gsz.Actual.Total.X > 0 {
ksz.Actual.Total.X = gsz.Actual.Total.X
ksz.Actual.Content.X = gsz.Actual.Content.X
ksz.Alloc.Total.X = gsz.Alloc.Total.X
ksz.Alloc.Content.X = gsz.Alloc.Content.X
}
}
// SelectedColumnStrings returns the string values of given column name.
func (tb *Table) SelectedColumnStrings(colName string) []string {
dt := tb.Table
jis := tb.SelectedIndexesList(false)
if len(jis) == 0 || dt == nil {
return nil
}
var s []string
col := dt.Column(colName)
for _, i := range jis {
v := col.StringRow(i, 0)
s = append(s, v)
}
return s
}
//////// Copy / Cut / Paste
func (tb *Table) MakeToolbar(p *tree.Plan) {
if tb.Table == nil {
return
}
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(tb.Table.AddRows).SetIcon(icons.Add)
w.SetAfterFunc(func() { tb.Update() })
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(tb.Table.SortColumns).SetText("Sort").SetIcon(icons.Sort)
w.SetAfterFunc(func() { tb.Update() })
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(tb.Table.FilterString).SetText("Filter").SetIcon(icons.FilterAlt)
w.SetAfterFunc(func() { tb.Update() })
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(tb.Table.Sequential).SetText("Unfilter").SetIcon(icons.FilterAltOff)
w.SetAfterFunc(func() { tb.Update() })
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(tb.Table.OpenCSV).SetIcon(icons.Open)
w.SetAfterFunc(func() { tb.Update() })
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(tb.Table.SaveCSV).SetIcon(icons.Save)
w.SetAfterFunc(func() { tb.Update() })
})
}
func (tb *Table) MimeDataType() string {
return fileinfo.DataCsv
}
// CopySelectToMime copies selected rows to mime data
func (tb *Table) CopySelectToMime() mimedata.Mimes {
nitms := len(tb.SelectedIndexes)
if nitms == 0 {
return nil
}
ix := table.NewView(tb.Table)
idx := tb.SelectedIndexesList(false) // ascending
iidx := make([]int, len(idx))
for i, di := range idx {
iidx[i] = tb.Table.RowIndex(di)
}
ix.Indexes = iidx
var b bytes.Buffer
ix.WriteCSV(&b, tensor.Tab, table.Headers)
md := mimedata.NewTextBytes(b.Bytes())
md[0].Type = fileinfo.DataCsv
return md
}
// FromMimeData returns records from csv of mime data
func (tb *Table) FromMimeData(md mimedata.Mimes) [][]string {
var recs [][]string
for _, d := range md {
if d.Type == fileinfo.DataCsv {
b := bytes.NewBuffer(d.Data)
cr := csv.NewReader(b)
cr.Comma = tensor.Tab.Rune()
rec, err := cr.ReadAll()
if err != nil || len(rec) == 0 {
log.Printf("Error reading CSV from clipboard: %s\n", err)
return nil
}
recs = append(recs, rec...)
}
}
return recs
}
// PasteAssign assigns mime data (only the first one!) to this idx
func (tb *Table) PasteAssign(md mimedata.Mimes, idx int) {
recs := tb.FromMimeData(md)
if len(recs) == 0 {
return
}
tb.Table.ReadCSVRow(recs[1], tb.Table.RowIndex(idx))
tb.UpdateChange()
}
// PasteAtIndex inserts object(s) from mime data at (before) given slice index
// adds to end of table
func (tb *Table) PasteAtIndex(md mimedata.Mimes, idx int) {
recs := tb.FromMimeData(md)
nr := len(recs) - 1
if nr <= 0 {
return
}
tb.Table.InsertRows(idx, nr)
for ri := 0; ri < nr; ri++ {
rec := recs[1+ri]
rw := tb.Table.RowIndex(idx + ri)
tb.Table.ReadCSVRow(rec, rw)
}
tb.SendChange()
tb.SelectIndexEvent(idx, events.SelectOne)
tb.Update()
}
// Copyright (c) 2023, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package tensorcore provides GUI Cogent Core widgets for tensor types.
package tensorcore
import (
"fmt"
"image"
"strconv"
"cogentcore.org/core/base/fileinfo"
"cogentcore.org/core/base/fileinfo/mimedata"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/icons"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/states"
"cogentcore.org/core/styles/units"
"cogentcore.org/core/tree"
"cogentcore.org/lab/tensor"
)
// TensorEditor provides a GUI widget for representing [tensor.Tensor] values.
type TensorEditor struct {
core.ListBase
// the tensor that we're a view of
Tensor tensor.Tensor `set:"-"`
// overall layout options for tensor display
Layout Layout `set:"-"`
// number of columns in table (as of last update)
NCols int `edit:"-"`
// headerWidths has number of characters in each header, per visfields
headerWidths []int `copier:"-" display:"-" json:"-" xml:"-"`
// colMaxWidths records maximum width in chars of string type fields
colMaxWidths []int `set:"-" copier:"-" json:"-" xml:"-"`
// blank values for out-of-range rows
BlankString string
BlankFloat float64
}
// check for interface impl
var _ core.Lister = (*TensorEditor)(nil)
func (tb *TensorEditor) Init() {
tb.ListBase.Init()
tb.Layout.OddRow = true
tb.Makers.Normal[0] = func(p *tree.Plan) { // TODO: reduce redundancy with ListBase Maker
svi := tb.This.(core.Lister)
svi.UpdateSliceSize()
scrollTo := -1
if tb.InitSelectedIndex >= 0 {
tb.SelectedIndex = tb.InitSelectedIndex
tb.InitSelectedIndex = -1
scrollTo = tb.SelectedIndex
}
if scrollTo >= 0 {
tb.ScrollToIndex(scrollTo)
}
tb.UpdateStartIndex()
tb.UpdateMaxWidths()
tb.Updater(func() {
if tb.Tensor.NumDims() == 1 {
tb.Layout.TopZero = true
}
tb.UpdateStartIndex()
})
tb.MakeHeader(p)
tb.MakeGrid(p, func(p *tree.Plan) {
for i := 0; i < tb.VisibleRows; i++ {
svi.MakeRow(p, i)
}
})
}
}
func (tb *TensorEditor) SliceIndex(i int) (si, vi int, invis bool) {
si = tb.StartIndex + i
vi = si
invis = si >= tb.SliceSize
if !tb.Layout.TopZero {
vi = (tb.SliceSize - 1) - si
}
return
}
// StyleValue performs additional value widget styling
func (tb *TensorEditor) StyleValue(w core.Widget, s *styles.Style, row, col int) {
hw := float32(tb.headerWidths[col])
if len(tb.colMaxWidths) > col {
hw = max(float32(tb.colMaxWidths[col]), hw)
}
hv := units.Ch(hw)
s.Min.X.Value = max(s.Min.X.Value, hv.Convert(s.Min.X.Unit, &s.UnitContext).Value)
s.SetTextWrap(false)
}
// SetTensor sets the source tensor that we are viewing,
// and then configures the display.
func (tb *TensorEditor) SetTensor(et tensor.Tensor) *TensorEditor {
if et == nil {
return nil
}
tb.Tensor = et
tb.This.(core.Lister).UpdateSliceSize()
tb.SetSliceBase()
tb.Update()
return tb
}
func (tb *TensorEditor) UpdateSliceSize() int {
tb.SliceSize, tb.NCols, _, _ = tensor.Projection2DShape(tb.Tensor.Shape(), tb.Layout.OddRow)
return tb.SliceSize
}
func (tb *TensorEditor) UpdateMaxWidths() {
if len(tb.headerWidths) != tb.NCols {
tb.headerWidths = make([]int, tb.NCols)
tb.colMaxWidths = make([]int, tb.NCols)
}
if tb.SliceSize == 0 {
return
}
_, isstr := tb.Tensor.(*tensor.String)
for fli := 0; fli < tb.NCols; fli++ {
tb.colMaxWidths[fli] = 0
if !isstr {
continue
}
mxw := 0
// for _, ixi := range tb.Tensor.Indexes {
// if ixi >= 0 {
// sval := stsr.Values[ixi]
// mxw = max(mxw, len(sval))
// }
// }
tb.colMaxWidths[fli] = mxw
}
}
func (tb *TensorEditor) MakeHeader(p *tree.Plan) {
tree.AddAt(p, "header", func(w *core.Frame) {
core.ToolbarStyles(w)
w.FinalStyler(func(s *styles.Style) {
s.Padding.Zero()
s.Grow.Set(0, 0)
s.Gap.Set(units.Em(0.5)) // matches grid default
})
w.Maker(func(p *tree.Plan) {
if tb.ShowIndexes {
tree.AddAt(p, "_head-index", func(w *core.Text) { // TODO: is not working
w.SetType(core.TextBodyMedium)
w.Styler(func(s *styles.Style) {
s.Align.Self = styles.Center
})
w.SetText("Index")
})
}
for fli := 0; fli < tb.NCols; fli++ {
hdr := tb.ColumnHeader(fli)
tree.AddAt(p, "head-"+hdr, func(w *core.Button) {
w.SetType(core.ButtonAction)
w.Styler(func(s *styles.Style) {
s.Justify.Content = styles.Start
})
w.Updater(func() {
hdr := tb.ColumnHeader(fli)
w.SetText(hdr).SetTooltip(hdr)
tb.headerWidths[fli] = len(hdr)
})
})
}
})
})
}
func (tb *TensorEditor) ColumnHeader(col int) string {
_, cc := tensor.Projection2DCoords(tb.Tensor.Shape(), tb.Layout.OddRow, 0, col)
sitxt := ""
for i, ccc := range cc {
sitxt += fmt.Sprintf("%03d", ccc)
if i < len(cc)-1 {
sitxt += ","
}
}
return sitxt
}
// SliceHeader returns the Frame header for slice grid
func (tb *TensorEditor) SliceHeader() *core.Frame {
return tb.Child(0).(*core.Frame)
}
// RowWidgetNs returns number of widgets per row and offset for index label
func (tb *TensorEditor) RowWidgetNs() (nWidgPerRow, idxOff int) {
nWidgPerRow = 1 + tb.NCols
idxOff = 1
if !tb.ShowIndexes {
nWidgPerRow -= 1
idxOff = 0
}
return
}
func (tb *TensorEditor) MakeRow(p *tree.Plan, i int) {
svi := tb.This.(core.Lister)
si, _, invis := svi.SliceIndex(i)
itxt := strconv.Itoa(i)
if tb.ShowIndexes {
tb.MakeGridIndex(p, i, si, itxt, invis)
}
_, isstr := tb.Tensor.(*tensor.String)
for fli := 0; fli < tb.NCols; fli++ {
valnm := fmt.Sprintf("value-%v.%v", fli, itxt)
fval := float64(0)
str := ""
tree.AddNew(p, valnm, func() core.Value {
if isstr {
return core.NewValue(&str, "")
} else {
return core.NewValue(&fval, "")
}
}, func(w core.Value) {
wb := w.AsWidget()
tb.MakeValue(w, i)
w.AsTree().SetProperty(core.ListColProperty, fli)
if !tb.IsReadOnly() {
wb.OnChange(func(e events.Event) {
_, vi, invis := svi.SliceIndex(i)
if !invis {
if isstr {
tensor.Projection2DSetString(tb.Tensor, tb.Layout.OddRow, vi, fli, str)
} else {
tensor.Projection2DSet(tb.Tensor, tb.Layout.OddRow, vi, fli, fval)
}
}
tb.This.(core.Lister).UpdateMaxWidths()
tb.SendChange()
})
}
wb.Updater(func() {
_, vi, invis := svi.SliceIndex(i)
if !invis {
if isstr {
str = tensor.Projection2DString(tb.Tensor, tb.Layout.OddRow, vi, fli)
core.Bind(&str, w)
} else {
fval = tensor.Projection2DValue(tb.Tensor, tb.Layout.OddRow, vi, fli)
core.Bind(&fval, w)
}
} else {
if isstr {
core.Bind(tb.BlankString, w)
} else {
core.Bind(tb.BlankFloat, w)
}
}
wb.SetReadOnly(tb.IsReadOnly())
wb.SetState(invis, states.Invisible)
if svi.HasStyler() {
w.Style()
}
if invis {
wb.SetSelected(false)
}
})
})
}
}
func (tb *TensorEditor) HasStyler() bool { return false }
func (tb *TensorEditor) StyleRow(w core.Widget, idx, fidx int) {}
// RowFirstVisWidget returns the first visible widget for given row (could be
// index or not) -- false if out of range
func (tb *TensorEditor) RowFirstVisWidget(row int) (*core.WidgetBase, bool) {
if !tb.IsRowInBounds(row) {
return nil, false
}
nWidgPerRow, idxOff := tb.RowWidgetNs()
lg := tb.ListGrid
w := lg.Children[row*nWidgPerRow].(core.Widget).AsWidget()
if w.Geom.TotalBBox != (image.Rectangle{}) {
return w, true
}
ridx := nWidgPerRow * row
for fli := 0; fli < tb.NCols; fli++ {
w := lg.Child(ridx + idxOff + fli).(core.Widget).AsWidget()
if w.Geom.TotalBBox != (image.Rectangle{}) {
return w, true
}
}
return nil, false
}
// RowGrabFocus grabs the focus for the first focusable widget in given row --
// returns that element or nil if not successful -- note: grid must have
// already rendered for focus to be grabbed!
func (tb *TensorEditor) RowGrabFocus(row int) *core.WidgetBase {
if !tb.IsRowInBounds(row) || tb.InFocusGrab { // range check
return nil
}
nWidgPerRow, idxOff := tb.RowWidgetNs()
ridx := nWidgPerRow * row
lg := tb.ListGrid
// first check if we already have focus
for fli := 0; fli < tb.NCols; fli++ {
w := lg.Child(ridx + idxOff + fli).(core.Widget).AsWidget()
if w.StateIs(states.Focused) || w.ContainsFocus() {
return w
}
}
tb.InFocusGrab = true
defer func() { tb.InFocusGrab = false }()
for fli := 0; fli < tb.NCols; fli++ {
w := lg.Child(ridx + idxOff + fli).(core.Widget).AsWidget()
if w.CanFocus() {
w.SetFocus()
return w
}
}
return nil
}
/////// Header layout
func (tb *TensorEditor) SizeFinal() {
tb.ListBase.SizeFinal()
lg := tb.ListGrid
sh := tb.SliceHeader()
sh.ForWidgetChildren(func(i int, cw core.Widget, cwb *core.WidgetBase) bool {
sgb := core.AsWidget(lg.Child(i))
gsz := &sgb.Geom.Size
if gsz.Actual.Total.X == 0 {
return tree.Continue
}
ksz := &cwb.Geom.Size
ksz.Actual.Total.X = gsz.Actual.Total.X
ksz.Actual.Content.X = gsz.Actual.Content.X
ksz.Alloc.Total.X = gsz.Alloc.Total.X
ksz.Alloc.Content.X = gsz.Alloc.Content.X
return tree.Continue
})
gsz := &lg.Geom.Size
ksz := &sh.Geom.Size
if gsz.Actual.Total.X > 0 {
ksz.Actual.Total.X = gsz.Actual.Total.X
ksz.Actual.Content.X = gsz.Actual.Content.X
ksz.Alloc.Total.X = gsz.Alloc.Total.X
ksz.Alloc.Content.X = gsz.Alloc.Content.X
}
}
//////// Copy / Cut / Paste
// SaveTSV writes a tensor to a tab-separated-values (TSV) file.
// Outer-most dims are rows in the file, and inner-most is column --
// Reading just grabs all values and doesn't care about shape.
func (tb *TensorEditor) SaveCSV(filename core.Filename) error { //types:add
return tensor.SaveCSV(tb.Tensor, fsx.Filename(filename), tensor.Tab)
}
// OpenTSV reads a tensor from a tab-separated-values (TSV) file.
// using the Go standard encoding/csv reader conforming
// to the official CSV standard.
// Reads all values and assigns as many as fit.
func (tb *TensorEditor) OpenCSV(filename core.Filename) error { //types:add
return tensor.OpenCSV(tb.Tensor, fsx.Filename(filename), tensor.Tab)
}
func (tb *TensorEditor) MakeToolbar(p *tree.Plan) {
if tb.Tensor == nil {
return
}
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(tb.OpenCSV).SetIcon(icons.Open)
w.SetAfterFunc(func() { tb.Update() })
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(tb.SaveCSV).SetIcon(icons.Save)
w.SetAfterFunc(func() { tb.Update() })
})
}
func (tb *TensorEditor) MimeDataType() string {
return fileinfo.DataCsv
}
// CopySelectToMime copies selected rows to mime data
func (tb *TensorEditor) CopySelectToMime() mimedata.Mimes {
nitms := len(tb.SelectedIndexes)
if nitms == 0 {
return nil
}
// idx := tb.SelectedIndexesList(false) // ascending
// var b bytes.Buffer
// ix.WriteCSV(&b, tensor.Tab, table.Headers)
// md := mimedata.NewTextBytes(b.Bytes())
// md[0].Type = fileinfo.DataCsv
// return md
return nil
}
// FromMimeData returns records from csv of mime data
func (tb *TensorEditor) FromMimeData(md mimedata.Mimes) [][]string {
var recs [][]string
for _, d := range md {
if d.Type == fileinfo.DataCsv {
// b := bytes.NewBuffer(d.Data)
// cr := csv.NewReader(b)
// cr.Comma = tensor.Tab.Rune()
// rec, err := cr.ReadAll()
// if err != nil || len(rec) == 0 {
// log.Printf("Error reading CSV from clipboard: %s\n", err)
// return nil
// }
// recs = append(recs, rec...)
}
}
return recs
}
// PasteAssign assigns mime data (only the first one!) to this idx
func (tb *TensorEditor) PasteAssign(md mimedata.Mimes, idx int) {
// recs := tb.FromMimeData(md)
// if len(recs) == 0 {
// return
// }
// tb.Tensor.ReadCSVRow(recs[1], tb.Tensor.Indexes[idx])
// tb.UpdateChange()
}
// PasteAtIndex inserts object(s) from mime data at (before) given slice index
// adds to end of table
func (tb *TensorEditor) PasteAtIndex(md mimedata.Mimes, idx int) {
// recs := tb.FromMimeData(md)
// nr := len(recs) - 1
// if nr <= 0 {
// return
// }
// tb.Tensor.InsertRows(idx, nr)
// for ri := 0; ri < nr; ri++ {
// rec := recs[1+ri]
// rw := tb.Tensor.Indexes[idx+ri]
// tb.Tensor.ReadCSVRow(rec, rw)
// }
// tb.SendChange()
// tb.SelectIndexEvent(idx, events.SelectOne)
// tb.Update()
}
// Copyright (c) 2019, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensorcore
import (
"image/color"
"log"
"cogentcore.org/core/colors"
"cogentcore.org/core/colors/colormap"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/icons"
"cogentcore.org/core/math32"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/abilities"
"cogentcore.org/core/styles/units"
"cogentcore.org/lab/tensor"
)
// TensorGrid is a widget that displays tensor values as a grid of colored squares.
type TensorGrid struct {
core.WidgetBase
// Tensor is the tensor that we view.
Tensor tensor.Tensor `set:"-"`
// GridStyle has grid display style properties.
GridStyle GridStyle
// ColorMap is the colormap displayed (based on)
ColorMap *colormap.Map
}
func (tg *TensorGrid) WidgetValue() any { return &tg.Tensor }
func (tg *TensorGrid) SetWidgetValue(value any) error {
tg.SetTensor(value.(tensor.Tensor))
return nil
}
func (tg *TensorGrid) Init() {
tg.WidgetBase.Init()
tg.GridStyle.Defaults()
tg.Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.DoubleClickable)
ms := tg.MinSize()
s.Min.Set(units.Dot(ms.X), units.Dot(ms.Y))
s.Grow.Set(1, 1)
})
tg.OnDoubleClick(func(e events.Event) {
tg.TensorEditor()
})
tg.AddContextMenu(func(m *core.Scene) {
core.NewFuncButton(m).SetFunc(tg.TensorEditor).SetIcon(icons.Edit)
core.NewFuncButton(m).SetFunc(tg.EditStyle).SetIcon(icons.Edit)
})
}
// SetTensor sets the tensor. Must call Update after this.
func (tg *TensorGrid) SetTensor(tsr tensor.Tensor) *TensorGrid {
if _, ok := tsr.(*tensor.String); ok {
log.Printf("TensorGrid: String tensors cannot be displayed using TensorGrid\n")
return tg
}
tg.Tensor = tsr
if tg.Tensor != nil {
tg.GridStyle.ApplyStylersFrom(tg.Tensor)
}
return tg
}
// TensorEditor pulls up a TensorEditor of our tensor
func (tg *TensorGrid) TensorEditor() { //types:add
d := core.NewBody("Tensor editor")
tb := core.NewToolbar(d)
te := NewTensorEditor(d).SetTensor(tg.Tensor)
te.OnChange(func(e events.Event) {
tg.NeedsRender()
})
tb.Maker(te.MakeToolbar)
d.RunWindowDialog(tg)
}
func (tg *TensorGrid) EditStyle() { //types:add
d := core.NewBody("Tensor grid style")
core.NewForm(d).SetStruct(&tg.GridStyle).
OnChange(func(e events.Event) {
tg.NeedsRender()
})
d.RunWindowDialog(tg)
}
// MinSize returns minimum size based on tensor and display settings
func (tg *TensorGrid) MinSize() math32.Vector2 {
if tg.Tensor == nil || tg.Tensor.Len() == 0 {
return math32.Vector2{}
}
if tg.GridStyle.Image {
return math32.Vec2(float32(tg.Tensor.DimSize(1)), float32(tg.Tensor.DimSize(0)))
}
rows, cols, rowEx, colEx := tensor.Projection2DShape(tg.Tensor.Shape(), tg.GridStyle.OddRow)
frw := float32(rows) + float32(rowEx)*tg.GridStyle.DimExtra // extra spacing
fcl := float32(cols) + float32(colEx)*tg.GridStyle.DimExtra // extra spacing
mx := float32(max(frw, fcl))
gsz := tg.GridStyle.TotalSize / mx
gsz = tg.GridStyle.Size.ClampValue(gsz)
gsz = max(gsz, 2)
return math32.Vec2(gsz*float32(fcl), gsz*float32(frw))
}
// EnsureColorMap makes sure there is a valid color map that matches specified name
func (tg *TensorGrid) EnsureColorMap() {
if tg.ColorMap != nil && tg.ColorMap.Name != string(tg.GridStyle.ColorMap) {
tg.ColorMap = nil
}
if tg.ColorMap == nil {
ok := false
tg.ColorMap, ok = colormap.AvailableMaps[string(tg.GridStyle.ColorMap)]
if !ok {
tg.GridStyle.ColorMap = ""
tg.GridStyle.Defaults()
}
tg.ColorMap = colormap.AvailableMaps[string(tg.GridStyle.ColorMap)]
}
}
func (tg *TensorGrid) Color(val float64) (norm float64, clr color.Color) {
if tg.ColorMap.Indexed {
clr = tg.ColorMap.MapIndex(int(val))
} else {
norm = tg.GridStyle.Range.ClipNormValue(val)
clr = tg.ColorMap.Map(float32(norm))
}
return
}
func (tg *TensorGrid) UpdateRange() {
if !tg.GridStyle.Range.FixMin || !tg.GridStyle.Range.FixMax {
min, max, _, _ := tensor.Range(tg.Tensor.AsValues())
if !tg.GridStyle.Range.FixMin {
nmin := minmax.NiceRoundNumber(min, true) // true = below #
tg.GridStyle.Range.Min = nmin
}
if !tg.GridStyle.Range.FixMax {
nmax := minmax.NiceRoundNumber(max, false) // false = above #
tg.GridStyle.Range.Max = nmax
}
}
}
func (tg *TensorGrid) Render() {
if tg.Tensor == nil || tg.Tensor.Len() == 0 {
return
}
tg.EnsureColorMap()
tg.UpdateRange()
pc := &tg.Scene.PaintContext
pos := tg.Geom.Pos.Content
sz := tg.Geom.Size.Actual.Content
// sz.SetSubScalar(tg.Disp.BotRtSpace.Dots)
pc.FillBox(pos, sz, tg.Styles.Background)
tsr := tg.Tensor
if tg.GridStyle.Image {
ysz := tsr.DimSize(0)
xsz := tsr.DimSize(1)
nclr := 1
outclr := false // outer dimension is color
if tsr.NumDims() == 3 {
if tsr.DimSize(0) == 3 || tsr.DimSize(0) == 4 {
outclr = true
ysz = tsr.DimSize(1)
xsz = tsr.DimSize(2)
nclr = tsr.DimSize(0)
} else {
nclr = tsr.DimSize(2)
}
}
tsz := math32.Vec2(float32(xsz), float32(ysz))
gsz := sz.Div(tsz)
for y := 0; y < ysz; y++ {
for x := 0; x < xsz; x++ {
ey := y
if !tg.GridStyle.TopZero {
ey = (ysz - 1) - y
}
switch {
case outclr:
var r, g, b, a float64
a = 1
r = tg.GridStyle.Range.ClipNormValue(tsr.Float(0, y, x))
g = tg.GridStyle.Range.ClipNormValue(tsr.Float(1, y, x))
b = tg.GridStyle.Range.ClipNormValue(tsr.Float(2, y, x))
if nclr > 3 {
a = tg.GridStyle.Range.ClipNormValue(tsr.Float(3, y, x))
}
cr := math32.Vec2(float32(x), float32(ey))
pr := pos.Add(cr.Mul(gsz))
pc.StrokeStyle.Color = colors.Uniform(colors.FromFloat64(r, g, b, a))
pc.FillBox(pr, gsz, pc.StrokeStyle.Color)
case nclr > 1:
var r, g, b, a float64
a = 1
r = tg.GridStyle.Range.ClipNormValue(tsr.Float(y, x, 0))
g = tg.GridStyle.Range.ClipNormValue(tsr.Float(y, x, 1))
b = tg.GridStyle.Range.ClipNormValue(tsr.Float(y, x, 2))
if nclr > 3 {
a = tg.GridStyle.Range.ClipNormValue(tsr.Float(y, x, 3))
}
cr := math32.Vec2(float32(x), float32(ey))
pr := pos.Add(cr.Mul(gsz))
pc.StrokeStyle.Color = colors.Uniform(colors.FromFloat64(r, g, b, a))
pc.FillBox(pr, gsz, pc.StrokeStyle.Color)
default:
val := tg.GridStyle.Range.ClipNormValue(tsr.Float(y, x))
cr := math32.Vec2(float32(x), float32(ey))
pr := pos.Add(cr.Mul(gsz))
pc.StrokeStyle.Color = colors.Uniform(colors.FromFloat64(val, val, val, 1))
pc.FillBox(pr, gsz, pc.StrokeStyle.Color)
}
}
}
return
}
rows, cols, rowEx, colEx := tensor.Projection2DShape(tsr.Shape(), tg.GridStyle.OddRow)
frw := float32(rows) + float32(rowEx)*tg.GridStyle.DimExtra // extra spacing
fcl := float32(cols) + float32(colEx)*tg.GridStyle.DimExtra // extra spacing
rowsInner := rows
colsInner := cols
if rowEx > 0 {
rowsInner = rows / rowEx
}
if colEx > 0 {
colsInner = cols / colEx
}
tsz := math32.Vec2(fcl, frw)
gsz := sz.Div(tsz)
ssz := gsz.MulScalar(tg.GridStyle.GridFill) // smaller size with margin
for y := 0; y < rows; y++ {
yex := float32(int(y/rowsInner)) * tg.GridStyle.DimExtra
for x := 0; x < cols; x++ {
xex := float32(int(x/colsInner)) * tg.GridStyle.DimExtra
ey := y
if !tg.GridStyle.TopZero {
ey = (rows - 1) - y
}
val := tensor.Projection2DValue(tsr, tg.GridStyle.OddRow, ey, x)
cr := math32.Vec2(float32(x)+xex, float32(y)+yex)
pr := pos.Add(cr.Mul(gsz))
_, clr := tg.Color(val)
pc.FillBox(pr, ssz, colors.Uniform(clr))
}
}
}
// Code generated by "core generate"; DO NOT EDIT.
package tensorcore
import (
"cogentcore.org/core/colors/colormap"
"cogentcore.org/core/core"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/core/tree"
"cogentcore.org/core/types"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
)
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/tensorcore.Layout", IDName: "layout", Doc: "Layout are layout options for displaying tensors.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"--setters"}}}, Fields: []types.Field{{Name: "OddRow", Doc: "OddRow means that even-numbered dimensions are displayed as Y*X rectangles.\nThis determines along which dimension to display any remaining\nodd dimension: OddRow = true = organize vertically along row\ndimension, false = organize horizontally across column dimension."}, {Name: "TopZero", Doc: "TopZero means that the Y=0 coordinate is displayed from the top-down;\notherwise the Y=0 coordinate is displayed from the bottom up,\nwhich is typical for emergent network patterns."}, {Name: "Image", Doc: "Image will display the data as a bitmap image. If a 2D tensor, then it will\nbe a greyscale image. If a 3D tensor with size of either the first\nor last dim = either 3 or 4, then it is a RGB(A) color image."}}})
// SetOddRow sets the [Layout.OddRow]:
// OddRow means that even-numbered dimensions are displayed as Y*X rectangles.
// This determines along which dimension to display any remaining
// odd dimension: OddRow = true = organize vertically along row
// dimension, false = organize horizontally across column dimension.
func (t *Layout) SetOddRow(v bool) *Layout { t.OddRow = v; return t }
// SetTopZero sets the [Layout.TopZero]:
// TopZero means that the Y=0 coordinate is displayed from the top-down;
// otherwise the Y=0 coordinate is displayed from the bottom up,
// which is typical for emergent network patterns.
func (t *Layout) SetTopZero(v bool) *Layout { t.TopZero = v; return t }
// SetImage sets the [Layout.Image]:
// Image will display the data as a bitmap image. If a 2D tensor, then it will
// be a greyscale image. If a 3D tensor with size of either the first
// or last dim = either 3 or 4, then it is a RGB(A) color image.
func (t *Layout) SetImage(v bool) *Layout { t.Image = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/tensorcore.GridStyle", IDName: "grid-style", Doc: "GridStyle are options for displaying tensors", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"--setters"}}}, Embeds: []types.Field{{Name: "Layout"}}, Fields: []types.Field{{Name: "Range", Doc: "Range to plot"}, {Name: "MinMax", Doc: "MinMax has the actual range of data, if not using fixed Range."}, {Name: "ColorMap", Doc: "ColorMap is the name of the color map to use in translating values to colors."}, {Name: "GridFill", Doc: "GridFill sets proportion of grid square filled by the color block:\n1 = all, .5 = half, etc."}, {Name: "DimExtra", Doc: "DimExtra is the amount of extra space to add at dimension boundaries,\nas a proportion of total grid size."}, {Name: "Size", Doc: "Size sets the minimum and maximum size for grid squares."}, {Name: "TotalSize", Doc: "TotalSize sets the total preferred display size along largest dimension.\nGrid squares will be sized to fit within this size,\nsubject to the Size.Min / Max constraints, which have precedence."}, {Name: "FontSize", Doc: "FontSize is the font size in standard point units for labels."}}})
// SetRange sets the [GridStyle.Range]:
// Range to plot
func (t *GridStyle) SetRange(v minmax.Range64) *GridStyle { t.Range = v; return t }
// SetMinMax sets the [GridStyle.MinMax]:
// MinMax has the actual range of data, if not using fixed Range.
func (t *GridStyle) SetMinMax(v minmax.F64) *GridStyle { t.MinMax = v; return t }
// SetColorMap sets the [GridStyle.ColorMap]:
// ColorMap is the name of the color map to use in translating values to colors.
func (t *GridStyle) SetColorMap(v core.ColorMapName) *GridStyle { t.ColorMap = v; return t }
// SetGridFill sets the [GridStyle.GridFill]:
// GridFill sets proportion of grid square filled by the color block:
// 1 = all, .5 = half, etc.
func (t *GridStyle) SetGridFill(v float32) *GridStyle { t.GridFill = v; return t }
// SetDimExtra sets the [GridStyle.DimExtra]:
// DimExtra is the amount of extra space to add at dimension boundaries,
// as a proportion of total grid size.
func (t *GridStyle) SetDimExtra(v float32) *GridStyle { t.DimExtra = v; return t }
// SetSize sets the [GridStyle.Size]:
// Size sets the minimum and maximum size for grid squares.
func (t *GridStyle) SetSize(v minmax.F32) *GridStyle { t.Size = v; return t }
// SetTotalSize sets the [GridStyle.TotalSize]:
// TotalSize sets the total preferred display size along largest dimension.
// Grid squares will be sized to fit within this size,
// subject to the Size.Min / Max constraints, which have precedence.
func (t *GridStyle) SetTotalSize(v float32) *GridStyle { t.TotalSize = v; return t }
// SetFontSize sets the [GridStyle.FontSize]:
// FontSize is the font size in standard point units for labels.
func (t *GridStyle) SetFontSize(v float32) *GridStyle { t.FontSize = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/tensorcore.Table", IDName: "table", Doc: "Table provides a GUI widget for representing [table.Table] values.", Embeds: []types.Field{{Name: "ListBase"}}, Fields: []types.Field{{Name: "Table", Doc: "Table is the table that we're a view of."}, {Name: "GridStyle", Doc: "GridStyle has global grid display styles. GridStylers on the Table\nare applied to this on top of defaults."}, {Name: "ColumnGridStyle", Doc: "ColumnGridStyle has per column grid display styles."}, {Name: "SortIndex", Doc: "current sort index."}, {Name: "SortDescending", Doc: "whether current sort order is descending."}, {Name: "nCols", Doc: "number of columns in table (as of last update)."}, {Name: "headerWidths", Doc: "headerWidths has number of characters in each header, per visfields."}, {Name: "colMaxWidths", Doc: "colMaxWidths records maximum width in chars of string type fields."}, {Name: "blankString", Doc: "\tblank values for out-of-range rows."}, {Name: "blankFloat"}, {Name: "blankCells", Doc: "blankCells has per column blank tensor cells."}}})
// NewTable returns a new [Table] with the given optional parent:
// Table provides a GUI widget for representing [table.Table] values.
func NewTable(parent ...tree.Node) *Table { return tree.New[Table](parent...) }
// SetSortIndex sets the [Table.SortIndex]:
// current sort index.
func (t *Table) SetSortIndex(v int) *Table { t.SortIndex = v; return t }
// SetSortDescending sets the [Table.SortDescending]:
// whether current sort order is descending.
func (t *Table) SetSortDescending(v bool) *Table { t.SortDescending = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/tensorcore.TensorEditor", IDName: "tensor-editor", Doc: "TensorEditor provides a GUI widget for representing [tensor.Tensor] values.", Methods: []types.Method{{Name: "SaveCSV", Doc: "SaveTSV writes a tensor to a tab-separated-values (TSV) file.\nOuter-most dims are rows in the file, and inner-most is column --\nReading just grabs all values and doesn't care about shape.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename"}, Returns: []string{"error"}}, {Name: "OpenCSV", Doc: "OpenTSV reads a tensor from a tab-separated-values (TSV) file.\nusing the Go standard encoding/csv reader conforming\nto the official CSV standard.\nReads all values and assigns as many as fit.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename"}, Returns: []string{"error"}}}, Embeds: []types.Field{{Name: "ListBase"}}, Fields: []types.Field{{Name: "Tensor", Doc: "the tensor that we're a view of"}, {Name: "Layout", Doc: "overall layout options for tensor display"}, {Name: "NCols", Doc: "number of columns in table (as of last update)"}, {Name: "headerWidths", Doc: "headerWidths has number of characters in each header, per visfields"}, {Name: "colMaxWidths", Doc: "colMaxWidths records maximum width in chars of string type fields"}, {Name: "BlankString", Doc: "\tblank values for out-of-range rows"}, {Name: "BlankFloat"}}})
// NewTensorEditor returns a new [TensorEditor] with the given optional parent:
// TensorEditor provides a GUI widget for representing [tensor.Tensor] values.
func NewTensorEditor(parent ...tree.Node) *TensorEditor { return tree.New[TensorEditor](parent...) }
// SetNCols sets the [TensorEditor.NCols]:
// number of columns in table (as of last update)
func (t *TensorEditor) SetNCols(v int) *TensorEditor { t.NCols = v; return t }
// SetBlankString sets the [TensorEditor.BlankString]:
//
// blank values for out-of-range rows
func (t *TensorEditor) SetBlankString(v string) *TensorEditor { t.BlankString = v; return t }
// SetBlankFloat sets the [TensorEditor.BlankFloat]
func (t *TensorEditor) SetBlankFloat(v float64) *TensorEditor { t.BlankFloat = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/tensorcore.TensorGrid", IDName: "tensor-grid", Doc: "TensorGrid is a widget that displays tensor values as a grid of colored squares.", Methods: []types.Method{{Name: "TensorEditor", Doc: "TensorEditor pulls up a TensorEditor of our tensor", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "EditStyle", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}}, Embeds: []types.Field{{Name: "WidgetBase"}}, Fields: []types.Field{{Name: "Tensor", Doc: "Tensor is the tensor that we view."}, {Name: "GridStyle", Doc: "GridStyle has grid display style properties."}, {Name: "ColorMap", Doc: "ColorMap is the colormap displayed (based on)"}}})
// NewTensorGrid returns a new [TensorGrid] with the given optional parent:
// TensorGrid is a widget that displays tensor values as a grid of colored squares.
func NewTensorGrid(parent ...tree.Node) *TensorGrid { return tree.New[TensorGrid](parent...) }
// SetGridStyle sets the [TensorGrid.GridStyle]:
// GridStyle has grid display style properties.
func (t *TensorGrid) SetGridStyle(v GridStyle) *TensorGrid { t.GridStyle = v; return t }
// SetColorMap sets the [TensorGrid.ColorMap]:
// ColorMap is the colormap displayed (based on)
func (t *TensorGrid) SetColorMap(v *colormap.Map) *TensorGrid { t.ColorMap = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/tensorcore.TensorButton", IDName: "tensor-button", Doc: "TensorButton represents a Tensor with a button for making a [TensorGrid]\nviewer for an [tensor.Tensor].", Embeds: []types.Field{{Name: "Button"}}, Fields: []types.Field{{Name: "Tensor"}}})
// NewTensorButton returns a new [TensorButton] with the given optional parent:
// TensorButton represents a Tensor with a button for making a [TensorGrid]
// viewer for an [tensor.Tensor].
func NewTensorButton(parent ...tree.Node) *TensorButton { return tree.New[TensorButton](parent...) }
// SetTensor sets the [TensorButton.Tensor]
func (t *TensorButton) SetTensor(v tensor.Tensor) *TensorButton { t.Tensor = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/tensorcore.TableButton", IDName: "table-button", Doc: "TableButton presents a button that pulls up the [Table] viewer for a [table.Table].", Embeds: []types.Field{{Name: "Button"}}, Fields: []types.Field{{Name: "Table"}}})
// NewTableButton returns a new [TableButton] with the given optional parent:
// TableButton presents a button that pulls up the [Table] viewer for a [table.Table].
func NewTableButton(parent ...tree.Node) *TableButton { return tree.New[TableButton](parent...) }
// SetTable sets the [TableButton.Table]
func (t *TableButton) SetTable(v *table.Table) *TableButton { t.Table = v; return t }
// Copyright (c) 2019, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensorcore
import (
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/core"
"cogentcore.org/core/icons"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
)
func init() {
core.AddValueType[table.Table, TableButton]()
core.AddValueType[tensor.Float32, TensorButton]()
core.AddValueType[tensor.Float64, TensorButton]()
core.AddValueType[tensor.Int, TensorButton]()
core.AddValueType[tensor.Int32, TensorButton]()
core.AddValueType[tensor.Byte, TensorButton]()
core.AddValueType[tensor.String, TensorButton]()
core.AddValueType[tensor.Bool, TensorButton]()
// core.AddValueType[simat.SimMat, SimMatButton]()
}
// TensorButton represents a Tensor with a button for making a [TensorGrid]
// viewer for an [tensor.Tensor].
type TensorButton struct {
core.Button
Tensor tensor.Tensor
}
func (tb *TensorButton) WidgetValue() any { return &tb.Tensor }
func (tb *TensorButton) Init() {
tb.Button.Init()
tb.SetType(core.ButtonTonal).SetIcon(icons.Edit)
tb.Updater(func() {
text := "None"
if tb.Tensor != nil {
text = "Tensor"
}
tb.SetText(text)
})
core.InitValueButton(tb, true, func(d *core.Body) {
NewTensorGrid(d).SetTensor(tb.Tensor)
})
}
// TableButton presents a button that pulls up the [Table] viewer for a [table.Table].
type TableButton struct {
core.Button
Table *table.Table
}
func (tb *TableButton) WidgetValue() any { return &tb.Table }
func (tb *TableButton) Init() {
tb.Button.Init()
tb.SetType(core.ButtonTonal).SetIcon(icons.Edit)
tb.Updater(func() {
text := "None"
if tb.Table != nil {
if nm, err := metadata.Get[string](tb.Table.Meta, "name"); err == nil {
text = nm
} else {
text = "Table"
}
}
tb.SetText(text)
})
core.InitValueButton(tb, true, func(d *core.Body) {
NewTable(d).SetTable(tb.Table)
})
}
/*
// SimMatValue presents a button that pulls up the [SimMatGrid] viewer for a [table.Table].
type SimMatButton struct {
core.Button
SimMat *simat.SimMat
}
func (tb *SimMatButton) WidgetValue() any { return &tb.SimMat }
func (tb *SimMatButton) Init() {
tb.Button.Init()
tb.SetType(core.ButtonTonal).SetIcon(icons.Edit)
tb.Updater(func() {
text := "None"
if tb.SimMat != nil && tb.SimMat.Mat != nil {
if nm, has := tb.SimMat.Mat.MetaData("name"); has {
text = nm
} else {
text = "SimMat"
}
}
tb.SetText(text)
})
core.InitValueButton(tb, true, func(d *core.Body) {
NewSimMatGrid(d).SetSimMat(tb.SimMat)
})
}
*/
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensorfs
import (
"fmt"
"io/fs"
"path"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/tensor"
)
var (
// CurDir is the current working directory.
CurDir *Node
// CurRoot is the current root tensorfs system.
// A default root tensorfs is created at startup.
CurRoot *Node
)
func init() {
CurRoot, _ = NewDir("data")
CurDir = CurRoot
}
// Record saves given tensor to current directory with given name.
func Record(tsr tensor.Tensor, name string) {
if CurDir == nil {
CurDir = CurRoot
}
SetTensor(CurDir, tsr, name)
}
// Chdir changes the current working tensorfs directory to the named directory.
func Chdir(dir string) error {
if CurDir == nil {
CurDir = CurRoot
}
if dir == "" {
CurDir = CurRoot
return nil
}
ndir, err := CurDir.DirAtPath(dir)
if err != nil {
return err
}
CurDir = ndir
return nil
}
// Mkdir creates a new directory with the specified name in the current directory.
// It returns an existing directory of the same name without error.
func Mkdir(dir string) *Node {
if CurDir == nil {
CurDir = CurRoot
}
if dir == "" {
err := &fs.PathError{Op: "Mkdir", Path: dir, Err: errors.New("path must not be empty")}
errors.Log(err)
return nil
}
return CurDir.Dir(dir)
}
// List lists files using arguments (options and path) from the current directory.
func List(opts ...string) error {
if CurDir == nil {
CurDir = CurRoot
}
long := false
recursive := false
if len(opts) > 0 && len(opts[0]) > 0 && opts[0][0] == '-' {
op := opts[0]
if strings.Contains(op, "l") {
long = true
}
if strings.Contains(op, "r") {
recursive = true
}
opts = opts[1:]
}
dir := CurDir
if len(opts) > 0 {
nd, err := CurDir.DirAtPath(opts[0])
if err == nil {
dir = nd
}
}
ls := dir.List(long, recursive)
fmt.Println(ls)
return nil
}
// Get returns the tensor value at given path relative to the
// current working directory.
// This is the direct pointer to the node, so changes
// to it will change the node. Clone the tensor to make
// a new copy disconnected from the original.
func Get(name string) tensor.Tensor {
if CurDir == nil {
CurDir = CurRoot
}
if name == "" {
err := &fs.PathError{Op: "Get", Path: name, Err: errors.New("name must not be empty")}
errors.Log(err)
return nil
}
nd, err := CurDir.NodeAtPath(name)
if errors.Log(err) != nil {
return nil
}
if nd.IsDir() {
err := &fs.PathError{Op: "Get", Path: name, Err: errors.New("node is a directory, not a data node")}
errors.Log(err)
return nil
}
return nd.Tensor
}
// Set sets tensor to given name or path relative to the
// current working directory.
// If the node already exists, its previous tensor is updated to the
// given one; if it doesn't, then a new node is created.
func Set(name string, tsr tensor.Tensor) error {
if CurDir == nil {
CurDir = CurRoot
}
if name == "" {
err := &fs.PathError{Op: "Set", Path: name, Err: errors.New("name must not be empty")}
return errors.Log(err)
}
itm, err := CurDir.NodeAtPath(name)
if err == nil {
if itm.IsDir() {
err := &fs.PathError{Op: "Set", Path: name, Err: errors.New("existing node is a directory, not a data node")}
return errors.Log(err)
}
itm.Tensor = tsr
return nil
}
cd := CurDir
dir, name := path.Split(name)
if dir != "" {
d, err := CurDir.DirAtPath(dir)
if err != nil {
return errors.Log(err)
}
cd = d
}
SetTensor(cd, tsr, name)
return nil
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensorfs
import (
"errors"
"io/fs"
"time"
"cogentcore.org/lab/tensor"
)
const (
// Preserve is used for Overwrite flag, indicating to not overwrite and preserve existing.
Preserve = false
// Overwrite is used for Overwrite flag, indicating to overwrite existing.
Overwrite = true
)
// CopyFromValue copies value from given source node, cloning it.
func (d *Node) CopyFromValue(frd *Node) {
d.modTime = time.Now()
d.Tensor = tensor.Clone(frd.Tensor)
}
// Clone returns a copy of this node, recursively cloning directory nodes
// if it is a directory.
func (nd *Node) Clone() *Node {
if !nd.IsDir() {
cp, _ := newNode(nil, nd.name)
cp.Tensor = tensor.Clone(nd.Tensor)
return cp
}
nodes, _ := nd.Nodes()
cp, _ := NewDir(nd.name)
for _, it := range nodes {
cp.Add(it.Clone())
}
return cp
}
// Copy copies node(s) from given paths to given path or directory.
// if there are multiple from nodes, then to must be a directory.
// must be called on a directory node.
func (dir *Node) Copy(overwrite bool, to string, from ...string) error {
if err := dir.mustDir("Copy", to); err != nil {
return err
}
switch {
case to == "":
return &fs.PathError{Op: "Copy", Path: to, Err: errors.New("to location is empty")}
case len(from) == 0:
return &fs.PathError{Op: "Copy", Path: to, Err: errors.New("no from sources specified")}
}
// todo: check for to conflict first here..
tod, _ := dir.NodeAtPath(to)
var errs []error
if len(from) > 1 && tod != nil && !tod.IsDir() {
return &fs.PathError{Op: "Copy", Path: to, Err: errors.New("multiple source nodes requires destination to be a directory")}
}
targd := dir
targf := to
if tod != nil && tod.IsDir() {
targd = tod
targf = ""
}
for _, fr := range from {
opstr := fr + " -> " + to
frd, err := dir.NodeAtPath(fr)
if err != nil {
errs = append(errs, err)
continue
}
if targf == "" {
if trg, ok := targd.nodes.AtTry(frd.name); ok { // target exists
switch {
case trg.IsDir() && frd.IsDir():
// todo: copy all nodes from frd into trg
case trg.IsDir(): // frd is not
errs = append(errs, &fs.PathError{Op: "Copy", Path: opstr, Err: errors.New("cannot copy from Value onto directory of same name")})
case frd.IsDir(): // trg is not
errs = append(errs, &fs.PathError{Op: "Copy", Path: opstr, Err: errors.New("cannot copy from Directory onto Value of same name")})
default: // both nodes
if overwrite { // todo: interactive!?
trg.CopyFromValue(frd)
}
}
continue
}
}
nw := frd.Clone()
if targf != "" {
nw.name = targf
}
targd.Add(nw)
}
return errors.Join(errs...)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensorfs
import (
"fmt"
"io/fs"
"path"
"slices"
"sort"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/keylist"
"cogentcore.org/lab/tensor"
)
// Nodes is a map of directory entry names to Nodes.
// It retains the order that nodes were added in, which is
// the natural order nodes are processed in.
type Nodes = keylist.List[string, *Node]
// NewDir returns a new tensorfs directory with the given name.
// If parent != nil and a directory, this dir is added to it.
// If the parent already has an node of that name, it is returned,
// with an [fs.ErrExist] error.
// If the name is empty, then it is set to "root", the root directory.
// Note that "/" is not allowed for the root directory in Go [fs].
// If no parent (i.e., a new root) and CurRoot is nil, then it is set
// to this.
func NewDir(name string, parent ...*Node) (*Node, error) {
if name == "" {
name = "root"
}
var par *Node
if len(parent) == 1 {
par = parent[0]
}
dir, err := newNode(par, name)
if dir != nil && dir.nodes == nil {
dir.nodes = &Nodes{}
}
return dir, err
}
// Dir creates a new directory under given dir with the specified name
// if it doesn't already exist, otherwise returns the existing one.
// Path / slash separators can be used to make a path of multiple directories.
// It logs an error and returns nil if this dir node is not a directory.
func (dir *Node) Dir(name string) *Node {
if err := dir.mustDir("Dir", name); errors.Log(err) != nil {
return nil
}
if len(name) == 0 {
return dir
}
path := strings.Split(name, "/")
if cd := dir.nodes.At(path[0]); cd != nil {
if len(path) > 1 {
return cd.Dir(strings.Join(path[1:], "/"))
}
return cd
}
nd, _ := NewDir(path[0], dir)
if len(path) > 1 {
return nd.Dir(strings.Join(path[1:], "/"))
}
return nd
}
// Node returns a Node in given directory by name.
// This is for fast access and direct usage of known
// nodes, and it will panic if this node is not a directory.
// Returns nil if no node of given name exists.
func (dir *Node) Node(name string) *Node {
return dir.nodes.At(name)
}
// Value returns the [tensor.Tensor] value for given node
// within this directory. This will panic if node is not
// found, and will return nil if it is not a Value
// (i.e., it is a directory).
func (dir *Node) Value(name string) tensor.Tensor {
return dir.nodes.At(name).Tensor
}
// Nodes returns a slice of Nodes in given directory by names variadic list.
// If list is empty, then all nodes in the directory are returned.
// returned error reports any nodes not found, or if not a directory.
func (dir *Node) Nodes(names ...string) ([]*Node, error) {
if err := dir.mustDir("Nodes", ""); err != nil {
return nil, err
}
var nds []*Node
if len(names) == 0 {
for _, it := range dir.nodes.Values {
nds = append(nds, it)
}
return nds, nil
}
var errs []error
for _, nm := range names {
dt := dir.nodes.At(nm)
if dt != nil {
nds = append(nds, dt)
} else {
err := fmt.Errorf("tensorfs Dir %q node not found: %q", dir.Path(), nm)
errs = append(errs, err)
}
}
return nds, errors.Join(errs...)
}
// Values returns a slice of tensor values in the given directory,
// by names variadic list. If list is empty, then all value nodes
// in the directory are returned.
// returned error reports any nodes not found, or if not a directory.
func (dir *Node) Values(names ...string) ([]tensor.Tensor, error) {
if err := dir.mustDir("Values", ""); err != nil {
return nil, err
}
var nds []tensor.Tensor
if len(names) == 0 {
for _, it := range dir.nodes.Values {
if it.Tensor != nil {
nds = append(nds, it.Tensor)
}
}
return nds, nil
}
var errs []error
for _, nm := range names {
it := dir.nodes.At(nm)
if it != nil && it.Tensor != nil {
nds = append(nds, it.Tensor)
} else {
err := fmt.Errorf("tensorfs Dir %q node not found: %q", dir.Path(), nm)
errs = append(errs, err)
}
}
return nds, errors.Join(errs...)
}
// ValuesFunc returns all tensor Values under given directory,
// filtered by given function, in directory order (e.g., order added),
// recursively descending into directories to return a flat list of
// the entire subtree. The function can filter out directories to prune
// the tree, e.g., using `IsDir` method.
// If func is nil, all Value nodes are returned.
func (dir *Node) ValuesFunc(fun func(nd *Node) bool) []tensor.Tensor {
if err := dir.mustDir("ValuesFunc", ""); err != nil {
return nil
}
var nds []tensor.Tensor
for _, it := range dir.nodes.Values {
if fun != nil && !fun(it) {
continue
}
if it.IsDir() {
subs := it.ValuesFunc(fun)
nds = append(nds, subs...)
} else {
nds = append(nds, it.Tensor)
}
}
return nds
}
// NodesFunc returns leaf Nodes under given directory, filtered by
// given function, recursively descending into directories
// to return a flat list of the entire subtree, in directory order
// (e.g., order added).
// The function can filter out directories to prune the tree.
// If func is nil, all leaf Nodes are returned.
func (dir *Node) NodesFunc(fun func(nd *Node) bool) []*Node {
if err := dir.mustDir("NodesFunc", ""); err != nil {
return nil
}
var nds []*Node
for _, it := range dir.nodes.Values {
if fun != nil && !fun(it) {
continue
}
if it.IsDir() {
subs := it.NodesFunc(fun)
nds = append(nds, subs...)
} else {
nds = append(nds, it)
}
}
return nds
}
// ValuesAlphaFunc returns all Value nodes (tensors) in given directory,
// recursively descending into directories to return a flat list of
// the entire subtree, filtered by given function, with nodes at each
// directory level traversed in alphabetical order.
// The function can filter out directories to prune the tree.
// If func is nil, all Values are returned.
func (dir *Node) ValuesAlphaFunc(fun func(nd *Node) bool) []tensor.Tensor {
if err := dir.mustDir("ValuesAlphaFunc", ""); err != nil {
return nil
}
names := dir.dirNamesAlpha()
var nds []tensor.Tensor
for _, nm := range names {
it := dir.nodes.At(nm)
if fun != nil && !fun(it) {
continue
}
if it.IsDir() {
subs := it.ValuesAlphaFunc(fun)
nds = append(nds, subs...)
} else {
nds = append(nds, it.Tensor)
}
}
return nds
}
// NodesAlphaFunc returns leaf nodes under given directory, filtered
// by given function, with nodes at each directory level
// traversed in alphabetical order, recursively descending into directories
// to return a flat list of the entire subtree, in directory order
// (e.g., order added).
// The function can filter out directories to prune the tree.
// If func is nil, all leaf Nodes are returned.
func (dir *Node) NodesAlphaFunc(fun func(nd *Node) bool) []*Node {
if err := dir.mustDir("NodesAlphaFunc", ""); err != nil {
return nil
}
names := dir.dirNamesAlpha()
var nds []*Node
for _, nm := range names {
it := dir.nodes.At(nm)
if fun != nil && !fun(it) {
continue
}
if it.IsDir() {
subs := it.NodesAlphaFunc(fun)
nds = append(nds, subs...)
} else {
nds = append(nds, it)
}
}
return nds
}
// todo: these must handle going up the tree using ..
// DirAtPath returns directory at given relative path
// from this starting dir.
func (dir *Node) DirAtPath(dirPath string) (*Node, error) {
var err error
dirPath = path.Clean(dirPath)
sdf, err := dir.Sub(dirPath) // this ensures that dir is a dir
if err != nil {
return nil, err
}
return sdf.(*Node), nil
}
// NodeAtPath returns node at given relative path from this starting dir.
func (dir *Node) NodeAtPath(name string) (*Node, error) {
if err := dir.mustDir("NodeAtPath", name); err != nil {
return nil, err
}
if !fs.ValidPath(name) {
return nil, &fs.PathError{Op: "NodeAtPath", Path: name, Err: errors.New("invalid path")}
}
dirPath, file := path.Split(name)
sd, err := dir.DirAtPath(dirPath)
if err != nil {
return nil, err
}
nd, ok := sd.nodes.AtTry(file)
if !ok {
if dirPath == "" && (file == dir.name || file == ".") {
return dir, nil
}
return nil, &fs.PathError{Op: "NodeAtPath", Path: name, Err: errors.New("file not found")}
}
return nd, nil
}
// Path returns the full path to this data node
func (dir *Node) Path() string {
pt := dir.name
cur := dir.Parent
loops := make(map[*Node]struct{})
for {
if cur == nil {
return pt
}
if _, ok := loops[cur]; ok {
return pt
}
pt = path.Join(cur.name, pt)
loops[cur] = struct{}{}
cur = cur.Parent
}
}
// dirNamesAlpha returns the names of nodes in the directory
// sorted alphabetically. Node must be dir by this point.
func (dir *Node) dirNamesAlpha() []string {
names := slices.Clone(dir.nodes.Keys)
sort.Strings(names)
return names
}
// dirNamesByTime returns the names of nodes in the directory
// sorted by modTime. Node must be dir by this point.
func (dir *Node) dirNamesByTime() []string {
names := slices.Clone(dir.nodes.Keys)
slices.SortFunc(names, func(a, b string) int {
return dir.nodes.At(a).ModTime().Compare(dir.nodes.At(b).ModTime())
})
return names
}
// mustDir returns an error for given operation and path
// if this data node is not a directory.
func (dir *Node) mustDir(op, path string) error {
if !dir.IsDir() {
return &fs.PathError{Op: op, Path: path, Err: errors.New("tensorfs node is not a directory")}
}
return nil
}
// Add adds an node to this directory data node.
// The only errors are if this node is not a directory,
// or the name already exists, in which case an [fs.ErrExist] is returned.
// Names must be unique within a directory.
func (dir *Node) Add(it *Node) error {
if err := dir.mustDir("Add", it.name); err != nil {
return err
}
err := dir.nodes.Add(it.name, it)
if err != nil {
return fs.ErrExist
}
return nil
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensorfs
import (
"bytes"
"io"
"io/fs"
)
// File represents a data item for reading, as an [fs.File].
// All io functionality is handled by [bytes.Reader].
type File struct {
bytes.Reader
Node *Node
dirEntries []fs.DirEntry
dirsRead int
}
func (f *File) Stat() (fs.FileInfo, error) {
return f.Node, nil
}
func (f *File) Close() error {
f.Reader.Reset(f.Node.Bytes())
return nil
}
// DirFile represents a directory data item for reading, as fs.ReadDirFile.
type DirFile struct {
File
dirEntries []fs.DirEntry
dirsRead int
}
func (f *DirFile) ReadDir(n int) ([]fs.DirEntry, error) {
if err := f.Node.mustDir("DirFile:ReadDir", ""); err != nil {
return nil, err
}
if f.dirEntries == nil {
f.dirEntries, _ = f.Node.ReadDir(".")
f.dirsRead = 0
}
ne := len(f.dirEntries)
if n <= 0 {
if f.dirsRead >= ne {
return nil, nil
}
re := f.dirEntries[f.dirsRead:]
f.dirsRead = ne
return re, nil
}
if f.dirsRead >= ne {
return nil, io.EOF
}
mx := min(n+f.dirsRead, ne)
re := f.dirEntries[f.dirsRead:mx]
f.dirsRead = mx
return re, nil
}
func (f *DirFile) Close() error {
f.dirsRead = 0
return nil
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensorfs
import (
"bytes"
"errors"
"io/fs"
"slices"
"time"
"cogentcore.org/core/base/fileinfo"
"cogentcore.org/core/base/fsx"
)
// fs.go contains all the io/fs interface implementations, and other fs functionality.
// Open opens the given node at given path within this tensorfs filesystem.
func (nd *Node) Open(name string) (fs.File, error) {
itm, err := nd.NodeAtPath(name)
if err != nil {
return nil, err
}
if itm.IsDir() {
return &DirFile{File: File{Reader: *bytes.NewReader(itm.Bytes()), Node: itm}}, nil
}
return &File{Reader: *bytes.NewReader(itm.Bytes()), Node: itm}, nil
}
// Stat returns a FileInfo describing the file.
// If there is an error, it should be of type *PathError.
func (nd *Node) Stat(name string) (fs.FileInfo, error) {
return nd.NodeAtPath(name)
}
// Sub returns a data FS corresponding to the subtree rooted at dir.
func (nd *Node) Sub(dir string) (fs.FS, error) {
if err := nd.mustDir("Sub", dir); err != nil {
return nil, err
}
if !fs.ValidPath(dir) {
return nil, &fs.PathError{Op: "Sub", Path: dir, Err: errors.New("invalid name")}
}
if dir == "." || dir == "" || dir == nd.name {
return nd, nil
}
cd := dir
cur := nd
root, rest := fsx.SplitRootPathFS(dir)
if root == "." || root == nd.name {
cd = rest
}
for {
if cd == "." || cd == "" {
return cur, nil
}
root, rest := fsx.SplitRootPathFS(cd)
if root == "." && rest == "" {
return cur, nil
}
cd = rest
sd, ok := cur.nodes.AtTry(root)
if !ok {
return nil, &fs.PathError{Op: "Sub", Path: dir, Err: errors.New("directory not found")}
}
if !sd.IsDir() {
return nil, &fs.PathError{Op: "Sub", Path: dir, Err: errors.New("is not a directory")}
}
cur = sd
}
}
// ReadDir returns the contents of the given directory within this filesystem.
// Use "." (or "") to refer to the current directory.
func (nd *Node) ReadDir(dir string) ([]fs.DirEntry, error) {
sd, err := nd.DirAtPath(dir)
if err != nil {
return nil, err
}
names := sd.dirNamesAlpha()
ents := make([]fs.DirEntry, len(names))
for i, nm := range names {
ents[i] = sd.nodes.At(nm)
}
return ents, nil
}
// ReadFile reads the named file and returns its contents.
// A successful call returns a nil error, not io.EOF.
// (Because ReadFile reads the whole file, the expected EOF
// from the final Read is not treated as an error to be reported.)
//
// The caller is permitted to modify the returned byte slice.
// This method should return a copy of the underlying data.
func (nd *Node) ReadFile(name string) ([]byte, error) {
itm, err := nd.NodeAtPath(name)
if err != nil {
return nil, err
}
if itm.IsDir() {
return nil, &fs.PathError{Op: "ReadFile", Path: name, Err: errors.New("Node is a directory")}
}
return slices.Clone(itm.Bytes()), nil
}
//////// FileInfo interface:
func (nd *Node) Name() string { return nd.name }
// Size returns the size of known data Values, or it uses
// the Sizer interface, otherwise returns 0.
func (nd *Node) Size() int64 {
if nd.Tensor == nil {
return 0
}
return nd.Tensor.AsValues().Sizeof()
}
func (nd *Node) IsDir() bool {
return nd.nodes != nil
}
func (nd *Node) ModTime() time.Time {
return nd.modTime
}
func (nd *Node) Mode() fs.FileMode {
if nd.IsDir() {
return 0755 | fs.ModeDir
}
return 0444
}
// Sys returns the Dir or Value
func (nd *Node) Sys() any {
if nd.Tensor != nil {
return nd.Tensor
}
return nd.nodes
}
//////// DirEntry interface
func (nd *Node) Type() fs.FileMode {
return nd.Mode().Type()
}
func (nd *Node) Info() (fs.FileInfo, error) {
return nd, nil
}
//////// Misc
func (nd *Node) KnownFileInfo() fileinfo.Known {
if nd.Tensor == nil {
return fileinfo.Unknown
}
tsr := nd.Tensor
if tsr.Len() > 1 {
return fileinfo.Tensor
}
// scalars by type
if tsr.IsString() {
return fileinfo.String
}
return fileinfo.Number
}
// Bytes returns the byte-wise representation of the data Value.
// This is the actual underlying data, so make a copy if it can be
// unintentionally modified or retained more than for immediate use.
func (nd *Node) Bytes() []byte {
if nd.Tensor == nil || nd.Tensor.NumDims() == 0 || nd.Tensor.Len() == 0 {
return nil
}
return nd.Tensor.AsValues().Bytes()
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensorfs
import (
"strings"
"cogentcore.org/core/base/indent"
)
const (
Short = false
Long = true
DirOnly = false
Recursive = true
)
// todo: list options string
func (nd *Node) String() string {
if !nd.IsDir() {
return nd.Tensor.Label()
}
return nd.List(Short, DirOnly)
}
// List returns a listing of nodes in the given directory.
// - long = include detailed information about each node, vs just the name.
// - recursive = descend into subdirectories.
func (dir *Node) List(long, recursive bool) string {
if long {
return dir.ListLong(recursive, 0)
}
return dir.ListShort(recursive, 0)
}
// ListShort returns a name-only listing of given directory.
func (dir *Node) ListShort(recursive bool, ident int) string {
var b strings.Builder
nodes, _ := dir.Nodes()
for _, it := range nodes {
b.WriteString(indent.Tabs(ident))
if it.IsDir() {
if recursive {
b.WriteString("\n" + it.ListShort(recursive, ident+1))
} else {
b.WriteString(it.name + "/ ")
}
} else {
b.WriteString(it.name + " ")
}
}
return b.String()
}
// ListLong returns a detailed listing of given directory.
func (dir *Node) ListLong(recursive bool, ident int) string {
var b strings.Builder
nodes, _ := dir.Nodes()
for _, it := range nodes {
b.WriteString(indent.Tabs(ident))
if it.IsDir() {
b.WriteString(it.name + "/\n")
if recursive {
b.WriteString(it.ListLong(recursive, ident+1))
}
} else {
b.WriteString(it.String() + "\n")
}
}
return b.String()
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensorfs
import (
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/tensor"
)
// This file provides standardized metadata options for frequent
// use cases, using codified key names to eliminate typos.
// SetMetaItems sets given metadata for Value items in given directory
// with given names. Returns error for any items not found.
func (d *Node) SetMetaItems(key string, value any, names ...string) error {
tsrs, err := d.Values(names...)
for _, tsr := range tsrs {
tsr.Metadata().Set(key, value)
}
return err
}
// CalcAll calls function set by [Node.SetCalcFunc] for all items
// in this directory and all of its subdirectories.
// Calls Calc on items from ValuesFunc(nil)
func (d *Node) CalcAll() error {
var errs []error
items := d.ValuesFunc(nil)
for _, it := range items {
err := tensor.Calc(it)
if err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensorfs
import (
"io/fs"
"reflect"
"time"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/base/metadata"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
)
// Node is the element type for the filesystem, which can represent either
// a [tensor] Value as a "file" equivalent, or a "directory" containing other Nodes.
// The [tensor.Tensor] can represent everything from a single scalar value up to
// n-dimensional collections of patterns, in a range of data types.
// Directories have an ordered map of nodes.
type Node struct {
// Parent is the parent data directory.
Parent *Node
// name is the name of this node. it is not a path.
name string
// modTime tracks time added to directory, used for ordering.
modTime time.Time
// Tensor is the tensor value for a file or leaf Node in the FS,
// represented using the universal [tensor] data type of
// [tensor.Tensor], which can represent anything from a scalar
// to n-dimensional data, in a range of data types.
Tensor tensor.Tensor
// nodes is for directory nodes, with all the nodes in the directory.
nodes *Nodes
// DirTable is a summary [table.Table] with columns comprised of Value
// nodes in the directory, which can be used for plotting or other operations.
DirTable *table.Table
}
// newNode returns a new Node in given directory Node, which can be nil.
// If dir is not a directory, returns nil and an error.
// If an node already exists in dir with that name, that node is returned
// with an [fs.ErrExist] error, and the caller can decide how to proceed.
// The modTime is set to now. The name must be unique within parent.
func newNode(dir *Node, name string) (*Node, error) {
if dir == nil {
return &Node{name: name, modTime: time.Now()}, nil
}
if err := dir.mustDir("newNode", name); err != nil {
return nil, err
}
if ex, ok := dir.nodes.AtTry(name); ok {
return ex, fs.ErrExist
}
d := &Node{Parent: dir, name: name, modTime: time.Now()}
dir.nodes.Add(name, d)
return d, nil
}
// Value creates / returns a Node with given name as a [tensor.Tensor]
// of given data type and shape sizes, in given directory Node.
// If it already exists, it is returned as-is (no checking against the
// type or sizes provided, for efficiency -- if there is doubt, check!),
// otherwise a new tensor is created. It is fine to not pass any sizes and
// use `SetShapeSizes` method later to set the size.
func Value[T tensor.DataTypes](dir *Node, name string, sizes ...int) tensor.Values {
it := dir.Node(name)
if it != nil {
return it.Tensor.(tensor.Values)
}
tsr := tensor.New[T](sizes...)
metadata.SetName(tsr, name)
nd, err := newNode(dir, name)
if errors.Log(err) != nil {
return nil
}
nd.Tensor = tsr
return tsr
}
// NewValues makes new tensor Node value(s) (as a [tensor.Tensor])
// of given data type and shape sizes, in given directory.
// Any existing nodes with the same names are recycled without checking
// or updating the data type or sizes.
// See the [Value] documentation for more info.
func NewValues[T tensor.DataTypes](dir *Node, shape []int, names ...string) {
for _, nm := range names {
Value[T](dir, nm, shape...)
}
}
// Scalar returns a scalar Node value (as a [tensor.Tensor])
// of given data type, in given directory and name.
// If it already exists, it is returned without checking against args,
// else a new one is made. See the [Value] documentation for more info.
func Scalar[T tensor.DataTypes](dir *Node, name string) tensor.Values {
return Value[T](dir, name, 1)
}
// ValueType creates / returns a Node with given name as a [tensor.Tensor]
// of given data type specified as a reflect.Kind, with shape sizes,
// in given directory Node.
// Supported types are string, bool (for [Bool]), float32, float64, int, int32, and byte.
// If it already exists, it is returned as-is (no checking against the
// type or sizes provided, for efficiency -- if there is doubt, check!),
// otherwise a new tensor is created. It is fine to not pass any sizes and
// use `SetShapeSizes` method later to set the size.
func ValueType(dir *Node, name string, typ reflect.Kind, sizes ...int) tensor.Values {
it := dir.Node(name)
if it != nil {
return it.Tensor.(tensor.Values)
}
tsr := tensor.NewOfType(typ, sizes...)
metadata.SetName(tsr, name)
nd, err := newNode(dir, name)
if errors.Log(err) != nil {
return nil
}
nd.Tensor = tsr
return tsr
}
// SetTensor creates / recycles a node and sets to given existing tensor with given name.
func SetTensor(dir *Node, tsr tensor.Tensor, name string) *Node {
nd := dir.Node(name)
if nd != nil {
nd.Tensor = tsr
} else {
nd, _ = newNode(dir, name)
}
nd.Tensor = tsr
return nd
}
// Set creates / returns a Node with given name setting value to given Tensor,
// in given directory [Node]. Calls [SetTensor].
func (dir *Node) Set(name string, tsr tensor.Tensor) *Node {
return SetTensor(dir, tsr, name)
}
// DirTable returns a [table.Table] with all of the tensor values under
// the given directory, with columns as the Tensor values elements in the directory
// and any subdirectories, using given filter function.
// This is a convenient mechanism for creating a plot of all the data
// in a given directory.
// If such was previously constructed, it is returned from "DirTable"
// where it is stored for later use.
// Row count is updated to current max row.
// Set DirTable = nil to regenerate.
func DirTable(dir *Node, fun func(node *Node) bool) *table.Table {
nds := dir.NodesFunc(fun)
if dir.DirTable != nil {
if dir.DirTable.NumColumns() == len(nds) {
dir.DirTable.SetNumRowsToMax()
return dir.DirTable
}
}
dt := table.New(fsx.DirAndFile(string(dir.Path())))
for _, it := range nds {
if it.Tensor == nil || it.Tensor.NumDims() == 0 {
continue
}
tsr := it.Tensor
rows := tsr.DimSize(0)
if dt.Columns.Rows < rows {
dt.Columns.Rows = rows
dt.SetNumRows(dt.Columns.Rows)
}
nm := it.name
if it.Parent != dir {
nm = fsx.DirAndFile(string(it.Path()))
}
dt.AddColumn(nm, tsr.AsValues())
}
dir.DirTable = dt
return dt
}
// DirFromTable sets tensor values under given directory node to the
// columns of the given [table.Table]. Also sets the DirTable to this table.
func DirFromTable(dir *Node, dt *table.Table) {
for i, cl := range dt.Columns.Values {
nm := dt.Columns.Keys[i]
nd, err := newNode(dir, nm)
if err == nil || err == fs.ErrExist {
nd.Tensor = cl
}
}
dir.DirTable = dt
}
// Float64 creates / returns a Node with given name as a [tensor.Float64]
// for given shape sizes, in given directory [Node].
// See [Values] function for more info.
func (dir *Node) Float64(name string, sizes ...int) *tensor.Float64 {
return Value[float64](dir, name, sizes...).(*tensor.Float64)
}
// Float32 creates / returns a Node with given name as a [tensor.Float32]
// for given shape sizes, in given directory [Node].
// See [Values] function for more info.
func (dir *Node) Float32(name string, sizes ...int) *tensor.Float32 {
return Value[float32](dir, name, sizes...).(*tensor.Float32)
}
// Int creates / returns a Node with given name as a [tensor.Int]
// for given shape sizes, in given directory [Node].
// See [Values] function for more info.
func (dir *Node) Int(name string, sizes ...int) *tensor.Int {
return Value[int](dir, name, sizes...).(*tensor.Int)
}
// StringValue creates / returns a Node with given name as a [tensor.String]
// for given shape sizes, in given directory [Node].
// See [Values] function for more info.
func (dir *Node) StringValue(name string, sizes ...int) *tensor.String {
return Value[string](dir, name, sizes...).(*tensor.String)
}
// Copyright (c) 2024, Cogent Lab. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package vector provides standard vector math functions that
// always operate on 1D views of tensor inputs regardless of the original
// vector shape.
package vector
import (
"math"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensor/tmath"
)
// Mul multiplies two vectors element-wise, using a 1D vector
// view of the two vectors, returning the output 1D vector.
func Mul(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2Float64(MulOut, a, b)
}
// MulOut multiplies two vectors element-wise, using a 1D vector
// view of the two vectors, filling in values in the output 1D vector.
func MulOut(a, b tensor.Tensor, out tensor.Values) error {
return tmath.MulOut(tensor.As1D(a), tensor.As1D(b), out)
}
// Sum returns the sum of all values in the tensor, as a scalar.
func Sum(a tensor.Tensor) tensor.Values {
n := a.Len()
sum := 0.0
tensor.Vectorize(func(tsr ...tensor.Tensor) int { return n },
func(idx int, tsr ...tensor.Tensor) {
sum += tsr[0].Float1D(idx)
}, a)
return tensor.NewFloat64Scalar(sum)
}
// Dot performs the vector dot product: the [Sum] of the [Mul] product
// of the two tensors, returning a scalar value. Also known as the inner product.
func Dot(a, b tensor.Tensor) tensor.Values {
return Sum(Mul(a, b))
}
// L2Norm returns the length of the vector as the L2 Norm:
// square root of the sum of squared values of the vector, as a scalar.
// This is the Sqrt of the [Dot] product of the vector with itself.
func L2Norm(a tensor.Tensor) tensor.Values {
dot := Dot(a, a).Float1D(0)
return tensor.NewFloat64Scalar(math.Sqrt(dot))
}
// L1Norm returns the length of the vector as the L1 Norm:
// sum of the absolute values of the tensor, as a scalar.
func L1Norm(a tensor.Tensor) tensor.Values {
n := a.Len()
sum := 0.0
tensor.Vectorize(func(tsr ...tensor.Tensor) int { return n },
func(idx int, tsr ...tensor.Tensor) {
sum += math.Abs(tsr[0].Float1D(idx))
}, a)
return tensor.NewFloat64Scalar(sum)
}
// Code generated by 'yaegi extract cogentcore.org/lab/lab'. DO NOT EDIT.
package gui
import (
"cogentcore.org/core/core"
"cogentcore.org/lab/lab"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/lab/lab"] = map[string]reflect.Value{
// function, constant and variable definitions
"AsDataTree": reflect.ValueOf(lab.AsDataTree),
"CurTabber": reflect.ValueOf(&lab.CurTabber).Elem(),
"FirstComment": reflect.ValueOf(lab.FirstComment),
"IsTableFile": reflect.ValueOf(lab.IsTableFile),
"NewBasic": reflect.ValueOf(lab.NewBasic),
"NewBasicWindow": reflect.ValueOf(lab.NewBasicWindow),
"NewDataTree": reflect.ValueOf(lab.NewDataTree),
"NewDiffBrowserDirs": reflect.ValueOf(lab.NewDiffBrowserDirs),
"NewFileNode": reflect.ValueOf(lab.NewFileNode),
"NewTabs": reflect.ValueOf(lab.NewTabs),
"PromptOKCancel": reflect.ValueOf(lab.PromptOKCancel),
"PromptString": reflect.ValueOf(lab.PromptString),
"PromptStruct": reflect.ValueOf(lab.PromptStruct),
"TensorFS": reflect.ValueOf(lab.TensorFS),
"TheBrowser": reflect.ValueOf(&lab.TheBrowser).Elem(),
"TrimOrderPrefix": reflect.ValueOf(lab.TrimOrderPrefix),
// type definitions
"Basic": reflect.ValueOf((*lab.Basic)(nil)),
"Browser": reflect.ValueOf((*lab.Browser)(nil)),
"DataTree": reflect.ValueOf((*lab.DataTree)(nil)),
"FileNode": reflect.ValueOf((*lab.FileNode)(nil)),
"Tabber": reflect.ValueOf((*lab.Tabber)(nil)),
"Tabs": reflect.ValueOf((*lab.Tabs)(nil)),
"Treer": reflect.ValueOf((*lab.Treer)(nil)),
// interface wrapper definitions
"_Tabber": reflect.ValueOf((*_cogentcore_org_lab_lab_Tabber)(nil)),
"_Treer": reflect.ValueOf((*_cogentcore_org_lab_lab_Treer)(nil)),
}
}
// _cogentcore_org_lab_lab_Tabber is an interface wrapper for Tabber type
type _cogentcore_org_lab_lab_Tabber struct {
IValue interface{}
WAsCoreTabs func() *core.Tabs
WAsLab func() *lab.Tabs
}
func (W _cogentcore_org_lab_lab_Tabber) AsCoreTabs() *core.Tabs { return W.WAsCoreTabs() }
func (W _cogentcore_org_lab_lab_Tabber) AsLab() *lab.Tabs { return W.WAsLab() }
// _cogentcore_org_lab_lab_Treer is an interface wrapper for Treer type
type _cogentcore_org_lab_lab_Treer struct {
IValue interface{}
WAsDataTree func() *lab.DataTree
}
func (W _cogentcore_org_lab_lab_Treer) AsDataTree() *lab.DataTree { return W.WAsDataTree() }
// Code generated by 'yaegi extract cogentcore.org/lab/plot/plots'. DO NOT EDIT.
package gui
import (
"cogentcore.org/lab/plot/plots"
"go/constant"
"go/token"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/plot/plots/plots"] = map[string]reflect.Value{
// function, constant and variable definitions
"BarType": reflect.ValueOf(constant.MakeFromLiteral("\"Bar\"", token.STRING, 0)),
"LabelsType": reflect.ValueOf(constant.MakeFromLiteral("\"Labels\"", token.STRING, 0)),
"NewBar": reflect.ValueOf(plots.NewBar),
"NewLabels": reflect.ValueOf(plots.NewLabels),
"NewLine": reflect.ValueOf(plots.NewLine),
"NewScatter": reflect.ValueOf(plots.NewScatter),
"NewXErrorBars": reflect.ValueOf(plots.NewXErrorBars),
"NewXY": reflect.ValueOf(plots.NewXY),
"NewYErrorBars": reflect.ValueOf(plots.NewYErrorBars),
"XErrorBarsType": reflect.ValueOf(constant.MakeFromLiteral("\"XErrorBars\"", token.STRING, 0)),
"XYType": reflect.ValueOf(constant.MakeFromLiteral("\"XY\"", token.STRING, 0)),
"YErrorBarsType": reflect.ValueOf(constant.MakeFromLiteral("\"YErrorBars\"", token.STRING, 0)),
// type definitions
"Bar": reflect.ValueOf((*plots.Bar)(nil)),
"Labels": reflect.ValueOf((*plots.Labels)(nil)),
"XErrorBars": reflect.ValueOf((*plots.XErrorBars)(nil)),
"XY": reflect.ValueOf((*plots.XY)(nil)),
"YErrorBars": reflect.ValueOf((*plots.YErrorBars)(nil)),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/plot'. DO NOT EDIT.
package gui
import (
"cogentcore.org/core/math32/minmax"
"cogentcore.org/lab/plot"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/plot/plot"] = map[string]reflect.Value{
// function, constant and variable definitions
"AddStylerTo": reflect.ValueOf(plot.AddStylerTo),
"AxisScalesN": reflect.ValueOf(plot.AxisScalesN),
"AxisScalesValues": reflect.ValueOf(plot.AxisScalesValues),
"Box": reflect.ValueOf(plot.Box),
"CheckFloats": reflect.ValueOf(plot.CheckFloats),
"CheckNaNs": reflect.ValueOf(plot.CheckNaNs),
"Circle": reflect.ValueOf(plot.Circle),
"Color": reflect.ValueOf(plot.Color),
"CopyRole": reflect.ValueOf(plot.CopyRole),
"CopyValues": reflect.ValueOf(plot.CopyValues),
"Cross": reflect.ValueOf(plot.Cross),
"Default": reflect.ValueOf(plot.Default),
"DefaultFontFamily": reflect.ValueOf(&plot.DefaultFontFamily).Elem(),
"DefaultOffOnN": reflect.ValueOf(plot.DefaultOffOnN),
"DefaultOffOnValues": reflect.ValueOf(plot.DefaultOffOnValues),
"DrawBox": reflect.ValueOf(plot.DrawBox),
"DrawCircle": reflect.ValueOf(plot.DrawCircle),
"DrawCross": reflect.ValueOf(plot.DrawCross),
"DrawPlus": reflect.ValueOf(plot.DrawPlus),
"DrawPyramid": reflect.ValueOf(plot.DrawPyramid),
"DrawRing": reflect.ValueOf(plot.DrawRing),
"DrawSquare": reflect.ValueOf(plot.DrawSquare),
"DrawTriangle": reflect.ValueOf(plot.DrawTriangle),
"ErrInfinity": reflect.ValueOf(&plot.ErrInfinity).Elem(),
"ErrNoData": reflect.ValueOf(&plot.ErrNoData).Elem(),
"GetStylersFrom": reflect.ValueOf(plot.GetStylersFrom),
"GetStylersFromData": reflect.ValueOf(plot.GetStylersFromData),
"High": reflect.ValueOf(plot.High),
"InverseLinear": reflect.ValueOf(plot.InverseLinear),
"InverseLog": reflect.ValueOf(plot.InverseLog),
"Label": reflect.ValueOf(plot.Label),
"Linear": reflect.ValueOf(plot.Linear),
"Log": reflect.ValueOf(plot.Log),
"Low": reflect.ValueOf(plot.Low),
"MidStep": reflect.ValueOf(plot.MidStep),
"MustCopyRole": reflect.ValueOf(plot.MustCopyRole),
"New": reflect.ValueOf(plot.New),
"NewPlotter": reflect.ValueOf(plot.NewPlotter),
"NewStyle": reflect.ValueOf(plot.NewStyle),
"NewTablePlot": reflect.ValueOf(plot.NewTablePlot),
"NoRole": reflect.ValueOf(plot.NoRole),
"NoStep": reflect.ValueOf(plot.NoStep),
"Off": reflect.ValueOf(plot.Off),
"On": reflect.ValueOf(plot.On),
"PlotX": reflect.ValueOf(plot.PlotX),
"PlotY": reflect.ValueOf(plot.PlotY),
"PlotterByType": reflect.ValueOf(plot.PlotterByType),
"Plotters": reflect.ValueOf(&plot.Plotters).Elem(),
"Plus": reflect.ValueOf(plot.Plus),
"PostStep": reflect.ValueOf(plot.PostStep),
"PreStep": reflect.ValueOf(plot.PreStep),
"Pyramid": reflect.ValueOf(plot.Pyramid),
"Range": reflect.ValueOf(plot.Range),
"RangeClamp": reflect.ValueOf(plot.RangeClamp),
"RegisterPlotter": reflect.ValueOf(plot.RegisterPlotter),
"Ring": reflect.ValueOf(plot.Ring),
"RolesN": reflect.ValueOf(plot.RolesN),
"RolesValues": reflect.ValueOf(plot.RolesValues),
"SetFirstStylerTo": reflect.ValueOf(plot.SetFirstStylerTo),
"SetStylerTo": reflect.ValueOf(plot.SetStylerTo),
"SetStylersTo": reflect.ValueOf(plot.SetStylersTo),
"ShapesN": reflect.ValueOf(plot.ShapesN),
"ShapesValues": reflect.ValueOf(plot.ShapesValues),
"Size": reflect.ValueOf(plot.Size),
"Square": reflect.ValueOf(plot.Square),
"StepKindN": reflect.ValueOf(plot.StepKindN),
"StepKindValues": reflect.ValueOf(plot.StepKindValues),
"Triangle": reflect.ValueOf(plot.Triangle),
"U": reflect.ValueOf(plot.U),
"UTCUnixTime": reflect.ValueOf(&plot.UTCUnixTime).Elem(),
"UnixTimeIn": reflect.ValueOf(plot.UnixTimeIn),
"V": reflect.ValueOf(plot.V),
"W": reflect.ValueOf(plot.W),
"X": reflect.ValueOf(plot.X),
"Y": reflect.ValueOf(plot.Y),
"Z": reflect.ValueOf(plot.Z),
// type definitions
"Axis": reflect.ValueOf((*plot.Axis)(nil)),
"AxisScales": reflect.ValueOf((*plot.AxisScales)(nil)),
"AxisStyle": reflect.ValueOf((*plot.AxisStyle)(nil)),
"ConstantTicks": reflect.ValueOf((*plot.ConstantTicks)(nil)),
"Data": reflect.ValueOf((*plot.Data)(nil)),
"DefaultOffOn": reflect.ValueOf((*plot.DefaultOffOn)(nil)),
"DefaultTicks": reflect.ValueOf((*plot.DefaultTicks)(nil)),
"InvertedScale": reflect.ValueOf((*plot.InvertedScale)(nil)),
"Labels": reflect.ValueOf((*plot.Labels)(nil)),
"Legend": reflect.ValueOf((*plot.Legend)(nil)),
"LegendEntry": reflect.ValueOf((*plot.LegendEntry)(nil)),
"LegendPosition": reflect.ValueOf((*plot.LegendPosition)(nil)),
"LegendStyle": reflect.ValueOf((*plot.LegendStyle)(nil)),
"LineStyle": reflect.ValueOf((*plot.LineStyle)(nil)),
"LinearScale": reflect.ValueOf((*plot.LinearScale)(nil)),
"LogScale": reflect.ValueOf((*plot.LogScale)(nil)),
"LogTicks": reflect.ValueOf((*plot.LogTicks)(nil)),
"Normalizer": reflect.ValueOf((*plot.Normalizer)(nil)),
"PanZoom": reflect.ValueOf((*plot.PanZoom)(nil)),
"Plot": reflect.ValueOf((*plot.Plot)(nil)),
"PlotStyle": reflect.ValueOf((*plot.PlotStyle)(nil)),
"Plotter": reflect.ValueOf((*plot.Plotter)(nil)),
"PlotterName": reflect.ValueOf((*plot.PlotterName)(nil)),
"PlotterType": reflect.ValueOf((*plot.PlotterType)(nil)),
"PointStyle": reflect.ValueOf((*plot.PointStyle)(nil)),
"Roles": reflect.ValueOf((*plot.Roles)(nil)),
"Shapes": reflect.ValueOf((*plot.Shapes)(nil)),
"StepKind": reflect.ValueOf((*plot.StepKind)(nil)),
"Style": reflect.ValueOf((*plot.Style)(nil)),
"Stylers": reflect.ValueOf((*plot.Stylers)(nil)),
"Text": reflect.ValueOf((*plot.Text)(nil)),
"TextStyle": reflect.ValueOf((*plot.TextStyle)(nil)),
"Thumbnailer": reflect.ValueOf((*plot.Thumbnailer)(nil)),
"Tick": reflect.ValueOf((*plot.Tick)(nil)),
"Ticker": reflect.ValueOf((*plot.Ticker)(nil)),
"TimeTicks": reflect.ValueOf((*plot.TimeTicks)(nil)),
"Valuer": reflect.ValueOf((*plot.Valuer)(nil)),
"Values": reflect.ValueOf((*plot.Values)(nil)),
"WidthStyle": reflect.ValueOf((*plot.WidthStyle)(nil)),
"XAxisStyle": reflect.ValueOf((*plot.XAxisStyle)(nil)),
// interface wrapper definitions
"_Normalizer": reflect.ValueOf((*_cogentcore_org_lab_plot_Normalizer)(nil)),
"_Plotter": reflect.ValueOf((*_cogentcore_org_lab_plot_Plotter)(nil)),
"_Thumbnailer": reflect.ValueOf((*_cogentcore_org_lab_plot_Thumbnailer)(nil)),
"_Ticker": reflect.ValueOf((*_cogentcore_org_lab_plot_Ticker)(nil)),
"_Valuer": reflect.ValueOf((*_cogentcore_org_lab_plot_Valuer)(nil)),
}
}
// _cogentcore_org_lab_plot_Normalizer is an interface wrapper for Normalizer type
type _cogentcore_org_lab_plot_Normalizer struct {
IValue interface{}
WNormalize func(min float64, max float64, x float64) float64
}
func (W _cogentcore_org_lab_plot_Normalizer) Normalize(min float64, max float64, x float64) float64 {
return W.WNormalize(min, max, x)
}
// _cogentcore_org_lab_plot_Plotter is an interface wrapper for Plotter type
type _cogentcore_org_lab_plot_Plotter struct {
IValue interface{}
WApplyStyle func(plotStyle *plot.PlotStyle)
WData func() (data plot.Data, pixX []float32, pixY []float32)
WPlot func(pt *plot.Plot)
WStylers func() *plot.Stylers
WUpdateRange func(plt *plot.Plot, xr *minmax.F64, yr *minmax.F64, zr *minmax.F64)
}
func (W _cogentcore_org_lab_plot_Plotter) ApplyStyle(plotStyle *plot.PlotStyle) {
W.WApplyStyle(plotStyle)
}
func (W _cogentcore_org_lab_plot_Plotter) Data() (data plot.Data, pixX []float32, pixY []float32) {
return W.WData()
}
func (W _cogentcore_org_lab_plot_Plotter) Plot(pt *plot.Plot) { W.WPlot(pt) }
func (W _cogentcore_org_lab_plot_Plotter) Stylers() *plot.Stylers { return W.WStylers() }
func (W _cogentcore_org_lab_plot_Plotter) UpdateRange(plt *plot.Plot, xr *minmax.F64, yr *minmax.F64, zr *minmax.F64) {
W.WUpdateRange(plt, xr, yr, zr)
}
// _cogentcore_org_lab_plot_Thumbnailer is an interface wrapper for Thumbnailer type
type _cogentcore_org_lab_plot_Thumbnailer struct {
IValue interface{}
WThumbnail func(pt *plot.Plot)
}
func (W _cogentcore_org_lab_plot_Thumbnailer) Thumbnail(pt *plot.Plot) { W.WThumbnail(pt) }
// _cogentcore_org_lab_plot_Ticker is an interface wrapper for Ticker type
type _cogentcore_org_lab_plot_Ticker struct {
IValue interface{}
WTicks func(min float64, max float64, nticks int) []plot.Tick
}
func (W _cogentcore_org_lab_plot_Ticker) Ticks(min float64, max float64, nticks int) []plot.Tick {
return W.WTicks(min, max, nticks)
}
// _cogentcore_org_lab_plot_Valuer is an interface wrapper for Valuer type
type _cogentcore_org_lab_plot_Valuer struct {
IValue interface{}
WFloat1D func(i int) float64
WLen func() int
WString1D func(i int) string
}
func (W _cogentcore_org_lab_plot_Valuer) Float1D(i int) float64 { return W.WFloat1D(i) }
func (W _cogentcore_org_lab_plot_Valuer) Len() int { return W.WLen() }
func (W _cogentcore_org_lab_plot_Valuer) String1D(i int) string { return W.WString1D(i) }
// Code generated by 'yaegi extract cogentcore.org/lab/plotcore'. DO NOT EDIT.
package gui
import (
"cogentcore.org/lab/plotcore"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/plotcore/plotcore"] = map[string]reflect.Value{
// function, constant and variable definitions
"NewPlot": reflect.ValueOf(plotcore.NewPlot),
"NewPlotEditor": reflect.ValueOf(plotcore.NewPlotEditor),
"NewPlotterChooser": reflect.ValueOf(plotcore.NewPlotterChooser),
"NewSubPlot": reflect.ValueOf(plotcore.NewSubPlot),
// type definitions
"Plot": reflect.ValueOf((*plotcore.Plot)(nil)),
"PlotEditor": reflect.ValueOf((*plotcore.PlotEditor)(nil)),
"PlotterChooser": reflect.ValueOf((*plotcore.PlotterChooser)(nil)),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/tensorcore'. DO NOT EDIT.
package gui
import (
"cogentcore.org/lab/tensorcore"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/tensorcore/tensorcore"] = map[string]reflect.Value{
// function, constant and variable definitions
"AddGridStylerTo": reflect.ValueOf(tensorcore.AddGridStylerTo),
"GetGridStylersFrom": reflect.ValueOf(tensorcore.GetGridStylersFrom),
"NewGridStyle": reflect.ValueOf(tensorcore.NewGridStyle),
"NewTable": reflect.ValueOf(tensorcore.NewTable),
"NewTableButton": reflect.ValueOf(tensorcore.NewTableButton),
"NewTensorButton": reflect.ValueOf(tensorcore.NewTensorButton),
"NewTensorEditor": reflect.ValueOf(tensorcore.NewTensorEditor),
"NewTensorGrid": reflect.ValueOf(tensorcore.NewTensorGrid),
"SetGridStylersTo": reflect.ValueOf(tensorcore.SetGridStylersTo),
// type definitions
"GridStyle": reflect.ValueOf((*tensorcore.GridStyle)(nil)),
"GridStylers": reflect.ValueOf((*tensorcore.GridStylers)(nil)),
"Layout": reflect.ValueOf((*tensorcore.Layout)(nil)),
"Table": reflect.ValueOf((*tensorcore.Table)(nil)),
"TableButton": reflect.ValueOf((*tensorcore.TableButton)(nil)),
"TensorButton": reflect.ValueOf((*tensorcore.TensorButton)(nil)),
"TensorEditor": reflect.ValueOf((*tensorcore.TensorEditor)(nil)),
"TensorGrid": reflect.ValueOf((*tensorcore.TensorGrid)(nil)),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/goal/goalib'. DO NOT EDIT.
package nogui
import (
"cogentcore.org/lab/goal/goalib"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/goal/goalib/goalib"] = map[string]reflect.Value{
// function, constant and variable definitions
"AllFiles": reflect.ValueOf(goalib.AllFiles),
"FileExists": reflect.ValueOf(goalib.FileExists),
"ReadFile": reflect.ValueOf(goalib.ReadFile),
"ReplaceInFile": reflect.ValueOf(goalib.ReplaceInFile),
"SplitLines": reflect.ValueOf(goalib.SplitLines),
"StringsToAnys": reflect.ValueOf(goalib.StringsToAnys),
"WriteFile": reflect.ValueOf(goalib.WriteFile),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/matrix'. DO NOT EDIT.
package nogui
import (
"cogentcore.org/lab/matrix"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/matrix/matrix"] = map[string]reflect.Value{
// function, constant and variable definitions
"CallOut1": reflect.ValueOf(matrix.CallOut1),
"CallOut2": reflect.ValueOf(matrix.CallOut2),
"CopyFromDense": reflect.ValueOf(matrix.CopyFromDense),
"Det": reflect.ValueOf(matrix.Det),
"Diagonal": reflect.ValueOf(matrix.Diagonal),
"DiagonalIndices": reflect.ValueOf(matrix.DiagonalIndices),
"DiagonalN": reflect.ValueOf(matrix.DiagonalN),
"Eig": reflect.ValueOf(matrix.Eig),
"EigOut": reflect.ValueOf(matrix.EigOut),
"EigSym": reflect.ValueOf(matrix.EigSym),
"EigSymOut": reflect.ValueOf(matrix.EigSymOut),
"Identity": reflect.ValueOf(matrix.Identity),
"Inverse": reflect.ValueOf(matrix.Inverse),
"InverseOut": reflect.ValueOf(matrix.InverseOut),
"LogDet": reflect.ValueOf(matrix.LogDet),
"Mul": reflect.ValueOf(matrix.Mul),
"MulOut": reflect.ValueOf(matrix.MulOut),
"NewDense": reflect.ValueOf(matrix.NewDense),
"NewMatrix": reflect.ValueOf(matrix.NewMatrix),
"NewSymmetric": reflect.ValueOf(matrix.NewSymmetric),
"ProjectOnMatrixColumn": reflect.ValueOf(matrix.ProjectOnMatrixColumn),
"ProjectOnMatrixColumnOut": reflect.ValueOf(matrix.ProjectOnMatrixColumnOut),
"SVD": reflect.ValueOf(matrix.SVD),
"SVDOut": reflect.ValueOf(matrix.SVDOut),
"SVDValues": reflect.ValueOf(matrix.SVDValues),
"SVDValuesOut": reflect.ValueOf(matrix.SVDValuesOut),
"StringCheck": reflect.ValueOf(matrix.StringCheck),
"Trace": reflect.ValueOf(matrix.Trace),
"Tri": reflect.ValueOf(matrix.Tri),
"TriL": reflect.ValueOf(matrix.TriL),
"TriLIndicies": reflect.ValueOf(matrix.TriLIndicies),
"TriLNum": reflect.ValueOf(matrix.TriLNum),
"TriLView": reflect.ValueOf(matrix.TriLView),
"TriU": reflect.ValueOf(matrix.TriU),
"TriUIndicies": reflect.ValueOf(matrix.TriUIndicies),
"TriUNum": reflect.ValueOf(matrix.TriUNum),
"TriUView": reflect.ValueOf(matrix.TriUView),
"TriUpper": reflect.ValueOf(matrix.TriUpper),
// type definitions
"Matrix": reflect.ValueOf((*matrix.Matrix)(nil)),
"Symmetric": reflect.ValueOf((*matrix.Symmetric)(nil)),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/stats/cluster'. DO NOT EDIT.
package nogui
import (
"cogentcore.org/lab/stats/cluster"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/stats/cluster/cluster"] = map[string]reflect.Value{
// function, constant and variable definitions
"Avg": reflect.ValueOf(cluster.Avg),
"AvgFunc": reflect.ValueOf(cluster.AvgFunc),
"Call": reflect.ValueOf(cluster.Call),
"Cluster": reflect.ValueOf(cluster.Cluster),
"Contrast": reflect.ValueOf(cluster.Contrast),
"ContrastFunc": reflect.ValueOf(cluster.ContrastFunc),
"Funcs": reflect.ValueOf(&cluster.Funcs).Elem(),
"Glom": reflect.ValueOf(cluster.Glom),
"InitAllLeaves": reflect.ValueOf(cluster.InitAllLeaves),
"Max": reflect.ValueOf(cluster.Max),
"MaxFunc": reflect.ValueOf(cluster.MaxFunc),
"MetricsN": reflect.ValueOf(cluster.MetricsN),
"MetricsValues": reflect.ValueOf(cluster.MetricsValues),
"Min": reflect.ValueOf(cluster.Min),
"MinFunc": reflect.ValueOf(cluster.MinFunc),
"NewNode": reflect.ValueOf(cluster.NewNode),
"Plot": reflect.ValueOf(cluster.Plot),
// type definitions
"MetricFunc": reflect.ValueOf((*cluster.MetricFunc)(nil)),
"Metrics": reflect.ValueOf((*cluster.Metrics)(nil)),
"Node": reflect.ValueOf((*cluster.Node)(nil)),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/stats/convolve'. DO NOT EDIT.
package nogui
import (
"cogentcore.org/lab/stats/convolve"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/stats/convolve/convolve"] = map[string]reflect.Value{
// function, constant and variable definitions
"GaussianKernel32": reflect.ValueOf(convolve.GaussianKernel32),
"GaussianKernel64": reflect.ValueOf(convolve.GaussianKernel64),
"Slice32": reflect.ValueOf(convolve.Slice32),
"Slice64": reflect.ValueOf(convolve.Slice64),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/stats/glm'. DO NOT EDIT.
package nogui
import (
"cogentcore.org/lab/stats/glm"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/stats/glm/glm"] = map[string]reflect.Value{
// function, constant and variable definitions
"NewGLM": reflect.ValueOf(glm.NewGLM),
// type definitions
"GLM": reflect.ValueOf((*glm.GLM)(nil)),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/stats/histogram'. DO NOT EDIT.
package nogui
import (
"cogentcore.org/lab/stats/histogram"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/stats/histogram/histogram"] = map[string]reflect.Value{
// function, constant and variable definitions
"F64": reflect.ValueOf(histogram.F64),
"F64Table": reflect.ValueOf(histogram.F64Table),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/stats/metric'. DO NOT EDIT.
package nogui
import (
"cogentcore.org/lab/stats/metric"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/stats/metric/metric"] = map[string]reflect.Value{
// function, constant and variable definitions
"AsMetricFunc": reflect.ValueOf(metric.AsMetricFunc),
"AsMetricOutFunc": reflect.ValueOf(metric.AsMetricOutFunc),
"ClosestRow": reflect.ValueOf(metric.ClosestRow),
"ClosestRowOut": reflect.ValueOf(metric.ClosestRowOut),
"Correlation": reflect.ValueOf(metric.Correlation),
"CorrelationOut": reflect.ValueOf(metric.CorrelationOut),
"CorrelationOut64": reflect.ValueOf(metric.CorrelationOut64),
"Cosine": reflect.ValueOf(metric.Cosine),
"CosineOut": reflect.ValueOf(metric.CosineOut),
"CosineOut64": reflect.ValueOf(metric.CosineOut64),
"Covariance": reflect.ValueOf(metric.Covariance),
"CovarianceMatrix": reflect.ValueOf(metric.CovarianceMatrix),
"CovarianceMatrixOut": reflect.ValueOf(metric.CovarianceMatrixOut),
"CovarianceOut": reflect.ValueOf(metric.CovarianceOut),
"CrossEntropy": reflect.ValueOf(metric.CrossEntropy),
"CrossEntropyOut": reflect.ValueOf(metric.CrossEntropyOut),
"CrossMatrix": reflect.ValueOf(metric.CrossMatrix),
"CrossMatrixOut": reflect.ValueOf(metric.CrossMatrixOut),
"DotProduct": reflect.ValueOf(metric.DotProduct),
"DotProductOut": reflect.ValueOf(metric.DotProductOut),
"Hamming": reflect.ValueOf(metric.Hamming),
"HammingOut": reflect.ValueOf(metric.HammingOut),
"InvCorrelation": reflect.ValueOf(metric.InvCorrelation),
"InvCorrelationOut": reflect.ValueOf(metric.InvCorrelationOut),
"InvCosine": reflect.ValueOf(metric.InvCosine),
"InvCosineOut": reflect.ValueOf(metric.InvCosineOut),
"L1Norm": reflect.ValueOf(metric.L1Norm),
"L1NormOut": reflect.ValueOf(metric.L1NormOut),
"L2Norm": reflect.ValueOf(metric.L2Norm),
"L2NormBinTol": reflect.ValueOf(metric.L2NormBinTol),
"L2NormBinTolOut": reflect.ValueOf(metric.L2NormBinTolOut),
"L2NormOut": reflect.ValueOf(metric.L2NormOut),
"Matrix": reflect.ValueOf(metric.Matrix),
"MatrixOut": reflect.ValueOf(metric.MatrixOut),
"MetricCorrelation": reflect.ValueOf(metric.MetricCorrelation),
"MetricCosine": reflect.ValueOf(metric.MetricCosine),
"MetricCovariance": reflect.ValueOf(metric.MetricCovariance),
"MetricCrossEntropy": reflect.ValueOf(metric.MetricCrossEntropy),
"MetricDotProduct": reflect.ValueOf(metric.MetricDotProduct),
"MetricHamming": reflect.ValueOf(metric.MetricHamming),
"MetricInvCorrelation": reflect.ValueOf(metric.MetricInvCorrelation),
"MetricInvCosine": reflect.ValueOf(metric.MetricInvCosine),
"MetricL1Norm": reflect.ValueOf(metric.MetricL1Norm),
"MetricL2Norm": reflect.ValueOf(metric.MetricL2Norm),
"MetricL2NormBinTol": reflect.ValueOf(metric.MetricL2NormBinTol),
"MetricSumSquares": reflect.ValueOf(metric.MetricSumSquares),
"MetricSumSquaresBinTol": reflect.ValueOf(metric.MetricSumSquaresBinTol),
"MetricsN": reflect.ValueOf(metric.MetricsN),
"MetricsValues": reflect.ValueOf(metric.MetricsValues),
"SumSquares": reflect.ValueOf(metric.SumSquares),
"SumSquaresBinTol": reflect.ValueOf(metric.SumSquaresBinTol),
"SumSquaresBinTolOut": reflect.ValueOf(metric.SumSquaresBinTolOut),
"SumSquaresBinTolScaleOut64": reflect.ValueOf(metric.SumSquaresBinTolScaleOut64),
"SumSquaresOut": reflect.ValueOf(metric.SumSquaresOut),
"SumSquaresOut64": reflect.ValueOf(metric.SumSquaresOut64),
"SumSquaresScaleOut64": reflect.ValueOf(metric.SumSquaresScaleOut64),
"Vectorize2Out64": reflect.ValueOf(metric.Vectorize2Out64),
"Vectorize3Out64": reflect.ValueOf(metric.Vectorize3Out64),
"VectorizeOut64": reflect.ValueOf(metric.VectorizeOut64),
"VectorizePre3Out64": reflect.ValueOf(metric.VectorizePre3Out64),
"VectorizePreOut64": reflect.ValueOf(metric.VectorizePreOut64),
// type definitions
"MetricFunc": reflect.ValueOf((*metric.MetricFunc)(nil)),
"MetricOutFunc": reflect.ValueOf((*metric.MetricOutFunc)(nil)),
"Metrics": reflect.ValueOf((*metric.Metrics)(nil)),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/stats/stats'. DO NOT EDIT.
package nogui
import (
"cogentcore.org/lab/stats/stats"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/stats/stats/stats"] = map[string]reflect.Value{
// function, constant and variable definitions
"AsStatsFunc": reflect.ValueOf(stats.AsStatsFunc),
"Binarize": reflect.ValueOf(stats.Binarize),
"BinarizeOut": reflect.ValueOf(stats.BinarizeOut),
"Clamp": reflect.ValueOf(stats.Clamp),
"ClampOut": reflect.ValueOf(stats.ClampOut),
"Count": reflect.ValueOf(stats.Count),
"CountOut": reflect.ValueOf(stats.CountOut),
"CountOut64": reflect.ValueOf(stats.CountOut64),
"Describe": reflect.ValueOf(stats.Describe),
"DescribeTable": reflect.ValueOf(stats.DescribeTable),
"DescribeTableAll": reflect.ValueOf(stats.DescribeTableAll),
"DescriptiveStats": reflect.ValueOf(&stats.DescriptiveStats).Elem(),
"Final": reflect.ValueOf(stats.Final),
"FinalOut": reflect.ValueOf(stats.FinalOut),
"First": reflect.ValueOf(stats.First),
"FirstOut": reflect.ValueOf(stats.FirstOut),
"GroupAll": reflect.ValueOf(stats.GroupAll),
"GroupDescribe": reflect.ValueOf(stats.GroupDescribe),
"GroupStats": reflect.ValueOf(stats.GroupStats),
"GroupStatsAsTable": reflect.ValueOf(stats.GroupStatsAsTable),
"GroupStatsAsTableNoStatName": reflect.ValueOf(stats.GroupStatsAsTableNoStatName),
"Groups": reflect.ValueOf(stats.Groups),
"L1Norm": reflect.ValueOf(stats.L1Norm),
"L1NormOut": reflect.ValueOf(stats.L1NormOut),
"L2Norm": reflect.ValueOf(stats.L2Norm),
"L2NormOut": reflect.ValueOf(stats.L2NormOut),
"L2NormOut64": reflect.ValueOf(stats.L2NormOut64),
"Max": reflect.ValueOf(stats.Max),
"MaxAbs": reflect.ValueOf(stats.MaxAbs),
"MaxAbsOut": reflect.ValueOf(stats.MaxAbsOut),
"MaxOut": reflect.ValueOf(stats.MaxOut),
"Mean": reflect.ValueOf(stats.Mean),
"MeanOut": reflect.ValueOf(stats.MeanOut),
"MeanOut64": reflect.ValueOf(stats.MeanOut64),
"Median": reflect.ValueOf(stats.Median),
"MedianOut": reflect.ValueOf(stats.MedianOut),
"Min": reflect.ValueOf(stats.Min),
"MinAbs": reflect.ValueOf(stats.MinAbs),
"MinAbsOut": reflect.ValueOf(stats.MinAbsOut),
"MinOut": reflect.ValueOf(stats.MinOut),
"Prod": reflect.ValueOf(stats.Prod),
"ProdOut": reflect.ValueOf(stats.ProdOut),
"Q1": reflect.ValueOf(stats.Q1),
"Q1Out": reflect.ValueOf(stats.Q1Out),
"Q3": reflect.ValueOf(stats.Q3),
"Q3Out": reflect.ValueOf(stats.Q3Out),
"Quantiles": reflect.ValueOf(stats.Quantiles),
"QuantilesOut": reflect.ValueOf(stats.QuantilesOut),
"Sem": reflect.ValueOf(stats.Sem),
"SemOut": reflect.ValueOf(stats.SemOut),
"SemPop": reflect.ValueOf(stats.SemPop),
"SemPopOut": reflect.ValueOf(stats.SemPopOut),
"StatCount": reflect.ValueOf(stats.StatCount),
"StatFinal": reflect.ValueOf(stats.StatFinal),
"StatFirst": reflect.ValueOf(stats.StatFirst),
"StatL1Norm": reflect.ValueOf(stats.StatL1Norm),
"StatL2Norm": reflect.ValueOf(stats.StatL2Norm),
"StatMax": reflect.ValueOf(stats.StatMax),
"StatMaxAbs": reflect.ValueOf(stats.StatMaxAbs),
"StatMean": reflect.ValueOf(stats.StatMean),
"StatMedian": reflect.ValueOf(stats.StatMedian),
"StatMin": reflect.ValueOf(stats.StatMin),
"StatMinAbs": reflect.ValueOf(stats.StatMinAbs),
"StatProd": reflect.ValueOf(stats.StatProd),
"StatQ1": reflect.ValueOf(stats.StatQ1),
"StatQ3": reflect.ValueOf(stats.StatQ3),
"StatSem": reflect.ValueOf(stats.StatSem),
"StatSemPop": reflect.ValueOf(stats.StatSemPop),
"StatStd": reflect.ValueOf(stats.StatStd),
"StatStdPop": reflect.ValueOf(stats.StatStdPop),
"StatSum": reflect.ValueOf(stats.StatSum),
"StatSumSq": reflect.ValueOf(stats.StatSumSq),
"StatVar": reflect.ValueOf(stats.StatVar),
"StatVarPop": reflect.ValueOf(stats.StatVarPop),
"StatsN": reflect.ValueOf(stats.StatsN),
"StatsValues": reflect.ValueOf(stats.StatsValues),
"Std": reflect.ValueOf(stats.Std),
"StdOut": reflect.ValueOf(stats.StdOut),
"StdOut64": reflect.ValueOf(stats.StdOut64),
"StdPop": reflect.ValueOf(stats.StdPop),
"StdPopOut": reflect.ValueOf(stats.StdPopOut),
"StripPackage": reflect.ValueOf(stats.StripPackage),
"Sum": reflect.ValueOf(stats.Sum),
"SumOut": reflect.ValueOf(stats.SumOut),
"SumOut64": reflect.ValueOf(stats.SumOut64),
"SumSq": reflect.ValueOf(stats.SumSq),
"SumSqDevOut64": reflect.ValueOf(stats.SumSqDevOut64),
"SumSqOut": reflect.ValueOf(stats.SumSqOut),
"SumSqOut64": reflect.ValueOf(stats.SumSqOut64),
"SumSqScaleOut64": reflect.ValueOf(stats.SumSqScaleOut64),
"TableGroupDescribe": reflect.ValueOf(stats.TableGroupDescribe),
"TableGroupStats": reflect.ValueOf(stats.TableGroupStats),
"TableGroups": reflect.ValueOf(stats.TableGroups),
"UnitNorm": reflect.ValueOf(stats.UnitNorm),
"UnitNormOut": reflect.ValueOf(stats.UnitNormOut),
"Var": reflect.ValueOf(stats.Var),
"VarOut": reflect.ValueOf(stats.VarOut),
"VarOut64": reflect.ValueOf(stats.VarOut64),
"VarPop": reflect.ValueOf(stats.VarPop),
"VarPopOut": reflect.ValueOf(stats.VarPopOut),
"VarPopOut64": reflect.ValueOf(stats.VarPopOut64),
"Vectorize2Out64": reflect.ValueOf(stats.Vectorize2Out64),
"VectorizeOut64": reflect.ValueOf(stats.VectorizeOut64),
"VectorizePreOut64": reflect.ValueOf(stats.VectorizePreOut64),
"ZScore": reflect.ValueOf(stats.ZScore),
"ZScoreOut": reflect.ValueOf(stats.ZScoreOut),
// type definitions
"Stats": reflect.ValueOf((*stats.Stats)(nil)),
"StatsFunc": reflect.ValueOf((*stats.StatsFunc)(nil)),
"StatsOutFunc": reflect.ValueOf((*stats.StatsOutFunc)(nil)),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/table'. DO NOT EDIT.
package nogui
import (
"cogentcore.org/lab/table"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/table/table"] = map[string]reflect.Value{
// function, constant and variable definitions
"CleanCatTSV": reflect.ValueOf(table.CleanCatTSV),
"ConfigFromDataValues": reflect.ValueOf(table.ConfigFromDataValues),
"ConfigFromHeaders": reflect.ValueOf(table.ConfigFromHeaders),
"ConfigFromTableHeaders": reflect.ValueOf(table.ConfigFromTableHeaders),
"DetectTableHeaders": reflect.ValueOf(table.DetectTableHeaders),
"ErrLogNoNewRows": reflect.ValueOf(&table.ErrLogNoNewRows).Elem(),
"Headers": reflect.ValueOf(table.Headers),
"InferDataType": reflect.ValueOf(table.InferDataType),
"New": reflect.ValueOf(table.New),
"NewColumns": reflect.ValueOf(table.NewColumns),
"NewSliceTable": reflect.ValueOf(table.NewSliceTable),
"NewView": reflect.ValueOf(table.NewView),
"NoHeaders": reflect.ValueOf(table.NoHeaders),
"ShapeFromString": reflect.ValueOf(table.ShapeFromString),
"TableColumnType": reflect.ValueOf(table.TableColumnType),
"TableHeaderChar": reflect.ValueOf(table.TableHeaderChar),
"TableHeaderToType": reflect.ValueOf(&table.TableHeaderToType).Elem(),
"UpdateSliceTable": reflect.ValueOf(table.UpdateSliceTable),
// type definitions
"Columns": reflect.ValueOf((*table.Columns)(nil)),
"FilterFunc": reflect.ValueOf((*table.FilterFunc)(nil)),
"Table": reflect.ValueOf((*table.Table)(nil)),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/tensor/tmath'. DO NOT EDIT.
package nogui
import (
"cogentcore.org/lab/tensor/tmath"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/tensor/tmath/tmath"] = map[string]reflect.Value{
// function, constant and variable definitions
"Abs": reflect.ValueOf(tmath.Abs),
"AbsOut": reflect.ValueOf(tmath.AbsOut),
"Acos": reflect.ValueOf(tmath.Acos),
"AcosOut": reflect.ValueOf(tmath.AcosOut),
"Acosh": reflect.ValueOf(tmath.Acosh),
"AcoshOut": reflect.ValueOf(tmath.AcoshOut),
"Add": reflect.ValueOf(tmath.Add),
"AddAssign": reflect.ValueOf(tmath.AddAssign),
"AddOut": reflect.ValueOf(tmath.AddOut),
"And": reflect.ValueOf(tmath.And),
"AndOut": reflect.ValueOf(tmath.AndOut),
"Asin": reflect.ValueOf(tmath.Asin),
"AsinOut": reflect.ValueOf(tmath.AsinOut),
"Asinh": reflect.ValueOf(tmath.Asinh),
"AsinhOut": reflect.ValueOf(tmath.AsinhOut),
"Assign": reflect.ValueOf(tmath.Assign),
"Atan": reflect.ValueOf(tmath.Atan),
"Atan2": reflect.ValueOf(tmath.Atan2),
"Atan2Out": reflect.ValueOf(tmath.Atan2Out),
"AtanOut": reflect.ValueOf(tmath.AtanOut),
"Atanh": reflect.ValueOf(tmath.Atanh),
"AtanhOut": reflect.ValueOf(tmath.AtanhOut),
"Cbrt": reflect.ValueOf(tmath.Cbrt),
"CbrtOut": reflect.ValueOf(tmath.CbrtOut),
"Ceil": reflect.ValueOf(tmath.Ceil),
"CeilOut": reflect.ValueOf(tmath.CeilOut),
"Copysign": reflect.ValueOf(tmath.Copysign),
"CopysignOut": reflect.ValueOf(tmath.CopysignOut),
"Cos": reflect.ValueOf(tmath.Cos),
"CosOut": reflect.ValueOf(tmath.CosOut),
"Cosh": reflect.ValueOf(tmath.Cosh),
"CoshOut": reflect.ValueOf(tmath.CoshOut),
"Dec": reflect.ValueOf(tmath.Dec),
"Dim": reflect.ValueOf(tmath.Dim),
"DimOut": reflect.ValueOf(tmath.DimOut),
"Div": reflect.ValueOf(tmath.Div),
"DivAssign": reflect.ValueOf(tmath.DivAssign),
"DivOut": reflect.ValueOf(tmath.DivOut),
"Equal": reflect.ValueOf(tmath.Equal),
"EqualOut": reflect.ValueOf(tmath.EqualOut),
"Erf": reflect.ValueOf(tmath.Erf),
"ErfOut": reflect.ValueOf(tmath.ErfOut),
"Erfc": reflect.ValueOf(tmath.Erfc),
"ErfcOut": reflect.ValueOf(tmath.ErfcOut),
"Erfcinv": reflect.ValueOf(tmath.Erfcinv),
"ErfcinvOut": reflect.ValueOf(tmath.ErfcinvOut),
"Erfinv": reflect.ValueOf(tmath.Erfinv),
"ErfinvOut": reflect.ValueOf(tmath.ErfinvOut),
"Exp": reflect.ValueOf(tmath.Exp),
"Exp2": reflect.ValueOf(tmath.Exp2),
"Exp2Out": reflect.ValueOf(tmath.Exp2Out),
"ExpOut": reflect.ValueOf(tmath.ExpOut),
"Expm1": reflect.ValueOf(tmath.Expm1),
"Expm1Out": reflect.ValueOf(tmath.Expm1Out),
"Floor": reflect.ValueOf(tmath.Floor),
"FloorOut": reflect.ValueOf(tmath.FloorOut),
"Gamma": reflect.ValueOf(tmath.Gamma),
"GammaOut": reflect.ValueOf(tmath.GammaOut),
"Greater": reflect.ValueOf(tmath.Greater),
"GreaterEqual": reflect.ValueOf(tmath.GreaterEqual),
"GreaterEqualOut": reflect.ValueOf(tmath.GreaterEqualOut),
"GreaterOut": reflect.ValueOf(tmath.GreaterOut),
"Hypot": reflect.ValueOf(tmath.Hypot),
"HypotOut": reflect.ValueOf(tmath.HypotOut),
"Inc": reflect.ValueOf(tmath.Inc),
"J0": reflect.ValueOf(tmath.J0),
"J0Out": reflect.ValueOf(tmath.J0Out),
"J1": reflect.ValueOf(tmath.J1),
"J1Out": reflect.ValueOf(tmath.J1Out),
"Less": reflect.ValueOf(tmath.Less),
"LessEqual": reflect.ValueOf(tmath.LessEqual),
"LessEqualOut": reflect.ValueOf(tmath.LessEqualOut),
"LessOut": reflect.ValueOf(tmath.LessOut),
"Log": reflect.ValueOf(tmath.Log),
"Log10": reflect.ValueOf(tmath.Log10),
"Log10Out": reflect.ValueOf(tmath.Log10Out),
"Log1p": reflect.ValueOf(tmath.Log1p),
"Log1pOut": reflect.ValueOf(tmath.Log1pOut),
"Log2": reflect.ValueOf(tmath.Log2),
"Log2Out": reflect.ValueOf(tmath.Log2Out),
"LogOut": reflect.ValueOf(tmath.LogOut),
"Logb": reflect.ValueOf(tmath.Logb),
"LogbOut": reflect.ValueOf(tmath.LogbOut),
"Max": reflect.ValueOf(tmath.Max),
"MaxOut": reflect.ValueOf(tmath.MaxOut),
"Min": reflect.ValueOf(tmath.Min),
"MinOut": reflect.ValueOf(tmath.MinOut),
"Mod": reflect.ValueOf(tmath.Mod),
"ModAssign": reflect.ValueOf(tmath.ModAssign),
"ModOut": reflect.ValueOf(tmath.ModOut),
"Mul": reflect.ValueOf(tmath.Mul),
"MulAssign": reflect.ValueOf(tmath.MulAssign),
"MulOut": reflect.ValueOf(tmath.MulOut),
"Negate": reflect.ValueOf(tmath.Negate),
"NegateOut": reflect.ValueOf(tmath.NegateOut),
"Nextafter": reflect.ValueOf(tmath.Nextafter),
"NextafterOut": reflect.ValueOf(tmath.NextafterOut),
"Not": reflect.ValueOf(tmath.Not),
"NotEqual": reflect.ValueOf(tmath.NotEqual),
"NotEqualOut": reflect.ValueOf(tmath.NotEqualOut),
"NotOut": reflect.ValueOf(tmath.NotOut),
"Or": reflect.ValueOf(tmath.Or),
"OrOut": reflect.ValueOf(tmath.OrOut),
"Pow": reflect.ValueOf(tmath.Pow),
"PowOut": reflect.ValueOf(tmath.PowOut),
"Remainder": reflect.ValueOf(tmath.Remainder),
"RemainderOut": reflect.ValueOf(tmath.RemainderOut),
"Round": reflect.ValueOf(tmath.Round),
"RoundOut": reflect.ValueOf(tmath.RoundOut),
"RoundToEven": reflect.ValueOf(tmath.RoundToEven),
"RoundToEvenOut": reflect.ValueOf(tmath.RoundToEvenOut),
"Sin": reflect.ValueOf(tmath.Sin),
"SinOut": reflect.ValueOf(tmath.SinOut),
"Sinh": reflect.ValueOf(tmath.Sinh),
"SinhOut": reflect.ValueOf(tmath.SinhOut),
"Sqrt": reflect.ValueOf(tmath.Sqrt),
"SqrtOut": reflect.ValueOf(tmath.SqrtOut),
"Sub": reflect.ValueOf(tmath.Sub),
"SubAssign": reflect.ValueOf(tmath.SubAssign),
"SubOut": reflect.ValueOf(tmath.SubOut),
"Tan": reflect.ValueOf(tmath.Tan),
"TanOut": reflect.ValueOf(tmath.TanOut),
"Tanh": reflect.ValueOf(tmath.Tanh),
"TanhOut": reflect.ValueOf(tmath.TanhOut),
"Trunc": reflect.ValueOf(tmath.Trunc),
"TruncOut": reflect.ValueOf(tmath.TruncOut),
"Y0": reflect.ValueOf(tmath.Y0),
"Y0Out": reflect.ValueOf(tmath.Y0Out),
"Y1": reflect.ValueOf(tmath.Y1),
"Y1Out": reflect.ValueOf(tmath.Y1Out),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/tensor'. DO NOT EDIT.
package nogui
import (
"cogentcore.org/core/base/metadata"
"cogentcore.org/lab/tensor"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/tensor/tensor"] = map[string]reflect.Value{
// function, constant and variable definitions
"AddFunc": reflect.ValueOf(tensor.AddFunc),
"AddShapes": reflect.ValueOf(tensor.AddShapes),
"AlignForAssign": reflect.ValueOf(tensor.AlignForAssign),
"AlignShapes": reflect.ValueOf(tensor.AlignShapes),
"As1D": reflect.ValueOf(tensor.As1D),
"AsFloat32": reflect.ValueOf(tensor.AsFloat32),
"AsFloat64": reflect.ValueOf(tensor.AsFloat64),
"AsFloat64Scalar": reflect.ValueOf(tensor.AsFloat64Scalar),
"AsFloat64Slice": reflect.ValueOf(tensor.AsFloat64Slice),
"AsIndexed": reflect.ValueOf(tensor.AsIndexed),
"AsInt": reflect.ValueOf(tensor.AsInt),
"AsIntScalar": reflect.ValueOf(tensor.AsIntScalar),
"AsIntSlice": reflect.ValueOf(tensor.AsIntSlice),
"AsMasked": reflect.ValueOf(tensor.AsMasked),
"AsReshaped": reflect.ValueOf(tensor.AsReshaped),
"AsRows": reflect.ValueOf(tensor.AsRows),
"AsSliced": reflect.ValueOf(tensor.AsSliced),
"AsString": reflect.ValueOf(tensor.AsString),
"AsStringScalar": reflect.ValueOf(tensor.AsStringScalar),
"AsStringSlice": reflect.ValueOf(tensor.AsStringSlice),
"Ascending": reflect.ValueOf(tensor.Ascending),
"BoolFloatsFunc": reflect.ValueOf(tensor.BoolFloatsFunc),
"BoolFloatsFuncOut": reflect.ValueOf(tensor.BoolFloatsFuncOut),
"BoolIntsFunc": reflect.ValueOf(tensor.BoolIntsFunc),
"BoolIntsFuncOut": reflect.ValueOf(tensor.BoolIntsFuncOut),
"BoolStringsFunc": reflect.ValueOf(tensor.BoolStringsFunc),
"BoolStringsFuncOut": reflect.ValueOf(tensor.BoolStringsFuncOut),
"BoolToFloat64": reflect.ValueOf(tensor.BoolToFloat64),
"BoolToInt": reflect.ValueOf(tensor.BoolToInt),
"Calc": reflect.ValueOf(tensor.Calc),
"CallOut1": reflect.ValueOf(tensor.CallOut1),
"CallOut1Float64": reflect.ValueOf(tensor.CallOut1Float64),
"CallOut2": reflect.ValueOf(tensor.CallOut2),
"CallOut2Bool": reflect.ValueOf(tensor.CallOut2Bool),
"CallOut2Float64": reflect.ValueOf(tensor.CallOut2Float64),
"CallOut3": reflect.ValueOf(tensor.CallOut3),
"Cells1D": reflect.ValueOf(tensor.Cells1D),
"CellsSize": reflect.ValueOf(tensor.CellsSize),
"Clone": reflect.ValueOf(tensor.Clone),
"ColumnMajorStrides": reflect.ValueOf(tensor.ColumnMajorStrides),
"Comma": reflect.ValueOf(tensor.Comma),
"DefaultNumThreads": reflect.ValueOf(tensor.DefaultNumThreads),
"DelimsN": reflect.ValueOf(tensor.DelimsN),
"DelimsValues": reflect.ValueOf(tensor.DelimsValues),
"Descending": reflect.ValueOf(tensor.Descending),
"Detect": reflect.ValueOf(tensor.Detect),
"Ellipsis": reflect.ValueOf(tensor.Ellipsis),
"Flatten": reflect.ValueOf(tensor.Flatten),
"Float64ToBool": reflect.ValueOf(tensor.Float64ToBool),
"Float64ToString": reflect.ValueOf(tensor.Float64ToString),
"FloatAssignFunc": reflect.ValueOf(tensor.FloatAssignFunc),
"FloatBinaryFunc": reflect.ValueOf(tensor.FloatBinaryFunc),
"FloatBinaryFuncOut": reflect.ValueOf(tensor.FloatBinaryFuncOut),
"FloatFunc": reflect.ValueOf(tensor.FloatFunc),
"FloatFuncOut": reflect.ValueOf(tensor.FloatFuncOut),
"FloatPromoteType": reflect.ValueOf(tensor.FloatPromoteType),
"FloatSetFunc": reflect.ValueOf(tensor.FloatSetFunc),
"FullAxis": reflect.ValueOf(tensor.FullAxis),
"FuncByName": reflect.ValueOf(tensor.FuncByName),
"Funcs": reflect.ValueOf(&tensor.Funcs).Elem(),
"IntToBool": reflect.ValueOf(tensor.IntToBool),
"Mask": reflect.ValueOf(tensor.Mask),
"MaxPrintLineWidth": reflect.ValueOf(&tensor.MaxPrintLineWidth).Elem(),
"MaxSprintLength": reflect.ValueOf(&tensor.MaxSprintLength).Elem(),
"MustBeSameShape": reflect.ValueOf(tensor.MustBeSameShape),
"MustBeValues": reflect.ValueOf(tensor.MustBeValues),
"NFirstLen": reflect.ValueOf(tensor.NFirstLen),
"NFirstRows": reflect.ValueOf(tensor.NFirstRows),
"NMinLen": reflect.ValueOf(tensor.NMinLen),
"NegIndex": reflect.ValueOf(tensor.NegIndex),
"NewAxis": reflect.ValueOf(tensor.NewAxis),
"NewBool": reflect.ValueOf(tensor.NewBool),
"NewBoolShape": reflect.ValueOf(tensor.NewBoolShape),
"NewByte": reflect.ValueOf(tensor.NewByte),
"NewFloat32": reflect.ValueOf(tensor.NewFloat32),
"NewFloat32FromValues": reflect.ValueOf(tensor.NewFloat32FromValues),
"NewFloat32Scalar": reflect.ValueOf(tensor.NewFloat32Scalar),
"NewFloat64": reflect.ValueOf(tensor.NewFloat64),
"NewFloat64FromValues": reflect.ValueOf(tensor.NewFloat64FromValues),
"NewFloat64Full": reflect.ValueOf(tensor.NewFloat64Full),
"NewFloat64Ones": reflect.ValueOf(tensor.NewFloat64Ones),
"NewFloat64Rand": reflect.ValueOf(tensor.NewFloat64Rand),
"NewFloat64Scalar": reflect.ValueOf(tensor.NewFloat64Scalar),
"NewFloat64SpacedLinear": reflect.ValueOf(tensor.NewFloat64SpacedLinear),
"NewFunc": reflect.ValueOf(tensor.NewFunc),
"NewIndexed": reflect.ValueOf(tensor.NewIndexed),
"NewInt": reflect.ValueOf(tensor.NewInt),
"NewInt32": reflect.ValueOf(tensor.NewInt32),
"NewIntFromValues": reflect.ValueOf(tensor.NewIntFromValues),
"NewIntFull": reflect.ValueOf(tensor.NewIntFull),
"NewIntRange": reflect.ValueOf(tensor.NewIntRange),
"NewIntScalar": reflect.ValueOf(tensor.NewIntScalar),
"NewMasked": reflect.ValueOf(tensor.NewMasked),
"NewOfType": reflect.ValueOf(tensor.NewOfType),
"NewReshaped": reflect.ValueOf(tensor.NewReshaped),
"NewRowCellsView": reflect.ValueOf(tensor.NewRowCellsView),
"NewRows": reflect.ValueOf(tensor.NewRows),
"NewShape": reflect.ValueOf(tensor.NewShape),
"NewSlice": reflect.ValueOf(tensor.NewSlice),
"NewSliced": reflect.ValueOf(tensor.NewSliced),
"NewString": reflect.ValueOf(tensor.NewString),
"NewStringFromValues": reflect.ValueOf(tensor.NewStringFromValues),
"NewStringFull": reflect.ValueOf(tensor.NewStringFull),
"NewStringScalar": reflect.ValueOf(tensor.NewStringScalar),
"NewStringShape": reflect.ValueOf(tensor.NewStringShape),
"NewUint32": reflect.ValueOf(tensor.NewUint32),
"NumThreads": reflect.ValueOf(&tensor.NumThreads).Elem(),
"OnedColumn": reflect.ValueOf(tensor.OnedColumn),
"OnedRow": reflect.ValueOf(tensor.OnedRow),
"OpenCSV": reflect.ValueOf(tensor.OpenCSV),
"Precision": reflect.ValueOf(tensor.Precision),
"Projection2DCoords": reflect.ValueOf(tensor.Projection2DCoords),
"Projection2DDimShapes": reflect.ValueOf(tensor.Projection2DDimShapes),
"Projection2DIndex": reflect.ValueOf(tensor.Projection2DIndex),
"Projection2DSet": reflect.ValueOf(tensor.Projection2DSet),
"Projection2DSetString": reflect.ValueOf(tensor.Projection2DSetString),
"Projection2DShape": reflect.ValueOf(tensor.Projection2DShape),
"Projection2DString": reflect.ValueOf(tensor.Projection2DString),
"Projection2DValue": reflect.ValueOf(tensor.Projection2DValue),
"Range": reflect.ValueOf(tensor.Range),
"ReadCSV": reflect.ValueOf(tensor.ReadCSV),
"Reshape": reflect.ValueOf(tensor.Reshape),
"Reslice": reflect.ValueOf(tensor.Reslice),
"RowMajorStrides": reflect.ValueOf(tensor.RowMajorStrides),
"SaveCSV": reflect.ValueOf(tensor.SaveCSV),
"SetAllFloat64": reflect.ValueOf(tensor.SetAllFloat64),
"SetAllInt": reflect.ValueOf(tensor.SetAllInt),
"SetAllString": reflect.ValueOf(tensor.SetAllString),
"SetCalcFunc": reflect.ValueOf(tensor.SetCalcFunc),
"SetPrecision": reflect.ValueOf(tensor.SetPrecision),
"SetShape": reflect.ValueOf(tensor.SetShape),
"SetShapeFrom": reflect.ValueOf(tensor.SetShapeFrom),
"SetShapeNames": reflect.ValueOf(tensor.SetShapeNames),
"SetShapeSizesFromTensor": reflect.ValueOf(tensor.SetShapeSizesFromTensor),
"ShapeNames": reflect.ValueOf(tensor.ShapeNames),
"SlicesMagicN": reflect.ValueOf(tensor.SlicesMagicN),
"SlicesMagicValues": reflect.ValueOf(tensor.SlicesMagicValues),
"Space": reflect.ValueOf(tensor.Space),
"SplitAtInnerDims": reflect.ValueOf(tensor.SplitAtInnerDims),
"Sprintf": reflect.ValueOf(tensor.Sprintf),
"Squeeze": reflect.ValueOf(tensor.Squeeze),
"StableSort": reflect.ValueOf(tensor.StableSort),
"StringAssignFunc": reflect.ValueOf(tensor.StringAssignFunc),
"StringBinaryFunc": reflect.ValueOf(tensor.StringBinaryFunc),
"StringBinaryFuncOut": reflect.ValueOf(tensor.StringBinaryFuncOut),
"StringToFloat64": reflect.ValueOf(tensor.StringToFloat64),
"Tab": reflect.ValueOf(tensor.Tab),
"ThreadingThreshold": reflect.ValueOf(&tensor.ThreadingThreshold).Elem(),
"Transpose": reflect.ValueOf(tensor.Transpose),
"UnstableSort": reflect.ValueOf(tensor.UnstableSort),
"Vectorize": reflect.ValueOf(tensor.Vectorize),
"VectorizeOnThreads": reflect.ValueOf(tensor.VectorizeOnThreads),
"VectorizeThreaded": reflect.ValueOf(tensor.VectorizeThreaded),
"WrapIndex1D": reflect.ValueOf(tensor.WrapIndex1D),
"WriteCSV": reflect.ValueOf(tensor.WriteCSV),
// type definitions
"Arg": reflect.ValueOf((*tensor.Arg)(nil)),
"Bool": reflect.ValueOf((*tensor.Bool)(nil)),
"Delims": reflect.ValueOf((*tensor.Delims)(nil)),
"FilterFunc": reflect.ValueOf((*tensor.FilterFunc)(nil)),
"FilterOptions": reflect.ValueOf((*tensor.FilterOptions)(nil)),
"Func": reflect.ValueOf((*tensor.Func)(nil)),
"Indexed": reflect.ValueOf((*tensor.Indexed)(nil)),
"Masked": reflect.ValueOf((*tensor.Masked)(nil)),
"Reshaped": reflect.ValueOf((*tensor.Reshaped)(nil)),
"RowMajor": reflect.ValueOf((*tensor.RowMajor)(nil)),
"Rows": reflect.ValueOf((*tensor.Rows)(nil)),
"Shape": reflect.ValueOf((*tensor.Shape)(nil)),
"Slice": reflect.ValueOf((*tensor.Slice)(nil)),
"Sliced": reflect.ValueOf((*tensor.Sliced)(nil)),
"SlicesMagic": reflect.ValueOf((*tensor.SlicesMagic)(nil)),
"String": reflect.ValueOf((*tensor.String)(nil)),
"Tensor": reflect.ValueOf((*tensor.Tensor)(nil)),
"Values": reflect.ValueOf((*tensor.Values)(nil)),
// interface wrapper definitions
"_RowMajor": reflect.ValueOf((*_cogentcore_org_lab_tensor_RowMajor)(nil)),
"_Tensor": reflect.ValueOf((*_cogentcore_org_lab_tensor_Tensor)(nil)),
"_Values": reflect.ValueOf((*_cogentcore_org_lab_tensor_Values)(nil)),
}
}
// _cogentcore_org_lab_tensor_RowMajor is an interface wrapper for RowMajor type
type _cogentcore_org_lab_tensor_RowMajor struct {
IValue interface{}
WAppendRow func(val tensor.Values)
WAppendRowFloat func(val ...float64)
WAppendRowInt func(val ...int)
WAppendRowString func(val ...string)
WAsValues func() tensor.Values
WDataType func() reflect.Kind
WDimSize func(dim int) int
WFloat func(i ...int) float64
WFloat1D func(i int) float64
WFloatRow func(row int, cell int) float64
WInt func(i ...int) int
WInt1D func(i int) int
WIntRow func(row int, cell int) int
WIsString func() bool
WLabel func() string
WLen func() int
WMetadata func() *metadata.Data
WNumDims func() int
WRowTensor func(row int) tensor.Values
WSetFloat func(val float64, i ...int)
WSetFloat1D func(val float64, i int)
WSetFloatRow func(val float64, row int, cell int)
WSetInt func(val int, i ...int)
WSetInt1D func(val int, i int)
WSetIntRow func(val int, row int, cell int)
WSetRowTensor func(val tensor.Values, row int)
WSetString func(val string, i ...int)
WSetString1D func(val string, i int)
WSetStringRow func(val string, row int, cell int)
WShape func() *tensor.Shape
WShapeSizes func() []int
WString func() string
WString1D func(i int) string
WStringRow func(row int, cell int) string
WStringValue func(i ...int) string
WSubSpace func(offs ...int) tensor.Values
}
func (W _cogentcore_org_lab_tensor_RowMajor) AppendRow(val tensor.Values) { W.WAppendRow(val) }
func (W _cogentcore_org_lab_tensor_RowMajor) AppendRowFloat(val ...float64) {
W.WAppendRowFloat(val...)
}
func (W _cogentcore_org_lab_tensor_RowMajor) AppendRowInt(val ...int) { W.WAppendRowInt(val...) }
func (W _cogentcore_org_lab_tensor_RowMajor) AppendRowString(val ...string) {
W.WAppendRowString(val...)
}
func (W _cogentcore_org_lab_tensor_RowMajor) AsValues() tensor.Values { return W.WAsValues() }
func (W _cogentcore_org_lab_tensor_RowMajor) DataType() reflect.Kind { return W.WDataType() }
func (W _cogentcore_org_lab_tensor_RowMajor) DimSize(dim int) int { return W.WDimSize(dim) }
func (W _cogentcore_org_lab_tensor_RowMajor) Float(i ...int) float64 { return W.WFloat(i...) }
func (W _cogentcore_org_lab_tensor_RowMajor) Float1D(i int) float64 { return W.WFloat1D(i) }
func (W _cogentcore_org_lab_tensor_RowMajor) FloatRow(row int, cell int) float64 {
return W.WFloatRow(row, cell)
}
func (W _cogentcore_org_lab_tensor_RowMajor) Int(i ...int) int { return W.WInt(i...) }
func (W _cogentcore_org_lab_tensor_RowMajor) Int1D(i int) int { return W.WInt1D(i) }
func (W _cogentcore_org_lab_tensor_RowMajor) IntRow(row int, cell int) int {
return W.WIntRow(row, cell)
}
func (W _cogentcore_org_lab_tensor_RowMajor) IsString() bool { return W.WIsString() }
func (W _cogentcore_org_lab_tensor_RowMajor) Label() string { return W.WLabel() }
func (W _cogentcore_org_lab_tensor_RowMajor) Len() int { return W.WLen() }
func (W _cogentcore_org_lab_tensor_RowMajor) Metadata() *metadata.Data { return W.WMetadata() }
func (W _cogentcore_org_lab_tensor_RowMajor) NumDims() int { return W.WNumDims() }
func (W _cogentcore_org_lab_tensor_RowMajor) RowTensor(row int) tensor.Values {
return W.WRowTensor(row)
}
func (W _cogentcore_org_lab_tensor_RowMajor) SetFloat(val float64, i ...int) { W.WSetFloat(val, i...) }
func (W _cogentcore_org_lab_tensor_RowMajor) SetFloat1D(val float64, i int) { W.WSetFloat1D(val, i) }
func (W _cogentcore_org_lab_tensor_RowMajor) SetFloatRow(val float64, row int, cell int) {
W.WSetFloatRow(val, row, cell)
}
func (W _cogentcore_org_lab_tensor_RowMajor) SetInt(val int, i ...int) { W.WSetInt(val, i...) }
func (W _cogentcore_org_lab_tensor_RowMajor) SetInt1D(val int, i int) { W.WSetInt1D(val, i) }
func (W _cogentcore_org_lab_tensor_RowMajor) SetIntRow(val int, row int, cell int) {
W.WSetIntRow(val, row, cell)
}
func (W _cogentcore_org_lab_tensor_RowMajor) SetRowTensor(val tensor.Values, row int) {
W.WSetRowTensor(val, row)
}
func (W _cogentcore_org_lab_tensor_RowMajor) SetString(val string, i ...int) { W.WSetString(val, i...) }
func (W _cogentcore_org_lab_tensor_RowMajor) SetString1D(val string, i int) { W.WSetString1D(val, i) }
func (W _cogentcore_org_lab_tensor_RowMajor) SetStringRow(val string, row int, cell int) {
W.WSetStringRow(val, row, cell)
}
func (W _cogentcore_org_lab_tensor_RowMajor) Shape() *tensor.Shape { return W.WShape() }
func (W _cogentcore_org_lab_tensor_RowMajor) ShapeSizes() []int { return W.WShapeSizes() }
func (W _cogentcore_org_lab_tensor_RowMajor) String() string {
if W.WString == nil {
return ""
}
return W.WString()
}
func (W _cogentcore_org_lab_tensor_RowMajor) String1D(i int) string { return W.WString1D(i) }
func (W _cogentcore_org_lab_tensor_RowMajor) StringRow(row int, cell int) string {
return W.WStringRow(row, cell)
}
func (W _cogentcore_org_lab_tensor_RowMajor) StringValue(i ...int) string {
return W.WStringValue(i...)
}
func (W _cogentcore_org_lab_tensor_RowMajor) SubSpace(offs ...int) tensor.Values {
return W.WSubSpace(offs...)
}
// _cogentcore_org_lab_tensor_Tensor is an interface wrapper for Tensor type
type _cogentcore_org_lab_tensor_Tensor struct {
IValue interface{}
WAsValues func() tensor.Values
WDataType func() reflect.Kind
WDimSize func(dim int) int
WFloat func(i ...int) float64
WFloat1D func(i int) float64
WInt func(i ...int) int
WInt1D func(i int) int
WIsString func() bool
WLabel func() string
WLen func() int
WMetadata func() *metadata.Data
WNumDims func() int
WSetFloat func(val float64, i ...int)
WSetFloat1D func(val float64, i int)
WSetInt func(val int, i ...int)
WSetInt1D func(val int, i int)
WSetString func(val string, i ...int)
WSetString1D func(val string, i int)
WShape func() *tensor.Shape
WShapeSizes func() []int
WString func() string
WString1D func(i int) string
WStringValue func(i ...int) string
}
func (W _cogentcore_org_lab_tensor_Tensor) AsValues() tensor.Values { return W.WAsValues() }
func (W _cogentcore_org_lab_tensor_Tensor) DataType() reflect.Kind { return W.WDataType() }
func (W _cogentcore_org_lab_tensor_Tensor) DimSize(dim int) int { return W.WDimSize(dim) }
func (W _cogentcore_org_lab_tensor_Tensor) Float(i ...int) float64 { return W.WFloat(i...) }
func (W _cogentcore_org_lab_tensor_Tensor) Float1D(i int) float64 { return W.WFloat1D(i) }
func (W _cogentcore_org_lab_tensor_Tensor) Int(i ...int) int { return W.WInt(i...) }
func (W _cogentcore_org_lab_tensor_Tensor) Int1D(i int) int { return W.WInt1D(i) }
func (W _cogentcore_org_lab_tensor_Tensor) IsString() bool { return W.WIsString() }
func (W _cogentcore_org_lab_tensor_Tensor) Label() string { return W.WLabel() }
func (W _cogentcore_org_lab_tensor_Tensor) Len() int { return W.WLen() }
func (W _cogentcore_org_lab_tensor_Tensor) Metadata() *metadata.Data { return W.WMetadata() }
func (W _cogentcore_org_lab_tensor_Tensor) NumDims() int { return W.WNumDims() }
func (W _cogentcore_org_lab_tensor_Tensor) SetFloat(val float64, i ...int) { W.WSetFloat(val, i...) }
func (W _cogentcore_org_lab_tensor_Tensor) SetFloat1D(val float64, i int) { W.WSetFloat1D(val, i) }
func (W _cogentcore_org_lab_tensor_Tensor) SetInt(val int, i ...int) { W.WSetInt(val, i...) }
func (W _cogentcore_org_lab_tensor_Tensor) SetInt1D(val int, i int) { W.WSetInt1D(val, i) }
func (W _cogentcore_org_lab_tensor_Tensor) SetString(val string, i ...int) { W.WSetString(val, i...) }
func (W _cogentcore_org_lab_tensor_Tensor) SetString1D(val string, i int) { W.WSetString1D(val, i) }
func (W _cogentcore_org_lab_tensor_Tensor) Shape() *tensor.Shape { return W.WShape() }
func (W _cogentcore_org_lab_tensor_Tensor) ShapeSizes() []int { return W.WShapeSizes() }
func (W _cogentcore_org_lab_tensor_Tensor) String() string {
if W.WString == nil {
return ""
}
return W.WString()
}
func (W _cogentcore_org_lab_tensor_Tensor) String1D(i int) string { return W.WString1D(i) }
func (W _cogentcore_org_lab_tensor_Tensor) StringValue(i ...int) string { return W.WStringValue(i...) }
// _cogentcore_org_lab_tensor_Values is an interface wrapper for Values type
type _cogentcore_org_lab_tensor_Values struct {
IValue interface{}
WAppendFrom func(from tensor.Values) tensor.Values
WAppendRow func(val tensor.Values)
WAppendRowFloat func(val ...float64)
WAppendRowInt func(val ...int)
WAppendRowString func(val ...string)
WAsValues func() tensor.Values
WBytes func() []byte
WClone func() tensor.Values
WCopyCellsFrom func(from tensor.Values, to int, start int, n int)
WCopyFrom func(from tensor.Values)
WDataType func() reflect.Kind
WDimSize func(dim int) int
WFloat func(i ...int) float64
WFloat1D func(i int) float64
WFloatRow func(row int, cell int) float64
WInt func(i ...int) int
WInt1D func(i int) int
WIntRow func(row int, cell int) int
WIsString func() bool
WLabel func() string
WLen func() int
WMetadata func() *metadata.Data
WNumDims func() int
WRowTensor func(row int) tensor.Values
WSetFloat func(val float64, i ...int)
WSetFloat1D func(val float64, i int)
WSetFloatRow func(val float64, row int, cell int)
WSetInt func(val int, i ...int)
WSetInt1D func(val int, i int)
WSetIntRow func(val int, row int, cell int)
WSetNumRows func(rows int)
WSetRowTensor func(val tensor.Values, row int)
WSetShapeSizes func(sizes ...int)
WSetString func(val string, i ...int)
WSetString1D func(val string, i int)
WSetStringRow func(val string, row int, cell int)
WSetZeros func()
WShape func() *tensor.Shape
WShapeSizes func() []int
WSizeof func() int64
WString func() string
WString1D func(i int) string
WStringRow func(row int, cell int) string
WStringValue func(i ...int) string
WSubSpace func(offs ...int) tensor.Values
}
func (W _cogentcore_org_lab_tensor_Values) AppendFrom(from tensor.Values) tensor.Values {
return W.WAppendFrom(from)
}
func (W _cogentcore_org_lab_tensor_Values) AppendRow(val tensor.Values) { W.WAppendRow(val) }
func (W _cogentcore_org_lab_tensor_Values) AppendRowFloat(val ...float64) { W.WAppendRowFloat(val...) }
func (W _cogentcore_org_lab_tensor_Values) AppendRowInt(val ...int) { W.WAppendRowInt(val...) }
func (W _cogentcore_org_lab_tensor_Values) AppendRowString(val ...string) { W.WAppendRowString(val...) }
func (W _cogentcore_org_lab_tensor_Values) AsValues() tensor.Values { return W.WAsValues() }
func (W _cogentcore_org_lab_tensor_Values) Bytes() []byte { return W.WBytes() }
func (W _cogentcore_org_lab_tensor_Values) Clone() tensor.Values { return W.WClone() }
func (W _cogentcore_org_lab_tensor_Values) CopyCellsFrom(from tensor.Values, to int, start int, n int) {
W.WCopyCellsFrom(from, to, start, n)
}
func (W _cogentcore_org_lab_tensor_Values) CopyFrom(from tensor.Values) { W.WCopyFrom(from) }
func (W _cogentcore_org_lab_tensor_Values) DataType() reflect.Kind { return W.WDataType() }
func (W _cogentcore_org_lab_tensor_Values) DimSize(dim int) int { return W.WDimSize(dim) }
func (W _cogentcore_org_lab_tensor_Values) Float(i ...int) float64 { return W.WFloat(i...) }
func (W _cogentcore_org_lab_tensor_Values) Float1D(i int) float64 { return W.WFloat1D(i) }
func (W _cogentcore_org_lab_tensor_Values) FloatRow(row int, cell int) float64 {
return W.WFloatRow(row, cell)
}
func (W _cogentcore_org_lab_tensor_Values) Int(i ...int) int { return W.WInt(i...) }
func (W _cogentcore_org_lab_tensor_Values) Int1D(i int) int { return W.WInt1D(i) }
func (W _cogentcore_org_lab_tensor_Values) IntRow(row int, cell int) int { return W.WIntRow(row, cell) }
func (W _cogentcore_org_lab_tensor_Values) IsString() bool { return W.WIsString() }
func (W _cogentcore_org_lab_tensor_Values) Label() string { return W.WLabel() }
func (W _cogentcore_org_lab_tensor_Values) Len() int { return W.WLen() }
func (W _cogentcore_org_lab_tensor_Values) Metadata() *metadata.Data { return W.WMetadata() }
func (W _cogentcore_org_lab_tensor_Values) NumDims() int { return W.WNumDims() }
func (W _cogentcore_org_lab_tensor_Values) RowTensor(row int) tensor.Values { return W.WRowTensor(row) }
func (W _cogentcore_org_lab_tensor_Values) SetFloat(val float64, i ...int) { W.WSetFloat(val, i...) }
func (W _cogentcore_org_lab_tensor_Values) SetFloat1D(val float64, i int) { W.WSetFloat1D(val, i) }
func (W _cogentcore_org_lab_tensor_Values) SetFloatRow(val float64, row int, cell int) {
W.WSetFloatRow(val, row, cell)
}
func (W _cogentcore_org_lab_tensor_Values) SetInt(val int, i ...int) { W.WSetInt(val, i...) }
func (W _cogentcore_org_lab_tensor_Values) SetInt1D(val int, i int) { W.WSetInt1D(val, i) }
func (W _cogentcore_org_lab_tensor_Values) SetIntRow(val int, row int, cell int) {
W.WSetIntRow(val, row, cell)
}
func (W _cogentcore_org_lab_tensor_Values) SetNumRows(rows int) { W.WSetNumRows(rows) }
func (W _cogentcore_org_lab_tensor_Values) SetRowTensor(val tensor.Values, row int) {
W.WSetRowTensor(val, row)
}
func (W _cogentcore_org_lab_tensor_Values) SetShapeSizes(sizes ...int) { W.WSetShapeSizes(sizes...) }
func (W _cogentcore_org_lab_tensor_Values) SetString(val string, i ...int) { W.WSetString(val, i...) }
func (W _cogentcore_org_lab_tensor_Values) SetString1D(val string, i int) { W.WSetString1D(val, i) }
func (W _cogentcore_org_lab_tensor_Values) SetStringRow(val string, row int, cell int) {
W.WSetStringRow(val, row, cell)
}
func (W _cogentcore_org_lab_tensor_Values) SetZeros() { W.WSetZeros() }
func (W _cogentcore_org_lab_tensor_Values) Shape() *tensor.Shape { return W.WShape() }
func (W _cogentcore_org_lab_tensor_Values) ShapeSizes() []int { return W.WShapeSizes() }
func (W _cogentcore_org_lab_tensor_Values) Sizeof() int64 { return W.WSizeof() }
func (W _cogentcore_org_lab_tensor_Values) String() string {
if W.WString == nil {
return ""
}
return W.WString()
}
func (W _cogentcore_org_lab_tensor_Values) String1D(i int) string { return W.WString1D(i) }
func (W _cogentcore_org_lab_tensor_Values) StringRow(row int, cell int) string {
return W.WStringRow(row, cell)
}
func (W _cogentcore_org_lab_tensor_Values) StringValue(i ...int) string { return W.WStringValue(i...) }
func (W _cogentcore_org_lab_tensor_Values) SubSpace(offs ...int) tensor.Values {
return W.WSubSpace(offs...)
}
// Code generated by 'yaegi extract cogentcore.org/lab/tensorfs'. DO NOT EDIT.
package nogui
import (
"cogentcore.org/lab/tensorfs"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/tensorfs/tensorfs"] = map[string]reflect.Value{
// function, constant and variable definitions
"Chdir": reflect.ValueOf(tensorfs.Chdir),
"CurDir": reflect.ValueOf(&tensorfs.CurDir).Elem(),
"CurRoot": reflect.ValueOf(&tensorfs.CurRoot).Elem(),
"DirFromTable": reflect.ValueOf(tensorfs.DirFromTable),
"DirOnly": reflect.ValueOf(tensorfs.DirOnly),
"DirTable": reflect.ValueOf(tensorfs.DirTable),
"Get": reflect.ValueOf(tensorfs.Get),
"List": reflect.ValueOf(tensorfs.List),
"Long": reflect.ValueOf(tensorfs.Long),
"Mkdir": reflect.ValueOf(tensorfs.Mkdir),
"NewDir": reflect.ValueOf(tensorfs.NewDir),
"Overwrite": reflect.ValueOf(tensorfs.Overwrite),
"Preserve": reflect.ValueOf(tensorfs.Preserve),
"Record": reflect.ValueOf(tensorfs.Record),
"Recursive": reflect.ValueOf(tensorfs.Recursive),
"Set": reflect.ValueOf(tensorfs.Set),
"SetTensor": reflect.ValueOf(tensorfs.SetTensor),
"Short": reflect.ValueOf(tensorfs.Short),
"ValueType": reflect.ValueOf(tensorfs.ValueType),
// type definitions
"DirFile": reflect.ValueOf((*tensorfs.DirFile)(nil)),
"File": reflect.ValueOf((*tensorfs.File)(nil)),
"Node": reflect.ValueOf((*tensorfs.Node)(nil)),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/vector'. DO NOT EDIT.
package nogui
import (
"cogentcore.org/lab/vector"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/vector/vector"] = map[string]reflect.Value{
// function, constant and variable definitions
"Dot": reflect.ValueOf(vector.Dot),
"L1Norm": reflect.ValueOf(vector.L1Norm),
"L2Norm": reflect.ValueOf(vector.L2Norm),
"Mul": reflect.ValueOf(vector.Mul),
"MulOut": reflect.ValueOf(vector.MulOut),
"Sum": reflect.ValueOf(vector.Sum),
}
}