// Copyright (c) 2018, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package atomicx implements misc atomic functions.
package atomicx
import (
"sync/atomic"
)
// Counter implements a basic atomic int64 counter.
type Counter int64
// Add adds to counter.
func (a *Counter) Add(inc int64) int64 {
return atomic.AddInt64((*int64)(a), inc)
}
// Sub subtracts from counter.
func (a *Counter) Sub(dec int64) int64 {
return atomic.AddInt64((*int64)(a), -dec)
}
// Inc increments by 1.
func (a *Counter) Inc() int64 {
return atomic.AddInt64((*int64)(a), 1)
}
// Dec decrements by 1.
func (a *Counter) Dec() int64 {
return atomic.AddInt64((*int64)(a), -1)
}
// Value returns the current value.
func (a *Counter) Value() int64 {
return atomic.LoadInt64((*int64)(a))
}
// Set sets the counter to a new value.
func (a *Counter) Set(val int64) {
atomic.StoreInt64((*int64)(a), val)
}
// Swap swaps a new value in and returns the old value.
func (a *Counter) Swap(val int64) int64 {
return atomic.SwapInt64((*int64)(a), val)
}
// Copyright (c) 2018, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package atomicx
import "sync/atomic"
// MaxInt32 performs an atomic Max operation: a = max(a, b)
func MaxInt32(a *int32, b int32) {
old := atomic.LoadInt32(a)
for old < b && !atomic.CompareAndSwapInt32(a, old, b) {
old = atomic.LoadInt32(a)
}
}
// Code generated by dummy.gen.go.tmpl. DO NOT EDIT.
// Copyright (c) 2020, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !mpi && !mpich
package mpi
// SendF64 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendF64(toProc int, tag int, vals []float64) error {
return nil
}
// Recv64F64 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvF64(fmProc int, tag int, vals []float64) error {
return nil
}
// BcastF64 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastF64(fmProc int, vals []float64) error {
return nil
}
// ReduceF64 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceF64(toProc int, op Op, dest, orig []float64) error {
return nil
}
// AllReduceF64 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceF64(op Op, dest, orig []float64) error {
return nil
}
// GatherF64 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherF64(toProc int, dest, orig []float64) error {
return nil
}
// AllGatherF64 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherF64(dest, orig []float64) error {
return nil
}
// ScatterF64 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterF64(fmProc int, dest, orig []float64) error {
return nil
}
// SendF32 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendF32(toProc int, tag int, vals []float32) error {
return nil
}
// Recv64F32 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvF32(fmProc int, tag int, vals []float32) error {
return nil
}
// BcastF32 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastF32(fmProc int, vals []float32) error {
return nil
}
// ReduceF32 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceF32(toProc int, op Op, dest, orig []float32) error {
return nil
}
// AllReduceF32 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceF32(op Op, dest, orig []float32) error {
return nil
}
// GatherF32 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherF32(toProc int, dest, orig []float32) error {
return nil
}
// AllGatherF32 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherF32(dest, orig []float32) error {
return nil
}
// ScatterF32 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterF32(fmProc int, dest, orig []float32) error {
return nil
}
// SendInt sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendInt(toProc int, tag int, vals []int) error {
return nil
}
// Recv64Int receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvInt(fmProc int, tag int, vals []int) error {
return nil
}
// BcastInt broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastInt(fmProc int, vals []int) error {
return nil
}
// ReduceInt reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceInt(toProc int, op Op, dest, orig []int) error {
return nil
}
// AllReduceInt reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceInt(op Op, dest, orig []int) error {
return nil
}
// GatherInt gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherInt(toProc int, dest, orig []int) error {
return nil
}
// AllGatherInt gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherInt(dest, orig []int) error {
return nil
}
// ScatterInt scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterInt(fmProc int, dest, orig []int) error {
return nil
}
// SendI64 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendI64(toProc int, tag int, vals []int64) error {
return nil
}
// Recv64I64 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvI64(fmProc int, tag int, vals []int64) error {
return nil
}
// BcastI64 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastI64(fmProc int, vals []int64) error {
return nil
}
// ReduceI64 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceI64(toProc int, op Op, dest, orig []int64) error {
return nil
}
// AllReduceI64 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceI64(op Op, dest, orig []int64) error {
return nil
}
// GatherI64 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherI64(toProc int, dest, orig []int64) error {
return nil
}
// AllGatherI64 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherI64(dest, orig []int64) error {
return nil
}
// ScatterI64 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterI64(fmProc int, dest, orig []int64) error {
return nil
}
// SendU64 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendU64(toProc int, tag int, vals []uint64) error {
return nil
}
// Recv64U64 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvU64(fmProc int, tag int, vals []uint64) error {
return nil
}
// BcastU64 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastU64(fmProc int, vals []uint64) error {
return nil
}
// ReduceU64 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceU64(toProc int, op Op, dest, orig []uint64) error {
return nil
}
// AllReduceU64 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceU64(op Op, dest, orig []uint64) error {
return nil
}
// GatherU64 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherU64(toProc int, dest, orig []uint64) error {
return nil
}
// AllGatherU64 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherU64(dest, orig []uint64) error {
return nil
}
// ScatterU64 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterU64(fmProc int, dest, orig []uint64) error {
return nil
}
// SendI32 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendI32(toProc int, tag int, vals []int32) error {
return nil
}
// Recv64I32 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvI32(fmProc int, tag int, vals []int32) error {
return nil
}
// BcastI32 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastI32(fmProc int, vals []int32) error {
return nil
}
// ReduceI32 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceI32(toProc int, op Op, dest, orig []int32) error {
return nil
}
// AllReduceI32 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceI32(op Op, dest, orig []int32) error {
return nil
}
// GatherI32 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherI32(toProc int, dest, orig []int32) error {
return nil
}
// AllGatherI32 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherI32(dest, orig []int32) error {
return nil
}
// ScatterI32 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterI32(fmProc int, dest, orig []int32) error {
return nil
}
// SendU32 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendU32(toProc int, tag int, vals []uint32) error {
return nil
}
// Recv64U32 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvU32(fmProc int, tag int, vals []uint32) error {
return nil
}
// BcastU32 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastU32(fmProc int, vals []uint32) error {
return nil
}
// ReduceU32 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceU32(toProc int, op Op, dest, orig []uint32) error {
return nil
}
// AllReduceU32 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceU32(op Op, dest, orig []uint32) error {
return nil
}
// GatherU32 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherU32(toProc int, dest, orig []uint32) error {
return nil
}
// AllGatherU32 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherU32(dest, orig []uint32) error {
return nil
}
// ScatterU32 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterU32(fmProc int, dest, orig []uint32) error {
return nil
}
// SendI16 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendI16(toProc int, tag int, vals []int16) error {
return nil
}
// Recv64I16 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvI16(fmProc int, tag int, vals []int16) error {
return nil
}
// BcastI16 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastI16(fmProc int, vals []int16) error {
return nil
}
// ReduceI16 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceI16(toProc int, op Op, dest, orig []int16) error {
return nil
}
// AllReduceI16 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceI16(op Op, dest, orig []int16) error {
return nil
}
// GatherI16 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherI16(toProc int, dest, orig []int16) error {
return nil
}
// AllGatherI16 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherI16(dest, orig []int16) error {
return nil
}
// ScatterI16 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterI16(fmProc int, dest, orig []int16) error {
return nil
}
// SendU16 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendU16(toProc int, tag int, vals []uint16) error {
return nil
}
// Recv64U16 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvU16(fmProc int, tag int, vals []uint16) error {
return nil
}
// BcastU16 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastU16(fmProc int, vals []uint16) error {
return nil
}
// ReduceU16 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceU16(toProc int, op Op, dest, orig []uint16) error {
return nil
}
// AllReduceU16 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceU16(op Op, dest, orig []uint16) error {
return nil
}
// GatherU16 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherU16(toProc int, dest, orig []uint16) error {
return nil
}
// AllGatherU16 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherU16(dest, orig []uint16) error {
return nil
}
// ScatterU16 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterU16(fmProc int, dest, orig []uint16) error {
return nil
}
// SendI8 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendI8(toProc int, tag int, vals []int8) error {
return nil
}
// Recv64I8 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvI8(fmProc int, tag int, vals []int8) error {
return nil
}
// BcastI8 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastI8(fmProc int, vals []int8) error {
return nil
}
// ReduceI8 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceI8(toProc int, op Op, dest, orig []int8) error {
return nil
}
// AllReduceI8 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceI8(op Op, dest, orig []int8) error {
return nil
}
// GatherI8 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherI8(toProc int, dest, orig []int8) error {
return nil
}
// AllGatherI8 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherI8(dest, orig []int8) error {
return nil
}
// ScatterI8 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterI8(fmProc int, dest, orig []int8) error {
return nil
}
// SendU8 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendU8(toProc int, tag int, vals []uint8) error {
return nil
}
// Recv64U8 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvU8(fmProc int, tag int, vals []uint8) error {
return nil
}
// BcastU8 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastU8(fmProc int, vals []uint8) error {
return nil
}
// ReduceU8 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceU8(toProc int, op Op, dest, orig []uint8) error {
return nil
}
// AllReduceU8 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceU8(op Op, dest, orig []uint8) error {
return nil
}
// GatherU8 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherU8(toProc int, dest, orig []uint8) error {
return nil
}
// AllGatherU8 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherU8(dest, orig []uint8) error {
return nil
}
// ScatterU8 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterU8(fmProc int, dest, orig []uint8) error {
return nil
}
// SendC128 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendC128(toProc int, tag int, vals []complex128) error {
return nil
}
// Recv64C128 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvC128(fmProc int, tag int, vals []complex128) error {
return nil
}
// BcastC128 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastC128(fmProc int, vals []complex128) error {
return nil
}
// ReduceC128 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceC128(toProc int, op Op, dest, orig []complex128) error {
return nil
}
// AllReduceC128 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceC128(op Op, dest, orig []complex128) error {
return nil
}
// GatherC128 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherC128(toProc int, dest, orig []complex128) error {
return nil
}
// AllGatherC128 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherC128(dest, orig []complex128) error {
return nil
}
// ScatterC128 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterC128(fmProc int, dest, orig []complex128) error {
return nil
}
// SendC64 sends values to toProc, using given unique tag identifier.
// This is Blocking. Must have a corresponding Recv call with same tag on toProc, from this proc
func (cm *Comm) SendC64(toProc int, tag int, vals []complex64) error {
return nil
}
// Recv64C64 receives values from proc fmProc, using given unique tag identifier
// This is Blocking. Must have a corresponding Send call with same tag on fmProc, to this proc
func (cm *Comm) RecvC64(fmProc int, tag int, vals []complex64) error {
return nil
}
// BcastC64 broadcasts slice from fmProc to all other procs.
// All nodes have the same vals after this call, copied from fmProc.
func (cm *Comm) BcastC64(fmProc int, vals []complex64) error {
return nil
}
// ReduceC64 reduces all values across procs to toProc in orig to dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ReduceC64(toProc int, op Op, dest, orig []complex64) error {
return nil
}
// AllReduceC64 reduces all values across procs to all procs from orig into dest using given operation.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllReduceC64(op Op, dest, orig []complex64) error {
return nil
}
// GatherC64 gathers values from all procs into toProc proc, tiled into dest of size np * len(orig).
// This is inverse of Scatter.
// IMPORTANT: orig and dest must be different slices.
func (cm *Comm) GatherC64(toProc int, dest, orig []complex64) error {
return nil
}
// AllGatherC64 gathers values from all procs into all procs,
// tiled by proc into dest of size np * len(orig).
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) AllGatherC64(dest, orig []complex64) error {
return nil
}
// ScatterC64 scatters values from fmProc to all procs, distributing len(dest) size chunks to
// each proc from orig slice, which must be of size np * len(dest). This is inverse of Gather.
// IMPORTANT: orig and dest must be different slices
func (cm *Comm) ScatterC64(fmProc int, dest, orig []complex64) error {
return nil
}
// Copyright (c) 2020, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !mpi && !mpich
package mpi
// this file provides dummy versions, built by default, so mpi can be included
// generically without incurring additional complexity.
// set LogErrors to control whether MPI errors are automatically logged or not
var LogErrors = true
// Op is an aggregation operation: Sum, Min, Max, etc
type Op int
const (
OpSum Op = iota
OpMax
OpMin
OpProd
OpLAND // logical AND
OpLOR // logical OR
OpBAND // bitwise AND
OpBOR // bitwise OR
)
const (
// Root is the rank 0 node -- it is more semantic to use this
Root int = 0
)
// IsOn tells whether MPI is on or not
// NOTE: this returns true even after Stop
func IsOn() bool {
return false
}
// Init initialises MPI
func Init() {
}
// InitThreadSafe initialises MPI thread safe
func InitThreadSafe() error {
return nil
}
// Finalize finalises MPI (frees resources, shuts it down)
func Finalize() {
}
// WorldRank returns this proc's rank/ID within the World communicator.
// Returns 0 if not yet initialized, so it is always safe to call.
func WorldRank() (rank int) {
return 0
}
// WorldSize returns the number of procs in the World communicator.
// Returns 1 if not yet initialized, so it is always safe to call.
func WorldSize() (size int) {
return 1
}
// Comm is the MPI communicator -- all MPI communication operates as methods
// on this struct. It holds the MPI_Comm communicator and MPI_Group for
// sub-World group communication.
type Comm struct {
}
// NewComm creates a new communicator.
// if ranks is nil, communicator is for World (all active procs).
// otherwise, defined a group-level commuicator for given ranks.
func NewComm(ranks []int) (*Comm, error) {
cm := &Comm{}
return cm, nil
}
// Rank returns the rank/ID for this proc
func (cm *Comm) Rank() (rank int) {
return 0
}
// Size returns the number of procs in this communicator
func (cm *Comm) Size() (size int) {
return 1
}
// Abort aborts MPI
func (cm *Comm) Abort() error {
return nil
}
// Barrier forces synchronisation
func (cm *Comm) Barrier() error {
return nil
}
// Copyright (c) 2020, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mpi
import "fmt"
// PrintAllProcs causes mpi.Printf to print on all processors -- otherwise just 0
var PrintAllProcs = false
// Printf does fmt.Printf only on the 0 rank node (see also AllPrintf to do all)
// and PrintAllProcs var to override for debugging, and print all
func Printf(fs string, pars ...any) {
if !PrintAllProcs && WorldRank() > 0 {
return
}
if WorldRank() > 0 {
AllPrintf(fs, pars...)
} else {
fmt.Printf(fs, pars...)
}
}
// AllPrintf does fmt.Printf on all nodes, with node rank printed first
// This is best for debugging MPI itself.
func AllPrintf(fs string, pars ...any) {
fs = fmt.Sprintf("P%d: ", WorldRank()) + fs
fmt.Printf(fs, pars...)
}
// Println does fmt.Println only on the 0 rank node (see also AllPrintln to do all)
// and PrintAllProcs var to override for debugging, and print all
func Println(fs ...any) {
if !PrintAllProcs && WorldRank() > 0 {
return
}
if WorldRank() > 0 {
AllPrintln(fs...)
} else {
fmt.Println(fs...)
}
}
// AllPrintln does fmt.Println on all nodes, with node rank printed first
// This is best for debugging MPI itself.
func AllPrintln(fs ...any) {
fsa := make([]any, len(fs))
copy(fsa[1:], fs)
fsa[0] = fmt.Sprintf("P%d: ", WorldRank())
fmt.Println(fsa...)
}
// Copyright (c) 2019, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package randx
// BoolP is a simple method to generate a true value with given probability
// (else false). It is just rand.Float64() < p but this is more readable
// and explicit.
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func BoolP(p float64, randOpt ...Rand) bool {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
return rnd.Float64() < p
}
// BoolP32 is a simple method to generate a true value with given probability
// (else false). It is just rand.Float32() < p but this is more readable
// and explicit.
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func BoolP32(p float32, randOpt ...Rand) bool {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
return rnd.Float32() < p
}
// Copyright (c) 2023, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package randx
import (
"math"
)
// note: this file contains random distribution functions
// from gonum.org/v1/gonum/stat/distuv
// which we modified only to use the randx.Rand interface.
// BinomialGen returns binomial with n trials (par) each of probability p (var)
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func BinomialGen(n, p float64, randOpt ...Rand) float64 {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
// NUMERICAL RECIPES IN C: THE ART OF SCIENTIFIC COMPUTING (ISBN 0-521-43108-5)
// p. 295-6
// http://www.aip.de/groups/soe/local/numres/bookcpdf/c7-3.pdf
porg := p
if p > 0.5 {
p = 1 - p
}
am := n * p
if n < 25 {
// Use direct method.
bnl := 0.0
for i := 0; i < int(n); i++ {
if rnd.Float64() < p {
bnl++
}
}
if p != porg {
return n - bnl
}
return bnl
}
if am < 1 {
// Use rejection method with Poisson proposal.
const logM = 2.6e-2 // constant for rejection sampling (https://en.wikipedia.org/wiki/Rejection_sampling)
var bnl float64
z := -p
pclog := (1 + 0.5*z) * z / (1 + (1+1.0/6*z)*z) // Padé approximant of log(1 + x)
for {
bnl = 0.0
t := 0.0
for i := 0; i < int(n); i++ {
t += rnd.ExpFloat64()
if t >= am {
break
}
bnl++
}
bnlc := n - bnl
z = -bnl / n
log1p := (1 + 0.5*z) * z / (1 + (1+1.0/6*z)*z)
t = (bnlc+0.5)*log1p + bnl - bnlc*pclog + 1/(12*bnlc) - am + logM // Uses Stirling's expansion of log(n!)
if rnd.ExpFloat64() >= t {
break
}
}
if p != porg {
return n - bnl
}
return bnl
}
// Original algorithm samples from a Poisson distribution with the
// appropriate expected value. However, the Poisson approximation is
// asymptotic such that the absolute deviation in probability is O(1/n).
// Rejection sampling produces exact variates with at worst less than 3%
// rejection with miminal additional computation.
// Use rejection method with Cauchy proposal.
g, _ := math.Lgamma(n + 1)
plog := math.Log(p)
pclog := math.Log1p(-p)
sq := math.Sqrt(2 * am * (1 - p))
for {
var em, y float64
for {
y = math.Tan(math.Pi * rnd.Float64())
em = sq*y + am
if em >= 0 && em < n+1 {
break
}
}
em = math.Floor(em)
lg1, _ := math.Lgamma(em + 1)
lg2, _ := math.Lgamma(n - em + 1)
t := 1.2 * sq * (1 + y*y) * math.Exp(g-lg1-lg2+em*plog+(n-em)*pclog)
if rnd.Float64() <= t {
if p != porg {
return n - em
}
return em
}
}
}
// PoissonGen returns poisson variable, as number of events in interval,
// with event rate (lmb = Var) plus mean
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func PoissonGen(lambda float64, randOpt ...Rand) float64 {
// NUMERICAL RECIPES IN C: THE ART OF SCIENTIFIC COMPUTING (ISBN 0-521-43108-5)
// p. 294
// <http://www.aip.de/groups/soe/local/numres/bookcpdf/c7-3.pdf>
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
if lambda < 10.0 {
// Use direct method.
var em float64
t := 0.0
for {
t += rnd.ExpFloat64()
if t >= lambda {
break
}
em++
}
return em
}
// Generate using:
// W. Hörmann. "The transformed rejection method for generating Poisson
// random variables." Insurance: Mathematics and Economics
// 12.1 (1993): 39-45.
b := 0.931 + 2.53*math.Sqrt(lambda)
a := -0.059 + 0.02483*b
invalpha := 1.1239 + 1.1328/(b-3.4)
vr := 0.9277 - 3.6224/(b-2)
for {
U := rnd.Float64() - 0.5
V := rnd.Float64()
us := 0.5 - math.Abs(U)
k := math.Floor((2*a/us+b)*U + lambda + 0.43)
if us >= 0.07 && V <= vr {
return k
}
if k <= 0 || (us < 0.013 && V > us) {
continue
}
lg, _ := math.Lgamma(k + 1)
if math.Log(V*invalpha/(a/(us*us)+b)) <= k*math.Log(lambda)-lambda-lg {
return k
}
}
}
// GammaGen represents maximum entropy distribution with two parameters:
// a shape parameter (Alpha, Par in RandParams),
// and a scaling parameter (Beta, Var in RandParams).
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func GammaGen(alpha, beta float64, randOpt ...Rand) float64 {
const (
// The 0.2 threshold is from https://www4.stat.ncsu.edu/~rmartin/Codes/rgamss.R
// described in detail in https://arxiv.org/abs/1302.1884.
smallAlphaThresh = 0.2
)
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
if beta <= 0 {
panic("GammaGen: beta <= 0")
}
a := alpha
b := beta
switch {
case a <= 0:
panic("gamma: alpha <= 0")
case a == 1:
// Generate from exponential
return rnd.ExpFloat64() / b
case a < smallAlphaThresh:
// Generate using
// Liu, Chuanhai, Martin, Ryan and Syring, Nick. "Simulating from a
// gamma distribution with small shape parameter"
// https://arxiv.org/abs/1302.1884
// use this reference: http://link.springer.com/article/10.1007/s00180-016-0692-0
// Algorithm adjusted to work in log space as much as possible.
lambda := 1/a - 1
lr := -math.Log1p(1 / lambda / math.E)
for {
e := rnd.ExpFloat64()
var z float64
if e >= -lr {
z = e + lr
} else {
z = -rnd.ExpFloat64() / lambda
}
eza := math.Exp(-z / a)
lh := -z - eza
var lEta float64
if z >= 0 {
lEta = -z
} else {
lEta = -1 + lambda*z
}
if lh-lEta > -rnd.ExpFloat64() {
return eza / b
}
}
case a >= smallAlphaThresh:
// Generate using:
// Marsaglia, George, and Wai Wan Tsang. "A simple method for generating
// gamma variables." ACM Transactions on Mathematical Software (TOMS)
// 26.3 (2000): 363-372.
d := a - 1.0/3
m := 1.0
if a < 1 {
d += 1.0
m = math.Pow(rnd.Float64(), 1/a)
}
c := 1 / (3 * math.Sqrt(d))
for {
x := rnd.NormFloat64()
v := 1 + x*c
if v <= 0.0 {
continue
}
v = v * v * v
u := rnd.Float64()
if u < 1.0-0.0331*(x*x)*(x*x) {
return m * d * v / b
}
if math.Log(u) < 0.5*x*x+d*(1-v+math.Log(v)) {
return m * d * v / b
}
}
}
panic("unreachable")
}
// GaussianGen returns gaussian (normal) random number with given
// mean and sigma standard deviation.
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func GaussianGen(mean, sigma float64, randOpt ...Rand) float64 {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
return mean + sigma*rnd.NormFloat64()
}
// BetaGen returns beta random number with two shape parameters
// alpha > 0 and beta > 0
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func BetaGen(alpha, beta float64, randOpt ...Rand) float64 {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
ga := GammaGen(alpha, 1, rnd)
gb := GammaGen(beta, 1, rnd)
return ga / (ga + gb)
}
// Code generated by "core generate -add-types"; DO NOT EDIT.
package randx
import (
"cogentcore.org/core/enums"
)
var _RandDistsValues = []RandDists{0, 1, 2, 3, 4, 5, 6}
// RandDistsN is the highest valid value for type RandDists, plus one.
const RandDistsN RandDists = 7
var _RandDistsValueMap = map[string]RandDists{`Uniform`: 0, `Binomial`: 1, `Poisson`: 2, `Gamma`: 3, `Gaussian`: 4, `Beta`: 5, `Mean`: 6}
var _RandDistsDescMap = map[RandDists]string{0: `Uniform has a uniform probability distribution over Var = range on either side of the Mean`, 1: `Binomial represents number of 1's in n (Par) random (Bernouli) trials of probability p (Var)`, 2: `Poisson represents number of events in interval, with event rate (lambda = Var) plus Mean`, 3: `Gamma represents maximum entropy distribution with two parameters: scaling parameter (Var) and shape parameter k (Par) plus Mean`, 4: `Gaussian normal with Var = stddev plus Mean`, 5: `Beta with Var = alpha and Par = beta shape parameters`, 6: `Mean is just the constant Mean, no randomness`}
var _RandDistsMap = map[RandDists]string{0: `Uniform`, 1: `Binomial`, 2: `Poisson`, 3: `Gamma`, 4: `Gaussian`, 5: `Beta`, 6: `Mean`}
// String returns the string representation of this RandDists value.
func (i RandDists) String() string { return enums.String(i, _RandDistsMap) }
// SetString sets the RandDists value from its string representation,
// and returns an error if the string is invalid.
func (i *RandDists) SetString(s string) error {
return enums.SetString(i, s, _RandDistsValueMap, "RandDists")
}
// Int64 returns the RandDists value as an int64.
func (i RandDists) Int64() int64 { return int64(i) }
// SetInt64 sets the RandDists value from an int64.
func (i *RandDists) SetInt64(in int64) { *i = RandDists(in) }
// Desc returns the description of the RandDists value.
func (i RandDists) Desc() string { return enums.Desc(i, _RandDistsDescMap) }
// RandDistsValues returns all possible values for the type RandDists.
func RandDistsValues() []RandDists { return _RandDistsValues }
// Values returns all possible values for the type RandDists.
func (i RandDists) Values() []enums.Enum { return enums.Values(_RandDistsValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i RandDists) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *RandDists) UnmarshalText(text []byte) error {
return enums.UnmarshalText(i, text, "RandDists")
}
// Copyright (c) 2019, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package randx
// PChoose32 chooses an index in given slice of float32's at random according
// to the probilities of each item (must be normalized to sum to 1).
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func PChoose32(ps []float32, randOpt ...Rand) int {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
pv := rnd.Float32()
sum := float32(0)
for i, p := range ps {
sum += p
if pv < sum { // note: lower values already excluded
return i
}
}
return len(ps) - 1
}
// PChoose64 chooses an index in given slice of float64's at random according
// to the probilities of each item (must be normalized to sum to 1)
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func PChoose64(ps []float64, randOpt ...Rand) int {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
pv := rnd.Float64()
sum := float64(0)
for i, p := range ps {
sum += p
if pv < sum { // note: lower values already excluded
return i
}
}
return len(ps) - 1
}
// Copyright (c) 2019, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package randx
// SequentialInts initializes slice of ints to sequential start..start+N-1
// numbers -- for cases where permuting the order is optional.
func SequentialInts(ins []int, start int) {
for i := range ins {
ins[i] = start + i
}
}
// PermuteInts permutes (shuffles) the order of elements in the given int slice
// using the standard Fisher-Yates shuffle
// https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
// So you don't have to remember how to call rand.Shuffle.
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func PermuteInts(ins []int, randOpt ...Rand) {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
rnd.Shuffle(len(ins), func(i, j int) {
ins[i], ins[j] = ins[j], ins[i]
})
}
// PermuteStrings permutes (shuffles) the order of elements in the given string slice
// using the standard Fisher-Yates shuffle
// https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
// So you don't have to remember how to call rand.Shuffle
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func PermuteStrings(ins []string, randOpt ...Rand) {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
rnd.Shuffle(len(ins), func(i, j int) {
ins[i], ins[j] = ins[j], ins[i]
})
}
// PermuteFloat32s permutes (shuffles) the order of elements in the given float32 slice
// using the standard Fisher-Yates shuffle
// https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
// So you don't have to remember how to call rand.Shuffle
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func PermuteFloat32s(ins []float32, randOpt ...Rand) {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
rnd.Shuffle(len(ins), func(i, j int) {
ins[i], ins[j] = ins[j], ins[i]
})
}
// PermuteFloat64s permutes (shuffles) the order of elements in the given float64 slice
// using the standard Fisher-Yates shuffle
// https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
// So you don't have to remember how to call rand.Shuffle
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func PermuteFloat64s(ins []float64, randOpt ...Rand) {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
rnd.Shuffle(len(ins), func(i, j int) {
ins[i], ins[j] = ins[j], ins[i]
})
}
// Copyright (c) 2023, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package randx
//go:generate core generate -add-types
import "math/rand"
// Rand provides an interface with most of the standard
// rand.Rand methods, to support the use of either the
// global rand generator or a separate Rand source.
type Rand interface {
// Seed uses the provided seed value to initialize the generator to a deterministic state.
// Seed should not be called concurrently with any other Rand method.
Seed(seed int64)
// Int63 returns a non-negative pseudo-random 63-bit integer as an int64.
Int63() int64
// Uint32 returns a pseudo-random 32-bit value as a uint32.
Uint32() uint32
// Uint64 returns a pseudo-random 64-bit value as a uint64.
Uint64() uint64
// Int31 returns a non-negative pseudo-random 31-bit integer as an int32.
Int31() int32
// Int returns a non-negative pseudo-random int.
Int() int
// Int63n returns, as an int64, a non-negative pseudo-random number in the half-open interval [0,n).
// It panics if n <= 0.
Int63n(n int64) int64
// Int31n returns, as an int32, a non-negative pseudo-random number in the half-open interval [0,n).
// It panics if n <= 0.
Int31n(n int32) int32
// Intn returns, as an int, a non-negative pseudo-random number in the half-open interval [0,n).
// It panics if n <= 0.
Intn(n int) int
// Float64 returns, as a float64, a pseudo-random number in the half-open interval [0.0,1.0).
Float64() float64
// Float32 returns, as a float32, a pseudo-random number in the half-open interval [0.0,1.0).
Float32() float32
// NormFloat64 returns a normally distributed float64 in the range
// [-math.MaxFloat64, +math.MaxFloat64] with
// standard normal distribution (mean = 0, stddev = 1)
// from the default Source.
// To produce a different normal distribution, callers can
// adjust the output using:
//
// sample = NormFloat64() * desiredStdDev + desiredMean
NormFloat64() float64
// ExpFloat64 returns an exponentially distributed float64 in the range
// (0, +math.MaxFloat64] with an exponential distribution whose rate parameter
// (lambda) is 1 and whose mean is 1/lambda (1) from the default Source.
// To produce a distribution with a different rate parameter,
// callers can adjust the output using:
//
// sample = ExpFloat64() / desiredRateParameter
ExpFloat64() float64
// Perm returns, as a slice of n ints, a pseudo-random permutation of the integers
// in the half-open interval [0,n).
Perm(n int) []int
// Shuffle pseudo-randomizes the order of elements.
// n is the number of elements. Shuffle panics if n < 0.
// swap swaps the elements with indexes i and j.
Shuffle(n int, swap func(i, j int))
}
// SysRand supports the system random number generator
// for either a separate rand.Rand source, or, if that
// is nil, the global rand stream.
type SysRand struct {
// if non-nil, use this random number source instead of the global default one
Rand *rand.Rand `display:"-"`
}
// NewGlobalRand returns a new SysRand that implements the
// randx.Rand interface, with the system global rand source.
func NewGlobalRand() *SysRand {
r := &SysRand{}
return r
}
// NewSysRand returns a new SysRand with a new
// rand.Rand random source with given initial seed.
func NewSysRand(seed int64) *SysRand {
r := &SysRand{}
r.NewRand(seed)
return r
}
// NewRand sets Rand to a new rand.Rand source using given seed.
func (r *SysRand) NewRand(seed int64) {
r.Rand = rand.New(rand.NewSource(seed))
}
// Seed uses the provided seed value to initialize the generator to a deterministic state.
// Seed should not be called concurrently with any other Rand method.
func (r *SysRand) Seed(seed int64) {
if r.Rand == nil {
rand.Seed(seed)
return
}
r.Rand.Seed(seed)
}
// Int63 returns a non-negative pseudo-random 63-bit integer as an int64.
func (r *SysRand) Int63() int64 {
if r.Rand == nil {
return rand.Int63()
}
return r.Rand.Int63()
}
// Uint32 returns a pseudo-random 32-bit value as a uint32.
func (r *SysRand) Uint32() uint32 {
if r.Rand == nil {
return rand.Uint32()
}
return r.Rand.Uint32()
}
// Uint64 returns a pseudo-random 64-bit value as a uint64.
func (r *SysRand) Uint64() uint64 {
if r.Rand == nil {
return rand.Uint64()
}
return r.Rand.Uint64()
}
// Int31 returns a non-negative pseudo-random 31-bit integer as an int32.
func (r *SysRand) Int31() int32 {
if r.Rand == nil {
return rand.Int31()
}
return r.Rand.Int31()
}
// Int returns a non-negative pseudo-random int.
func (r *SysRand) Int() int {
if r.Rand == nil {
return rand.Int()
}
return r.Rand.Int()
}
// Int63n returns, as an int64, a non-negative pseudo-random number in the half-open interval [0,n).
// It panics if n <= 0.
func (r *SysRand) Int63n(n int64) int64 {
if r.Rand == nil {
return rand.Int63n(n)
}
return r.Rand.Int63n(n)
}
// Int31n returns, as an int32, a non-negative pseudo-random number in the half-open interval [0,n).
// It panics if n <= 0.
func (r *SysRand) Int31n(n int32) int32 {
if r.Rand == nil {
return rand.Int31n(n)
}
return r.Rand.Int31n(n)
}
// Intn returns, as an int, a non-negative pseudo-random number in the half-open interval [0,n).
// It panics if n <= 0.
func (r *SysRand) Intn(n int) int {
if r.Rand == nil {
return rand.Intn(n)
}
return r.Rand.Intn(n)
}
// Float64 returns, as a float64, a pseudo-random number in the half-open interval [0.0,1.0).
func (r *SysRand) Float64() float64 {
if r.Rand == nil {
return rand.Float64()
}
return r.Rand.Float64()
}
// Float32 returns, as a float32, a pseudo-random number in the half-open interval [0.0,1.0).
func (r *SysRand) Float32() float32 {
if r.Rand == nil {
return rand.Float32()
}
return r.Rand.Float32()
}
// NormFloat64 returns a normally distributed float64 in the range
// [-math.MaxFloat64, +math.MaxFloat64] with
// standard normal distribution (mean = 0, stddev = 1)
// from the default Source.
// To produce a different normal distribution, callers can
// adjust the output using:
//
// sample = NormFloat64() * desiredStdDev + desiredMean
func (r *SysRand) NormFloat64() float64 {
if r.Rand == nil {
return rand.NormFloat64()
}
return r.Rand.NormFloat64()
}
// ExpFloat64 returns an exponentially distributed float64 in the range
// (0, +math.MaxFloat64] with an exponential distribution whose rate parameter
// (lambda) is 1 and whose mean is 1/lambda (1) from the default Source.
// To produce a distribution with a different rate parameter,
// callers can adjust the output using:
//
// sample = ExpFloat64() / desiredRateParameter
func (r *SysRand) ExpFloat64() float64 {
if r.Rand == nil {
return rand.ExpFloat64()
}
return r.Rand.ExpFloat64()
}
// Perm returns, as a slice of n ints, a pseudo-random permutation of the integers
// in the half-open interval [0,n).
func (r *SysRand) Perm(n int) []int {
if r.Rand == nil {
return rand.Perm(n)
}
return r.Rand.Perm(n)
}
// Shuffle pseudo-randomizes the order of elements.
// n is the number of elements. Shuffle panics if n < 0.
// swap swaps the elements with indexes i and j.
func (r *SysRand) Shuffle(n int, swap func(i, j int)) {
if r.Rand == nil {
rand.Shuffle(n, swap)
return
}
r.Rand.Shuffle(n, swap)
}
// Copyright (c) 2019, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package randx
// RandParams provides parameterized random number generation according to different distributions
// and variance, mean params
type RandParams struct { //git:add
// distribution to generate random numbers from
Dist RandDists
// mean of random distribution -- typically added to generated random variants
Mean float64
// variability parameter for the random numbers (gauss = standard deviation, not variance; uniform = half-range, others as noted in RandDists)
Var float64
// extra parameter for distribution (depends on each one)
Par float64
}
func (rp *RandParams) Defaults() {
rp.Var = 1
rp.Par = 1
}
func (rp *RandParams) ShouldDisplay(field string) bool {
switch field {
case "Par":
return rp.Dist == Gamma || rp.Dist == Binomial || rp.Dist == Beta
}
return true
}
// Gen generates a random variable according to current parameters.
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func (rp *RandParams) Gen(randOpt ...Rand) float64 {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
switch rp.Dist {
case Uniform:
return UniformMeanRange(rp.Mean, rp.Var, rnd)
case Binomial:
return rp.Mean + BinomialGen(rp.Par, rp.Var, rnd)
case Poisson:
return rp.Mean + PoissonGen(rp.Var, rnd)
case Gamma:
return rp.Mean + GammaGen(rp.Par, rp.Var, rnd)
case Gaussian:
return GaussianGen(rp.Mean, rp.Var, rnd)
case Beta:
return rp.Mean + BetaGen(rp.Var, rp.Par, rnd)
}
return rp.Mean
}
// RandDists are different random number distributions
type RandDists int32 //enums:enum
// The random number distributions
const (
// Uniform has a uniform probability distribution over Var = range on either side of the Mean
Uniform RandDists = iota
// Binomial represents number of 1's in n (Par) random (Bernouli) trials of probability p (Var)
Binomial
// Poisson represents number of events in interval, with event rate (lambda = Var) plus Mean
Poisson
// Gamma represents maximum entropy distribution with two parameters: scaling parameter (Var)
// and shape parameter k (Par) plus Mean
Gamma
// Gaussian normal with Var = stddev plus Mean
Gaussian
// Beta with Var = alpha and Par = beta shape parameters
Beta
// Mean is just the constant Mean, no randomness
Mean
)
// IntZeroN returns uniform random integer in the range between 0 and n, exclusive of n: [0,n).
// Thr is an optional parallel thread index (-1 0 to ignore).
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func IntZeroN(n int64, randOpt ...Rand) int64 {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
return rnd.Int63n(n)
}
// IntMinMax returns uniform random integer in range between min and max, exclusive of max: [min,max).
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func IntMinMax(min, max int64, randOpt ...Rand) int64 {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
return min + rnd.Int63n(max-min)
}
// IntMeanRange returns uniform random integer with given range on either side of the mean:
// [mean - range, mean + range]
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func IntMeanRange(mean, rnge int64, randOpt ...Rand) int64 {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
return mean + (rnd.Int63n(2*rnge+1) - rnge)
}
// ZeroOne returns a uniform random number between zero and one (exclusive of 1)
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func ZeroOne(randOpt ...Rand) float64 {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
return rnd.Float64()
}
// UniformMinMax returns uniform random number between min and max values inclusive
// (Do not use for generating integers - will not include max!)
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func UniformMinMax(min, max float64, randOpt ...Rand) float64 {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
return min + (max-min)*rnd.Float64()
}
// UniformMeanRange returns uniform random number with given range on either size of the mean:
// [mean - range, mean + range]
// Optionally can pass a single Rand interface to use --
// otherwise uses system global Rand source.
func UniformMeanRange(mean, rnge float64, randOpt ...Rand) float64 {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
return mean + rnge*2.0*(rnd.Float64()-0.5)
}
// Copyright (c) 2019, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package randx
import (
"time"
)
// Seeds is a set of random seeds, typically used one per Run
type Seeds []int64
// Init allocates given number of seeds and initializes them to
// sequential numbers 1..n
func (rs *Seeds) Init(n int) {
*rs = make([]int64, n)
for i := range *rs {
(*rs)[i] = int64(i) + 1
}
}
// Set sets the given seed to either the single Rand
// interface passed, or the system global Rand source.
func (rs *Seeds) Set(idx int, randOpt ...Rand) {
var rnd Rand
if len(randOpt) == 0 {
rnd = NewGlobalRand()
} else {
rnd = randOpt[0]
}
rnd.Seed((*rs)[idx])
}
// NewSeeds sets a new set of random seeds based on current time
func (rs *Seeds) NewSeeds() {
rn := time.Now().UnixNano()
for i := range *rs {
(*rs)[i] = rn + int64(i)
}
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"bufio"
"flag"
"fmt"
"os"
"sort"
"cogentcore.org/core/core"
"cogentcore.org/lab/stats/stats"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorfs"
)
var (
Output string
Col string
OutFile *os.File
OutWriter *bufio.Writer
LF = []byte("\n")
Delete bool
LogPrec = 4
)
func main() {
var help bool
var avg bool
var colavg bool
flag.BoolVar(&help, "help", false, "if true, report usage info")
flag.BoolVar(&avg, "avg", false, "if true, files must have same cols (ideally rows too, though not necessary), outputs average of any float-type columns across files")
flag.BoolVar(&colavg, "colavg", false, "if true, outputs average of any float-type columns aggregated by column")
flag.StringVar(&Col, "col", "", "name of column for colavg")
flag.StringVar(&Output, "output", "", "name of output file -- stdout if not specified")
flag.StringVar(&Output, "o", "", "name of output file -- stdout if not specified")
flag.BoolVar(&Delete, "delete", false, "if true, delete the source files after cat -- careful!")
flag.BoolVar(&Delete, "d", false, "if true, delete the source files after cat -- careful!")
flag.IntVar(&LogPrec, "prec", 4, "precision for number output -- defaults to 4")
flag.Parse()
files := flag.Args()
sort.StringSlice(files).Sort()
if Output != "" {
OutFile, err := os.Create(Output)
if err != nil {
fmt.Println("Error creating output file:", err)
os.Exit(1)
}
defer OutFile.Close()
OutWriter = bufio.NewWriter(OutFile)
} else {
OutWriter = bufio.NewWriter(os.Stdout)
}
switch {
case help || len(files) == 0:
fmt.Printf("\netcat is a data table concatenation utility\n\tassumes all files have header lines, and only retains the header for the first file\n\t(otherwise just use regular cat)\n")
flag.PrintDefaults()
case colavg:
AvgByColumn(files, Col)
case avg:
AvgCat(files)
default:
RawCat(files)
}
OutWriter.Flush()
}
// RawCat concatenates all data in one big file
func RawCat(files []string) {
for fi, fn := range files {
fp, err := os.Open(fn)
if err != nil {
fmt.Println("Error opening file: ", err)
continue
}
scan := bufio.NewScanner(fp)
li := 0
for {
if !scan.Scan() {
break
}
ln := scan.Bytes()
if li == 0 {
if fi == 0 {
OutWriter.Write(ln)
OutWriter.Write(LF)
}
} else {
OutWriter.Write(ln)
OutWriter.Write(LF)
}
li++
}
fp.Close()
if Delete {
os.Remove(fn)
}
}
}
// AvgCat computes average across all runs
func AvgCat(files []string) {
dts := make([]*table.Table, 0, len(files))
for _, fn := range files {
dt := table.New()
err := dt.OpenCSV(core.Filename(fn), tensor.Tab)
if err != nil {
fmt.Println("Error opening file: ", err)
continue
}
if dt.NumRows() == 0 {
fmt.Printf("File %v empty\n", fn)
continue
}
dts = append(dts, dt)
}
if len(dts) == 0 {
fmt.Println("No files or files are empty, exiting")
return
}
avgdt := stats.MeanTables(dts)
tensor.SetPrecision(avgdt, LogPrec)
avgdt.SaveCSV(core.Filename(Output), tensor.Tab, table.Headers)
}
// AvgByColumn computes average by given column for given files
// If column is empty, averages across all rows.
func AvgByColumn(files []string, column string) {
for _, fn := range files {
dt := table.New()
err := dt.OpenCSV(core.Filename(fn), tensor.Tab)
if err != nil {
fmt.Println("Error opening file: ", err)
continue
}
if dt.NumRows() == 0 {
fmt.Printf("File %v empty\n", fn)
continue
}
dir, _ := tensorfs.NewDir("Groups")
if column == "" {
stats.GroupAll(dir, dt.ColumnByIndex(0))
} else {
stats.TableGroups(dir, dt, column)
}
var cols []string
for ci, cl := range dt.Columns.Values {
if cl.IsString() || dt.Columns.Keys[ci] == column {
continue
}
cols = append(cols, dt.Columns.Keys[ci])
}
stats.TableGroupStats(dir, stats.StatMean, dt, cols...)
std := dir.Node("Stats")
avgdt := tensorfs.DirTable(std, nil) // todo: has stat name slash
tensor.SetPrecision(avgdt, LogPrec)
avgdt.SaveCSV(core.Filename(Output), tensor.Tab, table.Headers)
}
}
// Copyright (c) 2020, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"errors"
"fmt"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/cli"
"cogentcore.org/lab/stats/stats"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorfs"
)
//go:generate core generate -add-types -add-funcs
type Config struct {
// Name of the column to compute stats on.
Column string `posarg:"0" required:"+"`
// Files to compute stats on.
Files []string `posarg:"leftover" required:"+"`
}
func Run(c *Config) error {
var errs []error
fmt.Printf("| %44s ", "File")
for _, st := range stats.DescriptiveStats {
fmt.Printf("| %12s ", st.String())
}
fmt.Printf("|\n")
for _, f := range c.Files {
if ok, err := fsx.FileExists(f); !ok {
errs = append(errs, err, fmt.Errorf("file %q not found", f))
continue
}
dt := table.New()
dt.OpenCSV(fsx.Filename(f), tensor.Detect)
cl, err := dt.ColumnTry(c.Column)
if err != nil {
errs = append(errs, err)
continue
}
dir, _ := tensorfs.NewDir("Desc")
stats.Describe(dir, cl)
ds := dir.Dir("Describe/" + c.Column)
fmt.Printf("| %44s ", f)
for _, st := range stats.DescriptiveStats {
v := ds.Float64(st.String())
fmt.Printf("| %12.2f ", v.Float(0))
}
fmt.Printf("|\n")
}
return errors.Join(errs...)
}
func main() {
opts := cli.DefaultOptions("tstats", "tstats computes standard descriptive statistics on a column of data in a CSV / TSV file.")
cli.Run(opts, &Config{}, Run)
}
// Copyright (c) 2020, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"bufio"
"os"
"strings"
"time"
)
// File represents one opened file -- all data is read in and maintained here
type File struct {
// file name (either in same dir or include path)
FName string `desc:"file name (either in same dir or include path)"`
// mod time of file when last read
ModTime time.Time `desc:"mod time of file when last read"`
// delim is commas, not tabs
Commas bool `desc:"delim is commas, not tabs"`
// rows of data == len(Data)
Rows int `desc:"rows of data == len(Data)"`
// width of each column: resized to fit widest element
Widths []int `desc:"width of each column: resized to fit widest element"`
// headers
Heads []string `desc:"headers"`
// data -- rows 1..end
Data [][]string `desc:"data -- rows 1..end"`
}
// Files is a slice of open files
type Files []*File
// TheFiles are the set of open files
var TheFiles Files
// Open opens file, reads it
func (fl *File) Open(fname string) error {
fl.FName = fname
return fl.Read()
}
// CheckUpdate checks if file has been modified -- returns true if so
func (fl *File) CheckUpdate() bool {
st, err := os.Stat(fl.FName)
if err != nil {
return false
}
return st.ModTime().After(fl.ModTime)
}
// Read reads data from file
func (fl *File) Read() error {
st, err := os.Stat(fl.FName)
if err != nil {
return err
}
fl.ModTime = st.ModTime()
f, err := os.Open(fl.FName)
if err != nil {
return err
}
defer f.Close()
if fl.Data != nil {
fl.Data = fl.Data[:0]
}
scan := bufio.NewScanner(f)
ln := 0
for scan.Scan() {
s := string(scan.Bytes())
var fd []string
if fl.Commas {
fd = strings.Split(s, ",")
} else {
fd = strings.Split(s, "\t")
}
if ln == 0 {
if len(fd) == 0 || strings.Count(s, ",") > strings.Count(s, "\t") {
fl.Commas = true
fd = strings.Split(s, ",")
}
fl.Heads = fd
fl.Widths = make([]int, len(fl.Heads))
fl.FitWidths(fd)
ln++
continue
}
fl.Data = append(fl.Data, fd)
fl.FitWidths(fd)
ln++
}
fl.Rows = ln - 1 // skip header
return err
}
// FitWidths expands widths given current set of fields
func (fl *File) FitWidths(fd []string) {
nw := len(fl.Widths)
for i, f := range fd {
if i >= nw {
break
}
w := max(fl.Widths[i], len(f))
fl.Widths[i] = w
}
}
/////////////////////////////////////////////////////////////////
// Files
// Open opens all files
func (fl *Files) Open(fnms []string) {
for _, fn := range fnms {
f := &File{}
err := f.Open(fn)
if err == nil {
*fl = append(*fl, f)
}
}
}
// CheckUpdates check for any updated files, re-read if so -- returns true if so
func (fl *Files) CheckUpdates() bool {
got := false
for _, f := range *fl {
if f.CheckUpdate() {
f.Read()
got = true
}
}
return got
}
// Copyright (c) 2020, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import "github.com/nsf/termbox-go"
func (tm *Term) Help() {
termbox.Clear(termbox.ColorDefault, termbox.ColorDefault)
ln := 0
tm.DrawStringDef(0, ln, "Key(s) Function")
ln++
tm.DrawStringDef(0, ln, "--------------------------------------------------------------")
ln++
tm.DrawStringDef(0, ln, "spc,n page down")
ln++
tm.DrawStringDef(0, ln, "p page up")
ln++
tm.DrawStringDef(0, ln, "f scroll right-hand panel to the right")
ln++
tm.DrawStringDef(0, ln, "b scroll right-hand panel to the left")
ln++
tm.DrawStringDef(0, ln, "w widen the left-hand panel of columns")
ln++
tm.DrawStringDef(0, ln, "s shrink the left-hand panel of columns")
ln++
tm.DrawStringDef(0, ln, "t toggle tail-mode (auto updating as file grows) on/off")
ln++
tm.DrawStringDef(0, ln, "a jump to top")
ln++
tm.DrawStringDef(0, ln, "e jump to end")
ln++
tm.DrawStringDef(0, ln, "v rotate down through the list of files (if not all displayed)")
ln++
tm.DrawStringDef(0, ln, "u rotate up through the list of files (if not all displayed)")
ln++
tm.DrawStringDef(0, ln, "m more minimum lines per file -- increase amount shown of each file")
ln++
tm.DrawStringDef(0, ln, "l less minimum lines per file -- decrease amount shown of each file")
ln++
tm.DrawStringDef(0, ln, "d toggle display of file names")
ln++
tm.DrawStringDef(0, ln, "c toggle display of column numbers instead of names")
ln++
tm.DrawStringDef(0, ln, "q quit")
ln++
termbox.Flush()
}
// Copyright (c) 2020, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"image"
"sync"
termbox "github.com/nsf/termbox-go"
)
// Term represents the terminal display -- has all drawing routines
// and all display data. See Tail for two diff display modes.
type Term struct {
// size of terminal
Size image.Point `desc:"size of terminal"`
// number of fixed (non-scrolling) columns on left
FixCols int `desc:"number of fixed (non-scrolling) columns on left"`
// starting column index -- relative to FixCols
ColSt int `desc:"starting column index -- relative to FixCols"`
// starting row index -- for !Tail mode
RowSt int `desc:"starting row index -- for !Tail mode"`
// row from end -- for Tail mode
RowFromEnd int `desc:"row from end (relative to RowsPer) -- for Tail mode"`
// starting index into files (if too many to display)
FileSt int `desc:"starting index into files (if too many to display)"`
// number of files to display (if too many to display)
NFiles int `desc:"number of files to display (if too many to display)"`
// minimum number of lines per file
MinLines int `desc:"minimum number of lines per file"`
// maximum column width (1/4 of term width)
MaxWd int `desc:"maximum column width (1/4 of term width)"`
// max number of rows across all files
MaxRows int `desc:"max number of rows across all files"`
// number of Y rows per file total: Size.Y / len(TheFiles)
YPer int `desc:"number of Y rows per file total: Size.Y / len(TheFiles)"`
// rows of data per file (subtracting header, filename)
RowsPer int `desc:"rows of data per file (subtracting header, filename)"`
// if true, print filename
ShowFName bool `desc:"if true, print filename"`
// if true, display is synchronized by the last row for each file, and otherwise it is synchronized by the starting row. Tail also checks for file updates
Tail bool `desc:"if true, display is synchronized by the last row for each file, and otherwise it is synchronized by the starting row. Tail also checks for file updates"`
// display column numbers instead of names
ColNums bool `desc:"display column numbers instead of names"`
// draw mutex
Mu sync.Mutex `desc:"draw mutex"`
}
// TheTerm is the terminal instance
var TheTerm Term
// Draw draws the current terminal display
func (tm *Term) Draw() error {
tm.Mu.Lock()
defer tm.Mu.Unlock()
err := termbox.Clear(termbox.ColorDefault, termbox.ColorDefault)
if err != nil {
return err
}
w, h := termbox.Size()
tm.Size.X = w
tm.Size.Y = h
tm.MaxWd = tm.Size.X / 4
if tm.MinLines == 0 {
tm.MinLines = min(5, tm.Size.Y-1)
}
nf := len(TheFiles)
if nf == 0 {
return fmt.Errorf("No files")
}
ysz := tm.Size.Y - 1 // status line
tm.YPer = ysz / nf
tm.NFiles = nf
if tm.YPer < tm.MinLines {
tm.NFiles = ysz / tm.MinLines
tm.YPer = tm.MinLines
}
if tm.NFiles+tm.FileSt > nf {
tm.FileSt = max(0, nf-tm.NFiles)
}
tm.RowsPer = tm.YPer - 1
if tm.ShowFName {
tm.RowsPer--
}
sty := 0
mxrows := 0
for fi := 0; fi < tm.NFiles; fi++ {
ffi := tm.FileSt + fi
if ffi >= nf {
break
}
fl := TheFiles[ffi]
tm.DrawFile(fl, sty)
sty += tm.YPer
mxrows = max(mxrows, fl.Rows)
}
tm.MaxRows = mxrows
tm.StatusLine()
termbox.Flush()
return nil
}
// StatusLine renders the status line at bottom
func (tm *Term) StatusLine() {
pos := tm.RowSt
if tm.Tail {
pos = tm.RowFromEnd
}
stat := fmt.Sprintf("Tail: %v\tPos: %d\tMaxRows: %d\tNFile: %d\tFileSt: %d\t h = help [spc,n,p,r,f,l,b,w,s,t,a,e,v,u,m,l,c,q] ", tm.Tail, pos, tm.MaxRows, len(TheFiles), tm.FileSt)
tm.DrawString(0, tm.Size.Y-1, stat, len(stat), termbox.AttrReverse, termbox.AttrReverse)
}
// NextPage moves down a page
func (tm *Term) NextPage() error {
if tm.Tail {
mn := min(-(tm.MaxRows - tm.RowsPer), 0)
tm.RowFromEnd = min(tm.RowFromEnd+tm.RowsPer, 0)
tm.RowFromEnd = max(tm.RowFromEnd, mn)
} else {
tm.RowSt = min(tm.RowSt+tm.RowsPer, tm.MaxRows-tm.RowsPer)
tm.RowSt = max(tm.RowSt, 0)
}
return tm.Draw()
}
// PrevPage moves up a page
func (tm *Term) PrevPage() error {
if tm.Tail {
mn := min(-(tm.MaxRows - tm.RowsPer), 0)
tm.RowFromEnd = min(tm.RowFromEnd-tm.RowsPer, 0)
tm.RowFromEnd = max(tm.RowFromEnd, mn)
} else {
tm.RowSt = max(tm.RowSt-tm.RowsPer, 0)
tm.RowSt = min(tm.RowSt, tm.MaxRows-tm.RowsPer)
}
return tm.Draw()
}
// NextLine moves down a page
func (tm *Term) NextLine() error {
if tm.Tail {
mn := min(-(tm.MaxRows - tm.RowsPer), 0)
tm.RowFromEnd = min(tm.RowFromEnd+1, 0)
tm.RowFromEnd = max(tm.RowFromEnd, mn)
} else {
tm.RowSt = min(tm.RowSt+1, tm.MaxRows-tm.RowsPer)
tm.RowSt = max(tm.RowSt, 0)
}
return tm.Draw()
}
// PrevLine moves up a page
func (tm *Term) PrevLine() error {
if tm.Tail {
mn := min(-(tm.MaxRows - tm.RowsPer), 0)
tm.RowFromEnd = min(tm.RowFromEnd-1, 0)
tm.RowFromEnd = max(tm.RowFromEnd, mn)
} else {
tm.RowSt = max(tm.RowSt-1, 0)
tm.RowSt = min(tm.RowSt, tm.MaxRows-tm.RowsPer)
}
return tm.Draw()
}
// Top moves to starting row = 0
func (tm *Term) Top() error {
mn := min(-(tm.MaxRows - tm.RowsPer), 0)
tm.RowFromEnd = mn
tm.RowSt = 0
return tm.Draw()
}
// End moves row start to last position in longest file
func (tm *Term) End() error {
mx := max(tm.MaxRows-tm.RowsPer, 0)
tm.RowFromEnd = 0
tm.RowSt = mx
return tm.Draw()
}
// ScrollRight scrolls columns to right
func (tm *Term) ScrollRight() error {
tm.ColSt++ // no obvious max
return tm.Draw()
}
// ScrollLeft scrolls columns to left
func (tm *Term) ScrollLeft() error {
tm.ColSt = max(tm.ColSt-1, 0)
return tm.Draw()
}
// FixRight increases number of fixed columns
func (tm *Term) FixRight() error {
tm.FixCols++ // no obvious max
return tm.Draw()
}
// FixLeft decreases number of fixed columns
func (tm *Term) FixLeft() error {
tm.FixCols = max(tm.FixCols-1, 0)
return tm.Draw()
}
// FilesNext moves down in list of files to display
func (tm *Term) FilesNext() error {
nf := len(TheFiles)
tm.FileSt = min(tm.FileSt+1, nf-tm.NFiles)
tm.FileSt = max(tm.FileSt, 0)
return tm.Draw()
}
// FilesPrev moves up in list of files to display
func (tm *Term) FilesPrev() error {
nf := len(TheFiles)
tm.FileSt = max(tm.FileSt-1, 0)
tm.FileSt = min(tm.FileSt, nf-tm.NFiles)
return tm.Draw()
}
// MoreMinLines increases minimum number of lines per file
func (tm *Term) MoreMinLines() error {
tm.MinLines++
return tm.Draw()
}
// LessMinLines decreases minimum number of lines per file
func (tm *Term) LessMinLines() error {
tm.MinLines--
tm.MinLines = max(3, tm.MinLines)
return tm.Draw()
}
// ToggleNames toggles whether file names are shown
func (tm *Term) ToggleNames() error {
tm.ShowFName = !tm.ShowFName
return tm.Draw()
}
// ToggleTail toggles Tail mode
func (tm *Term) ToggleTail() error {
tm.Tail = !tm.Tail
return tm.Draw()
}
// ToggleColNums toggles ColNums mode
func (tm *Term) ToggleColNums() error {
tm.ColNums = !tm.ColNums
return tm.Draw()
}
// TailCheck does tail update check -- returns true if updated
func (tm *Term) TailCheck() bool {
if !tm.Tail {
return false
}
tm.Mu.Lock()
update := TheFiles.CheckUpdates()
tm.Mu.Unlock()
if !update {
return false
}
tm.Draw()
return true
}
// DrawFile draws one file, starting at given y offset
func (tm *Term) DrawFile(fl *File, sty int) {
tdo := (fl.Rows - tm.RowsPer) + tm.RowFromEnd // tail data offset for this file
tdo = max(0, tdo)
rst := min(tm.RowSt, fl.Rows-tm.RowsPer)
rst = max(0, rst)
stx := 0
for ci, hs := range fl.Heads {
if !(ci < tm.FixCols || ci >= tm.FixCols+tm.ColSt) {
continue
}
my := sty
if tm.ShowFName {
tm.DrawString(0, my, fl.FName, tm.Size.X, termbox.AttrReverse, termbox.AttrReverse)
my++
}
wmax := min(fl.Widths[ci], tm.MaxWd)
if tm.ColNums {
hs = fmt.Sprintf("%d", ci)
}
tm.DrawString(stx, my, hs, wmax, termbox.AttrReverse, termbox.AttrReverse)
if ci == tm.FixCols-1 {
tm.DrawString(stx+wmax+1, my, "|", 1, termbox.AttrReverse, termbox.AttrReverse)
}
my++
for ri := 0; ri < tm.RowsPer; ri++ {
var di int
if tm.Tail {
di = tdo + ri
} else {
di = rst + ri
}
if di >= len(fl.Data) || di < 0 {
continue
}
dr := fl.Data[di]
if ci >= len(dr) {
break
}
ds := dr[ci]
tm.DrawString(stx, my+ri, ds, wmax, termbox.ColorDefault, termbox.ColorDefault)
if ci == tm.FixCols-1 {
tm.DrawString(stx+wmax+1, my+ri, "|", 1, termbox.AttrReverse, termbox.AttrReverse)
}
}
stx += wmax + 1
if ci == tm.FixCols-1 {
stx += 2
}
if stx >= tm.Size.X {
break
}
}
}
// DrawStringDef draws string at given position, using default colors
func (tm *Term) DrawStringDef(x, y int, s string) {
tm.DrawString(x, y, s, tm.Size.X, termbox.ColorDefault, termbox.ColorDefault)
}
// DrawString draws string at given position, using given attributes
func (tm *Term) DrawString(x, y int, s string, maxlen int, fg, bg termbox.Attribute) {
if y >= tm.Size.Y || y < 0 {
return
}
for i, r := range s {
if i >= maxlen {
break
}
xp := x + i
if xp >= tm.Size.X || xp < 0 {
continue
}
termbox.SetCell(xp, y, r, fg, bg)
}
}
// Copyright (c) 2020, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"log"
"os"
"time"
"github.com/nsf/termbox-go"
)
func main() {
err := termbox.Init()
if err != nil {
log.Println(err)
panic(err)
}
defer termbox.Close()
TheFiles.Open(os.Args[1:])
nf := len(TheFiles)
if nf == 0 {
fmt.Printf("usage: etail <filename>... (space separated)\n")
return
}
if nf > 1 {
TheTerm.ShowFName = true
}
err = TheTerm.ToggleTail() // start in tail mode
if err != nil {
log.Println(err)
panic(err)
}
Tailer := time.NewTicker(time.Duration(500) * time.Millisecond)
go func() {
for {
<-Tailer.C
TheTerm.TailCheck()
}
}()
loop:
for {
switch ev := termbox.PollEvent(); ev.Type {
case termbox.EventKey:
switch {
case ev.Key == termbox.KeyEsc || ev.Ch == 'Q' || ev.Ch == 'q':
break loop
case ev.Ch == ' ' || ev.Ch == 'n' || ev.Ch == 'N' || ev.Key == termbox.KeyPgdn || ev.Key == termbox.KeySpace:
TheTerm.NextPage()
case ev.Ch == 'p' || ev.Ch == 'P' || ev.Key == termbox.KeyPgup:
TheTerm.PrevPage()
case ev.Key == termbox.KeyArrowDown:
TheTerm.NextLine()
case ev.Key == termbox.KeyArrowUp:
TheTerm.PrevLine()
case ev.Ch == 'f' || ev.Ch == 'F' || ev.Key == termbox.KeyArrowRight:
TheTerm.ScrollRight()
case ev.Ch == 'b' || ev.Ch == 'B' || ev.Key == termbox.KeyArrowLeft:
TheTerm.ScrollLeft()
case ev.Ch == 'a' || ev.Ch == 'A' || ev.Key == termbox.KeyHome:
TheTerm.Top()
case ev.Ch == 'e' || ev.Ch == 'E' || ev.Key == termbox.KeyEnd:
TheTerm.End()
case ev.Ch == 'w' || ev.Ch == 'W':
TheTerm.FixRight()
case ev.Ch == 's' || ev.Ch == 'S':
TheTerm.FixLeft()
case ev.Ch == 'v' || ev.Ch == 'V':
TheTerm.FilesNext()
case ev.Ch == 'u' || ev.Ch == 'U':
TheTerm.FilesPrev()
case ev.Ch == 'm' || ev.Ch == 'M':
TheTerm.MoreMinLines()
case ev.Ch == 'l' || ev.Ch == 'L':
TheTerm.LessMinLines()
case ev.Ch == 'd' || ev.Ch == 'D':
TheTerm.ToggleNames()
case ev.Ch == 't' || ev.Ch == 'T':
TheTerm.ToggleTail()
case ev.Ch == 'c' || ev.Ch == 'C':
TheTerm.ToggleColNums()
case ev.Ch == 'h' || ev.Ch == 'H':
TheTerm.Help()
}
case termbox.EventResize:
TheTerm.Draw()
}
}
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"embed"
"cogentcore.org/core/content"
"cogentcore.org/core/core"
"cogentcore.org/core/htmlcore"
"cogentcore.org/core/icons"
"cogentcore.org/core/text/csl"
_ "cogentcore.org/core/text/tex" // include this to get math
"cogentcore.org/core/tree"
"cogentcore.org/lab/physics/examples/balls"
"cogentcore.org/lab/physics/examples/collide"
"cogentcore.org/lab/physics/examples/virtroom"
_ "cogentcore.org/lab/yaegilab"
)
// NOTE: you must make a symbolic link to the zotero CCNLab CSL file as ccnlab.json
// in this directory, to generate references and have the generated reference links
// use the official APA style. https://www.zotero.org/groups/340666/ccnlab
// Must configure using BetterBibTeX for zotero: https://retorque.re/zotero-better-bibtex/
//go:generate mdcite -vv -refs ./ccnlab.json -d ./content
//go:embed content citedrefs.json
var econtent embed.FS
func main() {
b := core.NewBody("Cogent Lab")
ct := content.NewContent(b).SetContent(econtent)
ctx := ct.Context
content.OfflineURL = "https://cogentcore.org/lab"
refs, err := csl.OpenFS(econtent, "citedrefs.json")
if err == nil {
ct.References = csl.NewKeyList(refs)
}
ctx.AddWikilinkHandler(htmlcore.GoDocWikilink("doc", "cogentcore.org/lab"))
b.AddTopBar(func(bar *core.Frame) {
tb := core.NewToolbar(bar)
tb.Maker(ct.MakeToolbar)
tb.Maker(func(p *tree.Plan) {
tree.Add(p, func(w *core.Button) {
ctx.LinkButton(w, "https://github.com/cogentcore/lab")
w.SetText("GitHub").SetIcon(icons.GitHub)
})
tree.Add(p, func(w *core.Button) {
ctx.LinkButton(w, "https://youtube.com/@CogentCore")
w.SetText("Videos").SetIcon(icons.VideoLibrary)
})
})
})
ctx.ElementHandlers["physics-balls"] = func(ctx *htmlcore.Context) bool {
balls.Config(ctx.BlockParent)
return true
}
ctx.ElementHandlers["physics-collide"] = func(ctx *htmlcore.Context) bool {
collide.Config(ctx.BlockParent)
return true
}
ctx.ElementHandlers["physics-virtroom"] = func(ctx *htmlcore.Context) bool {
ev := &virtroom.Env{}
ev.Defaults()
ev.ConfigGUI(ctx.BlockParent)
return true
}
b.RunMainWindow()
}
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package baremetal
import (
"time"
)
// this file has the exported API for direct usage,
// which wraps calls in locks.
// OpenLog opens a log file for recording actions.
func (bm *BareMetal) OpenLog(filename string) error {
// todo: openlog file on slog
return nil
}
// StartBGUpdates starts a ticker to update job status periodically.
func (bm *BareMetal) StartBGUpdates() {
bm.ticker = time.NewTicker(10 * time.Second)
go bm.bgLoop()
}
// Submit adds a new Active job with given parameters.
func (bm *BareMetal) Submit(src, path, script, results string, files []byte) *Job {
bm.Lock()
defer bm.Unlock()
return bm.submit(src, path, script, results, files)
}
// JobStatus gets current job data for given job id(s).
// An empty list returns all of the currently Active jobs.
func (bm *BareMetal) JobStatus(ids ...int) []*Job {
bm.Lock()
defer bm.Unlock()
if len(ids) == 0 {
return bm.Active.Values
}
jobs := make([]*Job, 0, len(ids))
for _, id := range ids {
job := bm.job(id)
if job == nil {
continue
}
jobs = append(jobs, job)
}
return jobs
}
// CancelJobs cancels list of job IDs. Returns error for jobs not found.
func (bm *BareMetal) CancelJobs(ids ...int) error {
bm.Lock()
defer bm.Unlock()
return bm.cancelJobs(ids...)
}
// FetchResults gets job results back from server for given job id(s).
// Results are available as job.Results as a compressed tar file.
func (bm *BareMetal) FetchResults(resultsGlob string, ids ...int) ([]*Job, error) {
bm.Lock()
defer bm.Unlock()
return bm.fetchResults(resultsGlob, ids...)
}
// UpdateJobs runs any pending jobs if there are available GPUs to run on.
// returns number of jobs started, and any errors incurred in starting jobs.
func (bm *BareMetal) UpdateJobs() (nrun, nfinished int, err error) {
bm.Lock()
defer bm.Unlock()
nfinished, err = bm.pollJobs()
nrun, err = bm.runPendingJobs()
bm.saveState()
return
}
// RecoverJob reinstates job information so files can be recovered etc.
func (bm *BareMetal) RecoverJob(job *Job) (*Job, error) {
bm.Lock()
defer bm.Unlock()
return bm.recoverJob(job)
}
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package baremetal
//go:generate core generate
import (
"log/slog"
"os"
"os/user"
"path/filepath"
"sync"
"time"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/base/iox/jsonx"
"cogentcore.org/core/base/iox/tomlx"
"cogentcore.org/core/base/keylist"
"cogentcore.org/lab/goal"
"cogentcore.org/lab/goal/goalib"
"cogentcore.org/lab/goal/interpreter"
"github.com/cogentcore/yaegi/interp"
)
// goalrun is needed for running goal commands.
var goalrun *goal.Goal
// BareMetal is the overall bare metal job manager.
type BareMetal struct {
// Servers is the ordered list of server machines.
Servers Servers `json:"-"`
// NextID is the next job ID to assign.
NextID int `edit:"-"`
// Active has all the active (pending, running) jobs being managed,
// in the order submitted.
// The unique key is the bare metal job ID (int).
Active Jobs
// Done has all the completed jobs that have been run.
// This list can be purged by time as needed.
// The unique key is the bare metal job ID (int).
Done Jobs
// interp is the goal interpreter that we exclusively control.
interp *interpreter.Interpreter `json:"-" toml:"-"`
// ticker is the [time.Ticker] used to control the background update loop.
ticker *time.Ticker `json:"-" toml:"-"`
// Lock for responding to inputs.
// everything below top-level input processing is assumed to be locked.
sync.Mutex `json:"-" toml:"-"`
}
// Jobs is the ordered list of jobs, in order submitted.
type Jobs = keylist.List[int, *Job]
// Servers is the ordered list of servers, in order of use preference.
type Servers = keylist.List[string, *Server]
func NewBareMetal() *BareMetal {
bm := &BareMetal{}
return bm
}
// Init does the full initialization of the baremetal server.
func (bm *BareMetal) Init() {
bm.config()
bm.openState()
bm.newInterpreter()
bm.initServers()
}
// todo: should have this logic in fsx presumably
// dataDir returns the app data dir
func (bm *BareMetal) dataDir() string {
usr, err := user.Current()
if errors.Log(err) != nil {
return "/tmp"
}
return filepath.Join(usr.HomeDir, "Library")
}
// config loads a toml format config file from
// TheApp.DataDir()/BareMetal/config.toml to load the servers.
// Use [[Servers.Values]] header for each server.
func (bm *BareMetal) config() {
dir := filepath.Join(bm.dataDir(), "BareMetal")
os.MkdirAll(dir, 0777)
file := filepath.Join(dir, "config.toml")
if !goalib.FileExists(file) {
slog.Error("BareMetal config file not found: no servers will be configured", "File location:", file)
return
}
errors.Log(tomlx.Open(bm, file))
bm.updateServerIndexes()
}
// saveState saves the current active state to a JSON file:
// TheApp.DataDir()/BareMetal/state.json A backup ~ file is
// made of any existing prior to saving.
func (bm *BareMetal) saveState() {
dir := filepath.Join(bm.dataDir(), "BareMetal")
os.MkdirAll(dir, 0777)
file := filepath.Join(dir, "state.json")
bkup := filepath.Join(dir, "state.json~")
bkup2 := filepath.Join(dir, "#state.json#")
if goalib.FileExists(file) {
if goalib.FileExists(bkup) {
os.Rename(bkup, bkup2)
}
os.Rename(file, bkup)
}
err := jsonx.Save(bm, file)
if err != nil {
fsx.CopyFile(bkup, filepath.Join(dir, "state_pre_err.json"), 0666)
panic(err)
}
os.Remove(bkup2)
}
// openState opens the current active state from the file made by SaveState,
// to restore to prior running state.
func (bm *BareMetal) openState() {
dir := filepath.Join(bm.dataDir(), "BareMetal")
file := filepath.Join(dir, "state.json")
if !goalib.FileExists(file) {
return
}
errors.Log(jsonx.Open(bm, file))
bm.updateServerIndexes()
bm.Active.UpdateIndexes()
bm.Done.UpdateIndexes()
bm.setServerUsedFromJobs()
}
// initServers initializes the server state, including opening SSH connections.
func (bm *BareMetal) initServers() {
for _, sv := range bm.Servers.Values {
sv.OpenSSH()
}
goalrun.Run("@0")
}
// updateServerIndexes updates the indexes in the Servers ordered map,
// which is needed after loading new Server config.
func (bm *BareMetal) updateServerIndexes() {
svs := &bm.Servers
svs.Keys = make([]string, len(svs.Values))
for i, v := range svs.Values {
svs.Keys[i] = v.Name
}
svs.UpdateIndexes()
}
// newInterpreter creates a new goal interpreter for exclusive use of bm.
func (bm *BareMetal) newInterpreter() {
in := interpreter.NewInterpreter(interp.Options{})
// has tensor, etc builtin
in.Config()
bm.interp = in
goalrun = in.Goal
}
// Interactive runs the interpreter in interactive mode.
func (bm *BareMetal) Interactive() {
bm.interp.Interactive()
}
// bgLoop is the background update loop
func (bm *BareMetal) bgLoop() {
for {
bm.Lock()
if bm.ticker == nil {
bm.Unlock()
return
}
bm.Unlock()
<-bm.ticker.C
nrun, nfin, err := bm.UpdateJobs()
if err != nil {
errors.Log(err)
} else {
if nrun > 0 || nfin > 0 {
slog.Info("Jobs Updated:", "N Run:", nrun, "N Finished:", nfin)
}
}
}
}
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.3
// protoc v5.29.3
// source: baremetal/baremetal.proto
package baremetal
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
emptypb "google.golang.org/protobuf/types/known/emptypb"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// Status are the job status values.
type Status int32
const (
// NoStatus is the unknown status state.
Status_NoStatus Status = 0
// Pending means the job has been submitted, but not yet run.
Status_Pending Status = 1
// Running means the job is running.
Status_Running Status = 2
// Completed means the job finished on its own, with no error status.
Status_Completed Status = 3
// Canceled means the job was canceled by the user.
Status_Canceled Status = 4
// Errored means the job quit with an error
Status_Errored Status = 5
)
// Enum value maps for Status.
var (
Status_name = map[int32]string{
0: "NoStatus",
1: "Pending",
2: "Running",
3: "Completed",
4: "Canceled",
5: "Errored",
}
Status_value = map[string]int32{
"NoStatus": 0,
"Pending": 1,
"Running": 2,
"Completed": 3,
"Canceled": 4,
"Errored": 5,
}
)
func (x Status) Enum() *Status {
p := new(Status)
*p = x
return p
}
func (x Status) String() string {
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
}
func (Status) Descriptor() protoreflect.EnumDescriptor {
return file_baremetal_baremetal_proto_enumTypes[0].Descriptor()
}
func (Status) Type() protoreflect.EnumType {
return &file_baremetal_baremetal_proto_enumTypes[0]
}
func (x Status) Number() protoreflect.EnumNumber {
return protoreflect.EnumNumber(x)
}
// Deprecated: Use Status.Descriptor instead.
func (Status) EnumDescriptor() ([]byte, []int) {
return file_baremetal_baremetal_proto_rawDescGZIP(), []int{0}
}
// Submission is a job submission.
type Submission struct {
state protoimpl.MessageState `protogen:"open.v1"`
// Source is info about the source of the job, e.g., simrun sim project.
Source string `protobuf:"bytes,1,opt,name=source,proto3" json:"source,omitempty"`
// Path is the path from the SSH home directory to launch the job in.
// This path will be created on the server when the job is run.
Path string `protobuf:"bytes,2,opt,name=path,proto3" json:"path,omitempty"`
// Script is name of the job script to run, which must be at the top level
// within the given tar file.
Script string `protobuf:"bytes,3,opt,name=script,proto3" json:"script,omitempty"`
// ResultsGlob is a glob expression for the result files to get back
// from the server (e.g., *.tsv). job.out is automatically included as well,
// which has the job stdout, stederr output.
ResultsGlob string `protobuf:"bytes,4,opt,name=resultsGlob,proto3" json:"resultsGlob,omitempty"`
// Files is the gzipped tar file of the job files set at submission.
Files []byte `protobuf:"bytes,5,opt,name=files,proto3" json:"files,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Submission) Reset() {
*x = Submission{}
mi := &file_baremetal_baremetal_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Submission) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Submission) ProtoMessage() {}
func (x *Submission) ProtoReflect() protoreflect.Message {
mi := &file_baremetal_baremetal_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Submission.ProtoReflect.Descriptor instead.
func (*Submission) Descriptor() ([]byte, []int) {
return file_baremetal_baremetal_proto_rawDescGZIP(), []int{0}
}
func (x *Submission) GetSource() string {
if x != nil {
return x.Source
}
return ""
}
func (x *Submission) GetPath() string {
if x != nil {
return x.Path
}
return ""
}
func (x *Submission) GetScript() string {
if x != nil {
return x.Script
}
return ""
}
func (x *Submission) GetResultsGlob() string {
if x != nil {
return x.ResultsGlob
}
return ""
}
func (x *Submission) GetFiles() []byte {
if x != nil {
return x.Files
}
return nil
}
// Job is one bare metal job.
type Job struct {
state protoimpl.MessageState `protogen:"open.v1"`
// ID is the overall baremetal unique job ID number.
ID int64 `protobuf:"varint,1,opt,name=ID,proto3" json:"ID,omitempty"`
// Status is the current status of the job.
Status Status `protobuf:"varint,2,opt,name=status,proto3,enum=baremetal.Status" json:"status,omitempty"`
// Source is info about the source of the job, e.g., simrun sim project.
Source string `protobuf:"bytes,3,opt,name=source,proto3" json:"source,omitempty"`
// Path is the path from the SSH home directory to launch the job in.
// This path will be created on the server when the job is run.
Path string `protobuf:"bytes,4,opt,name=path,proto3" json:"path,omitempty"`
// Script is name of the job script to run, which must be at the top level
// within the given tar file.
Script string `protobuf:"bytes,5,opt,name=script,proto3" json:"script,omitempty"`
// Files is the gzipped tar file of the job files set at submission.
Files []byte `protobuf:"bytes,6,opt,name=files,proto3" json:"files,omitempty"`
// ResultsGlob is a glob expression for the result files to get back
// from the server (e.g., *.tsv). job.out is automatically included as well,
// which has the job stdout, stederr output.
ResultsGlob string `protobuf:"bytes,7,opt,name=resultsGlob,proto3" json:"resultsGlob,omitempty"`
// Results is the gzipped tar file of the job result files, gathered
// at completion or when queried for results.
Results []byte `protobuf:"bytes,8,opt,name=results,proto3" json:"results,omitempty"`
// Submit is the time submitted.
Submit *timestamppb.Timestamp `protobuf:"bytes,9,opt,name=submit,proto3" json:"submit,omitempty"`
// Start is the time actually started.
Start *timestamppb.Timestamp `protobuf:"bytes,10,opt,name=start,proto3" json:"start,omitempty"`
// End is the time stopped running.
End *timestamppb.Timestamp `protobuf:"bytes,11,opt,name=end,proto3" json:"end,omitempty"`
// ServerName is the name of the server it is running / ran on. Empty for pending.
ServerName string `protobuf:"bytes,12,opt,name=serverName,proto3" json:"serverName,omitempty"`
// ServerGPU is the logical index of the GPU assigned to this job (0..N-1).
ServerGPU int32 `protobuf:"varint,13,opt,name=serverGPU,proto3" json:"serverGPU,omitempty"`
// PID is the process id of the job script.
PID int64 `protobuf:"varint,14,opt,name=PID,proto3" json:"PID,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Job) Reset() {
*x = Job{}
mi := &file_baremetal_baremetal_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Job) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Job) ProtoMessage() {}
func (x *Job) ProtoReflect() protoreflect.Message {
mi := &file_baremetal_baremetal_proto_msgTypes[1]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Job.ProtoReflect.Descriptor instead.
func (*Job) Descriptor() ([]byte, []int) {
return file_baremetal_baremetal_proto_rawDescGZIP(), []int{1}
}
func (x *Job) GetID() int64 {
if x != nil {
return x.ID
}
return 0
}
func (x *Job) GetStatus() Status {
if x != nil {
return x.Status
}
return Status_NoStatus
}
func (x *Job) GetSource() string {
if x != nil {
return x.Source
}
return ""
}
func (x *Job) GetPath() string {
if x != nil {
return x.Path
}
return ""
}
func (x *Job) GetScript() string {
if x != nil {
return x.Script
}
return ""
}
func (x *Job) GetFiles() []byte {
if x != nil {
return x.Files
}
return nil
}
func (x *Job) GetResultsGlob() string {
if x != nil {
return x.ResultsGlob
}
return ""
}
func (x *Job) GetResults() []byte {
if x != nil {
return x.Results
}
return nil
}
func (x *Job) GetSubmit() *timestamppb.Timestamp {
if x != nil {
return x.Submit
}
return nil
}
func (x *Job) GetStart() *timestamppb.Timestamp {
if x != nil {
return x.Start
}
return nil
}
func (x *Job) GetEnd() *timestamppb.Timestamp {
if x != nil {
return x.End
}
return nil
}
func (x *Job) GetServerName() string {
if x != nil {
return x.ServerName
}
return ""
}
func (x *Job) GetServerGPU() int32 {
if x != nil {
return x.ServerGPU
}
return 0
}
func (x *Job) GetPID() int64 {
if x != nil {
return x.PID
}
return 0
}
// JobList is a list of Jobs.
type JobList struct {
state protoimpl.MessageState `protogen:"open.v1"`
Jobs []*Job `protobuf:"bytes,1,rep,name=jobs,proto3" json:"jobs,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *JobList) Reset() {
*x = JobList{}
mi := &file_baremetal_baremetal_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *JobList) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JobList) ProtoMessage() {}
func (x *JobList) ProtoReflect() protoreflect.Message {
mi := &file_baremetal_baremetal_proto_msgTypes[2]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JobList.ProtoReflect.Descriptor instead.
func (*JobList) Descriptor() ([]byte, []int) {
return file_baremetal_baremetal_proto_rawDescGZIP(), []int{2}
}
func (x *JobList) GetJobs() []*Job {
if x != nil {
return x.Jobs
}
return nil
}
// JobIDList is a list of unique job ID numbers
type JobIDList struct {
state protoimpl.MessageState `protogen:"open.v1"`
JobID []int64 `protobuf:"varint,1,rep,packed,name=jobID,proto3" json:"jobID,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *JobIDList) Reset() {
*x = JobIDList{}
mi := &file_baremetal_baremetal_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *JobIDList) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JobIDList) ProtoMessage() {}
func (x *JobIDList) ProtoReflect() protoreflect.Message {
mi := &file_baremetal_baremetal_proto_msgTypes[3]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JobIDList.ProtoReflect.Descriptor instead.
func (*JobIDList) Descriptor() ([]byte, []int) {
return file_baremetal_baremetal_proto_rawDescGZIP(), []int{3}
}
func (x *JobIDList) GetJobID() []int64 {
if x != nil {
return x.JobID
}
return nil
}
// Error is an error message about an operation.
type Error struct {
state protoimpl.MessageState `protogen:"open.v1"`
Error string `protobuf:"bytes,1,opt,name=error,proto3" json:"error,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Error) Reset() {
*x = Error{}
mi := &file_baremetal_baremetal_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Error) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Error) ProtoMessage() {}
func (x *Error) ProtoReflect() protoreflect.Message {
mi := &file_baremetal_baremetal_proto_msgTypes[4]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Error.ProtoReflect.Descriptor instead.
func (*Error) Descriptor() ([]byte, []int) {
return file_baremetal_baremetal_proto_rawDescGZIP(), []int{4}
}
func (x *Error) GetError() string {
if x != nil {
return x.Error
}
return ""
}
var File_baremetal_baremetal_proto protoreflect.FileDescriptor
var file_baremetal_baremetal_proto_rawDesc = []byte{
0x0a, 0x19, 0x62, 0x61, 0x72, 0x65, 0x6d, 0x65, 0x74, 0x61, 0x6c, 0x2f, 0x62, 0x61, 0x72, 0x65,
0x6d, 0x65, 0x74, 0x61, 0x6c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x09, 0x62, 0x61, 0x72,
0x65, 0x6d, 0x65, 0x74, 0x61, 0x6c, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d,
0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1b, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f,
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x65, 0x6d, 0x70, 0x74, 0x79, 0x2e, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x22, 0x88, 0x01, 0x0a, 0x0a, 0x53, 0x75, 0x62, 0x6d, 0x69, 0x73, 0x73,
0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x18, 0x01, 0x20,
0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70,
0x61, 0x74, 0x68, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12,
0x16, 0x0a, 0x06, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52,
0x06, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x12, 0x20, 0x0a, 0x0b, 0x72, 0x65, 0x73, 0x75, 0x6c,
0x74, 0x73, 0x47, 0x6c, 0x6f, 0x62, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x72, 0x65,
0x73, 0x75, 0x6c, 0x74, 0x73, 0x47, 0x6c, 0x6f, 0x62, 0x12, 0x14, 0x0a, 0x05, 0x66, 0x69, 0x6c,
0x65, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x22,
0xba, 0x03, 0x0a, 0x03, 0x4a, 0x6f, 0x62, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20,
0x01, 0x28, 0x03, 0x52, 0x02, 0x49, 0x44, 0x12, 0x29, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75,
0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x11, 0x2e, 0x62, 0x61, 0x72, 0x65, 0x6d, 0x65,
0x74, 0x61, 0x6c, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74,
0x75, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01,
0x28, 0x09, 0x52, 0x06, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61,
0x74, 0x68, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x16,
0x0a, 0x06, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06,
0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x18,
0x06, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x12, 0x20, 0x0a, 0x0b,
0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x47, 0x6c, 0x6f, 0x62, 0x18, 0x07, 0x20, 0x01, 0x28,
0x09, 0x52, 0x0b, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x47, 0x6c, 0x6f, 0x62, 0x12, 0x18,
0x0a, 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0c, 0x52,
0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x12, 0x32, 0x0a, 0x06, 0x73, 0x75, 0x62, 0x6d,
0x69, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c,
0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73,
0x74, 0x61, 0x6d, 0x70, 0x52, 0x06, 0x73, 0x75, 0x62, 0x6d, 0x69, 0x74, 0x12, 0x30, 0x0a, 0x05,
0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f,
0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69,
0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x2c,
0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f,
0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69,
0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x12, 0x1e, 0x0a, 0x0a,
0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x09,
0x52, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x1c, 0x0a, 0x09,
0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x50, 0x55, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x05, 0x52,
0x09, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x50, 0x55, 0x12, 0x10, 0x0a, 0x03, 0x50, 0x49,
0x44, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x50, 0x49, 0x44, 0x22, 0x2d, 0x0a, 0x07,
0x4a, 0x6f, 0x62, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x22, 0x0a, 0x04, 0x6a, 0x6f, 0x62, 0x73, 0x18,
0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x62, 0x61, 0x72, 0x65, 0x6d, 0x65, 0x74, 0x61,
0x6c, 0x2e, 0x4a, 0x6f, 0x62, 0x52, 0x04, 0x6a, 0x6f, 0x62, 0x73, 0x22, 0x21, 0x0a, 0x09, 0x4a,
0x6f, 0x62, 0x49, 0x44, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x6a, 0x6f, 0x62, 0x49,
0x44, 0x18, 0x01, 0x20, 0x03, 0x28, 0x03, 0x52, 0x05, 0x6a, 0x6f, 0x62, 0x49, 0x44, 0x22, 0x1d,
0x0a, 0x05, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72,
0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x2a, 0x5a, 0x0a,
0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x0c, 0x0a, 0x08, 0x4e, 0x6f, 0x53, 0x74, 0x61,
0x74, 0x75, 0x73, 0x10, 0x00, 0x12, 0x0b, 0x0a, 0x07, 0x50, 0x65, 0x6e, 0x64, 0x69, 0x6e, 0x67,
0x10, 0x01, 0x12, 0x0b, 0x0a, 0x07, 0x52, 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x10, 0x02, 0x12,
0x0d, 0x0a, 0x09, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x10, 0x03, 0x12, 0x0c,
0x0a, 0x08, 0x43, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x65, 0x64, 0x10, 0x04, 0x12, 0x0b, 0x0a, 0x07,
0x45, 0x72, 0x72, 0x6f, 0x72, 0x65, 0x64, 0x10, 0x05, 0x32, 0xcf, 0x02, 0x0a, 0x09, 0x42, 0x61,
0x72, 0x65, 0x4d, 0x65, 0x74, 0x61, 0x6c, 0x12, 0x2f, 0x0a, 0x06, 0x53, 0x75, 0x62, 0x6d, 0x69,
0x74, 0x12, 0x15, 0x2e, 0x62, 0x61, 0x72, 0x65, 0x6d, 0x65, 0x74, 0x61, 0x6c, 0x2e, 0x53, 0x75,
0x62, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x1a, 0x0e, 0x2e, 0x62, 0x61, 0x72, 0x65, 0x6d,
0x65, 0x74, 0x61, 0x6c, 0x2e, 0x4a, 0x6f, 0x62, 0x12, 0x35, 0x0a, 0x09, 0x4a, 0x6f, 0x62, 0x53,
0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x14, 0x2e, 0x62, 0x61, 0x72, 0x65, 0x6d, 0x65, 0x74, 0x61,
0x6c, 0x2e, 0x4a, 0x6f, 0x62, 0x49, 0x44, 0x4c, 0x69, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x62, 0x61,
0x72, 0x65, 0x6d, 0x65, 0x74, 0x61, 0x6c, 0x2e, 0x4a, 0x6f, 0x62, 0x4c, 0x69, 0x73, 0x74, 0x12,
0x34, 0x0a, 0x0a, 0x43, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x4a, 0x6f, 0x62, 0x73, 0x12, 0x14, 0x2e,
0x62, 0x61, 0x72, 0x65, 0x6d, 0x65, 0x74, 0x61, 0x6c, 0x2e, 0x4a, 0x6f, 0x62, 0x49, 0x44, 0x4c,
0x69, 0x73, 0x74, 0x1a, 0x10, 0x2e, 0x62, 0x61, 0x72, 0x65, 0x6d, 0x65, 0x74, 0x61, 0x6c, 0x2e,
0x45, 0x72, 0x72, 0x6f, 0x72, 0x12, 0x38, 0x0a, 0x0c, 0x46, 0x65, 0x74, 0x63, 0x68, 0x52, 0x65,
0x73, 0x75, 0x6c, 0x74, 0x73, 0x12, 0x14, 0x2e, 0x62, 0x61, 0x72, 0x65, 0x6d, 0x65, 0x74, 0x61,
0x6c, 0x2e, 0x4a, 0x6f, 0x62, 0x49, 0x44, 0x4c, 0x69, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x62, 0x61,
0x72, 0x65, 0x6d, 0x65, 0x74, 0x61, 0x6c, 0x2e, 0x4a, 0x6f, 0x62, 0x4c, 0x69, 0x73, 0x74, 0x12,
0x3c, 0x0a, 0x0a, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x4a, 0x6f, 0x62, 0x73, 0x12, 0x16, 0x2e,
0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e,
0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x2c, 0x0a,
0x0a, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x4a, 0x6f, 0x62, 0x12, 0x0e, 0x2e, 0x62, 0x61,
0x72, 0x65, 0x6d, 0x65, 0x74, 0x61, 0x6c, 0x2e, 0x4a, 0x6f, 0x62, 0x1a, 0x0e, 0x2e, 0x62, 0x61,
0x72, 0x65, 0x6d, 0x65, 0x74, 0x61, 0x6c, 0x2e, 0x4a, 0x6f, 0x62, 0x42, 0x31, 0x5a, 0x2f, 0x63,
0x6f, 0x67, 0x65, 0x6e, 0x74, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x6f, 0x72, 0x67, 0x2f, 0x6c, 0x61,
0x62, 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2f, 0x62, 0x61, 0x72, 0x65, 0x6d,
0x65, 0x74, 0x61, 0x6c, 0x2f, 0x62, 0x61, 0x72, 0x65, 0x6d, 0x65, 0x74, 0x61, 0x6c, 0x62, 0x06,
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_baremetal_baremetal_proto_rawDescOnce sync.Once
file_baremetal_baremetal_proto_rawDescData = file_baremetal_baremetal_proto_rawDesc
)
func file_baremetal_baremetal_proto_rawDescGZIP() []byte {
file_baremetal_baremetal_proto_rawDescOnce.Do(func() {
file_baremetal_baremetal_proto_rawDescData = protoimpl.X.CompressGZIP(file_baremetal_baremetal_proto_rawDescData)
})
return file_baremetal_baremetal_proto_rawDescData
}
var file_baremetal_baremetal_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_baremetal_baremetal_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
var file_baremetal_baremetal_proto_goTypes = []any{
(Status)(0), // 0: baremetal.Status
(*Submission)(nil), // 1: baremetal.Submission
(*Job)(nil), // 2: baremetal.Job
(*JobList)(nil), // 3: baremetal.JobList
(*JobIDList)(nil), // 4: baremetal.JobIDList
(*Error)(nil), // 5: baremetal.Error
(*timestamppb.Timestamp)(nil), // 6: google.protobuf.Timestamp
(*emptypb.Empty)(nil), // 7: google.protobuf.Empty
}
var file_baremetal_baremetal_proto_depIdxs = []int32{
0, // 0: baremetal.Job.status:type_name -> baremetal.Status
6, // 1: baremetal.Job.submit:type_name -> google.protobuf.Timestamp
6, // 2: baremetal.Job.start:type_name -> google.protobuf.Timestamp
6, // 3: baremetal.Job.end:type_name -> google.protobuf.Timestamp
2, // 4: baremetal.JobList.jobs:type_name -> baremetal.Job
1, // 5: baremetal.BareMetal.Submit:input_type -> baremetal.Submission
4, // 6: baremetal.BareMetal.JobStatus:input_type -> baremetal.JobIDList
4, // 7: baremetal.BareMetal.CancelJobs:input_type -> baremetal.JobIDList
4, // 8: baremetal.BareMetal.FetchResults:input_type -> baremetal.JobIDList
7, // 9: baremetal.BareMetal.UpdateJobs:input_type -> google.protobuf.Empty
2, // 10: baremetal.BareMetal.RecoverJob:input_type -> baremetal.Job
2, // 11: baremetal.BareMetal.Submit:output_type -> baremetal.Job
3, // 12: baremetal.BareMetal.JobStatus:output_type -> baremetal.JobList
5, // 13: baremetal.BareMetal.CancelJobs:output_type -> baremetal.Error
3, // 14: baremetal.BareMetal.FetchResults:output_type -> baremetal.JobList
7, // 15: baremetal.BareMetal.UpdateJobs:output_type -> google.protobuf.Empty
2, // 16: baremetal.BareMetal.RecoverJob:output_type -> baremetal.Job
11, // [11:17] is the sub-list for method output_type
5, // [5:11] is the sub-list for method input_type
5, // [5:5] is the sub-list for extension type_name
5, // [5:5] is the sub-list for extension extendee
0, // [0:5] is the sub-list for field type_name
}
func init() { file_baremetal_baremetal_proto_init() }
func file_baremetal_baremetal_proto_init() {
if File_baremetal_baremetal_proto != nil {
return
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_baremetal_baremetal_proto_rawDesc,
NumEnums: 1,
NumMessages: 5,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_baremetal_baremetal_proto_goTypes,
DependencyIndexes: file_baremetal_baremetal_proto_depIdxs,
EnumInfos: file_baremetal_baremetal_proto_enumTypes,
MessageInfos: file_baremetal_baremetal_proto_msgTypes,
}.Build()
File_baremetal_baremetal_proto = out.File
file_baremetal_baremetal_proto_rawDesc = nil
file_baremetal_baremetal_proto_goTypes = nil
file_baremetal_baremetal_proto_depIdxs = nil
}
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.5.1
// - protoc v5.29.3
// source: baremetal/baremetal.proto
package baremetal
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
emptypb "google.golang.org/protobuf/types/known/emptypb"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.64.0 or later.
const _ = grpc.SupportPackageIsVersion9
const (
BareMetal_Submit_FullMethodName = "/baremetal.BareMetal/Submit"
BareMetal_JobStatus_FullMethodName = "/baremetal.BareMetal/JobStatus"
BareMetal_CancelJobs_FullMethodName = "/baremetal.BareMetal/CancelJobs"
BareMetal_FetchResults_FullMethodName = "/baremetal.BareMetal/FetchResults"
BareMetal_UpdateJobs_FullMethodName = "/baremetal.BareMetal/UpdateJobs"
BareMetal_RecoverJob_FullMethodName = "/baremetal.BareMetal/RecoverJob"
)
// BareMetalClient is the client API for BareMetal service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type BareMetalClient interface {
Submit(ctx context.Context, in *Submission, opts ...grpc.CallOption) (*Job, error)
JobStatus(ctx context.Context, in *JobIDList, opts ...grpc.CallOption) (*JobList, error)
CancelJobs(ctx context.Context, in *JobIDList, opts ...grpc.CallOption) (*Error, error)
FetchResults(ctx context.Context, in *JobIDList, opts ...grpc.CallOption) (*JobList, error)
UpdateJobs(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*emptypb.Empty, error)
RecoverJob(ctx context.Context, in *Job, opts ...grpc.CallOption) (*Job, error)
}
type bareMetalClient struct {
cc grpc.ClientConnInterface
}
func NewBareMetalClient(cc grpc.ClientConnInterface) BareMetalClient {
return &bareMetalClient{cc}
}
func (c *bareMetalClient) Submit(ctx context.Context, in *Submission, opts ...grpc.CallOption) (*Job, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(Job)
err := c.cc.Invoke(ctx, BareMetal_Submit_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *bareMetalClient) JobStatus(ctx context.Context, in *JobIDList, opts ...grpc.CallOption) (*JobList, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(JobList)
err := c.cc.Invoke(ctx, BareMetal_JobStatus_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *bareMetalClient) CancelJobs(ctx context.Context, in *JobIDList, opts ...grpc.CallOption) (*Error, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(Error)
err := c.cc.Invoke(ctx, BareMetal_CancelJobs_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *bareMetalClient) FetchResults(ctx context.Context, in *JobIDList, opts ...grpc.CallOption) (*JobList, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(JobList)
err := c.cc.Invoke(ctx, BareMetal_FetchResults_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *bareMetalClient) UpdateJobs(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*emptypb.Empty, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(emptypb.Empty)
err := c.cc.Invoke(ctx, BareMetal_UpdateJobs_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *bareMetalClient) RecoverJob(ctx context.Context, in *Job, opts ...grpc.CallOption) (*Job, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(Job)
err := c.cc.Invoke(ctx, BareMetal_RecoverJob_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
// BareMetalServer is the server API for BareMetal service.
// All implementations must embed UnimplementedBareMetalServer
// for forward compatibility.
type BareMetalServer interface {
Submit(context.Context, *Submission) (*Job, error)
JobStatus(context.Context, *JobIDList) (*JobList, error)
CancelJobs(context.Context, *JobIDList) (*Error, error)
FetchResults(context.Context, *JobIDList) (*JobList, error)
UpdateJobs(context.Context, *emptypb.Empty) (*emptypb.Empty, error)
RecoverJob(context.Context, *Job) (*Job, error)
mustEmbedUnimplementedBareMetalServer()
}
// UnimplementedBareMetalServer must be embedded to have
// forward compatible implementations.
//
// NOTE: this should be embedded by value instead of pointer to avoid a nil
// pointer dereference when methods are called.
type UnimplementedBareMetalServer struct{}
func (UnimplementedBareMetalServer) Submit(context.Context, *Submission) (*Job, error) {
return nil, status.Errorf(codes.Unimplemented, "method Submit not implemented")
}
func (UnimplementedBareMetalServer) JobStatus(context.Context, *JobIDList) (*JobList, error) {
return nil, status.Errorf(codes.Unimplemented, "method JobStatus not implemented")
}
func (UnimplementedBareMetalServer) CancelJobs(context.Context, *JobIDList) (*Error, error) {
return nil, status.Errorf(codes.Unimplemented, "method CancelJobs not implemented")
}
func (UnimplementedBareMetalServer) FetchResults(context.Context, *JobIDList) (*JobList, error) {
return nil, status.Errorf(codes.Unimplemented, "method FetchResults not implemented")
}
func (UnimplementedBareMetalServer) UpdateJobs(context.Context, *emptypb.Empty) (*emptypb.Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method UpdateJobs not implemented")
}
func (UnimplementedBareMetalServer) RecoverJob(context.Context, *Job) (*Job, error) {
return nil, status.Errorf(codes.Unimplemented, "method RecoverJob not implemented")
}
func (UnimplementedBareMetalServer) mustEmbedUnimplementedBareMetalServer() {}
func (UnimplementedBareMetalServer) testEmbeddedByValue() {}
// UnsafeBareMetalServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to BareMetalServer will
// result in compilation errors.
type UnsafeBareMetalServer interface {
mustEmbedUnimplementedBareMetalServer()
}
func RegisterBareMetalServer(s grpc.ServiceRegistrar, srv BareMetalServer) {
// If the following call pancis, it indicates UnimplementedBareMetalServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
t.testEmbeddedByValue()
}
s.RegisterService(&BareMetal_ServiceDesc, srv)
}
func _BareMetal_Submit_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(Submission)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BareMetalServer).Submit(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: BareMetal_Submit_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BareMetalServer).Submit(ctx, req.(*Submission))
}
return interceptor(ctx, in, info, handler)
}
func _BareMetal_JobStatus_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(JobIDList)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BareMetalServer).JobStatus(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: BareMetal_JobStatus_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BareMetalServer).JobStatus(ctx, req.(*JobIDList))
}
return interceptor(ctx, in, info, handler)
}
func _BareMetal_CancelJobs_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(JobIDList)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BareMetalServer).CancelJobs(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: BareMetal_CancelJobs_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BareMetalServer).CancelJobs(ctx, req.(*JobIDList))
}
return interceptor(ctx, in, info, handler)
}
func _BareMetal_FetchResults_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(JobIDList)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BareMetalServer).FetchResults(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: BareMetal_FetchResults_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BareMetalServer).FetchResults(ctx, req.(*JobIDList))
}
return interceptor(ctx, in, info, handler)
}
func _BareMetal_UpdateJobs_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(emptypb.Empty)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BareMetalServer).UpdateJobs(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: BareMetal_UpdateJobs_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BareMetalServer).UpdateJobs(ctx, req.(*emptypb.Empty))
}
return interceptor(ctx, in, info, handler)
}
func _BareMetal_RecoverJob_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(Job)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BareMetalServer).RecoverJob(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: BareMetal_RecoverJob_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BareMetalServer).RecoverJob(ctx, req.(*Job))
}
return interceptor(ctx, in, info, handler)
}
// BareMetal_ServiceDesc is the grpc.ServiceDesc for BareMetal service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var BareMetal_ServiceDesc = grpc.ServiceDesc{
ServiceName: "baremetal.BareMetal",
HandlerType: (*BareMetalServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Submit",
Handler: _BareMetal_Submit_Handler,
},
{
MethodName: "JobStatus",
Handler: _BareMetal_JobStatus_Handler,
},
{
MethodName: "CancelJobs",
Handler: _BareMetal_CancelJobs_Handler,
},
{
MethodName: "FetchResults",
Handler: _BareMetal_FetchResults_Handler,
},
{
MethodName: "UpdateJobs",
Handler: _BareMetal_UpdateJobs_Handler,
},
{
MethodName: "RecoverJob",
Handler: _BareMetal_RecoverJob_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "baremetal/baremetal.proto",
}
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package baremetal
import (
"context"
"fmt"
"math"
"time"
"cogentcore.org/core/base/errors"
pb "cogentcore.org/lab/examples/baremetal/baremetal"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
type Client struct {
// The server address including port number.
Host string `default:"localhost:8585"`
Timeout time.Duration
// grpc connection
conn *grpc.ClientConn
client pb.BareMetalClient
}
func NewClient() *Client {
cl := &Client{}
cl.Host = "localhost:8585"
cl.Timeout = 120 * time.Second
return cl
}
// increase max call size for big data
func maxCallRecvMessageSize() grpc.CallOption {
return grpc.MaxCallRecvMsgSize(16 * math.MaxInt32)
}
// increase max call size for big data
func maxCallSendMessageSize() grpc.CallOption {
return grpc.MaxCallSendMsgSize(16 * math.MaxInt32)
}
// Connect connects to the server
func (cl *Client) Connect() error {
conn, err := grpc.NewClient(cl.Host, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return errors.Log(fmt.Errorf("did not connect: %v", err))
}
cl.conn = conn
cl.client = pb.NewBareMetalClient(conn)
return nil
}
// Submit adds a new Active job with given parameters.
func (cl *Client) Submit(source, path, script, resultsGlob string, files []byte) (*Job, error) {
ctx, cancel := context.WithTimeout(context.Background(), cl.Timeout)
defer cancel()
sub := &pb.Submission{Source: source, Path: path, Script: script, ResultsGlob: resultsGlob, Files: files}
job, err := cl.client.Submit(ctx, sub) // todo: not working: maxCallSendMessageSize
if err != nil {
return nil, errors.Log(fmt.Errorf("could not submit: %v", err))
}
return JobFromPB(job), nil
}
// JobStatus gets current job data back from server for given job id(s).
func (cl *Client) JobStatus(ids ...int) ([]*Job, error) {
ctx, cancel := context.WithTimeout(context.Background(), cl.Timeout)
defer cancel()
pids := &pb.JobIDList{JobID: JobIDsToPB(ids)}
jobs, err := cl.client.JobStatus(ctx, pids, maxCallRecvMessageSize())
if err != nil {
return nil, errors.Log(fmt.Errorf("JobStatus failed: %v", err))
}
return JobsFromPB(jobs), nil
}
// CancelJobs cancels list of job IDs. Returns error for jobs not found.
func (cl *Client) CancelJobs(ids ...int) error {
ctx, cancel := context.WithTimeout(context.Background(), cl.Timeout)
defer cancel()
pids := &pb.JobIDList{JobID: JobIDsToPB(ids)}
emsg, err := cl.client.CancelJobs(ctx, pids)
if err != nil {
return errors.Log(fmt.Errorf("CancelJobs failed: %v", err))
}
return errors.New(emsg.Error)
}
// FetchResults gets job results back from server for given job id(s).
// Results are available as job.Results as a compressed tar file.
func (cl *Client) FetchResults(resultsGlob string, ids ...int) ([]*Job, error) {
ctx, cancel := context.WithTimeout(context.Background(), cl.Timeout)
defer cancel()
pids := &pb.JobIDList{JobID: JobIDsToPB(ids)}
jobs, err := cl.client.FetchResults(ctx, pids, maxCallRecvMessageSize())
if err != nil {
return nil, errors.Log(fmt.Errorf("FetchResults failed: %v", err))
}
return JobsFromPB(jobs), nil
}
// UpdateJobs pings the server to run its updates.
// This happens automatically very 10 seconds but this is for the impatient.
func (cl *Client) UpdateJobs() {
return
}
// RecoverJob recovers a job which has been lost somehow.
// It just adds the given job to the job table.
func (cl *Client) RecoverJob(job *Job) (*Job, error) {
ctx, cancel := context.WithTimeout(context.Background(), cl.Timeout)
defer cancel()
pjob := JobToPB(job)
rjob, err := cl.client.RecoverJob(ctx, pjob)
if err != nil {
return nil, errors.Log(fmt.Errorf("RecoverJob failed: %v", err))
}
return JobFromPB(rjob), nil
}
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"context"
"fmt"
"log"
"log/slog"
"net"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/logx"
"cogentcore.org/core/cli"
"cogentcore.org/lab/examples/baremetal"
pb "cogentcore.org/lab/examples/baremetal/baremetal"
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/emptypb"
)
type Config struct {
// The server port number.
Port int `default:"8585"`
}
type server struct {
pb.UnimplementedBareMetalServer
bm *baremetal.BareMetal
}
// Submit adds a new Active job with given parameters.
func (s *server) Submit(_ context.Context, in *pb.Submission) (*pb.Job, error) {
slog.Info("Submitting Job", "Source:", in.Source, "Path:", in.Path)
job := s.bm.Submit(in.Source, in.Path, in.Script, in.ResultsGlob, in.Files)
return baremetal.JobToPB(job), nil
}
// JobStatus gets current job data for given job id(s).
// An empty list returns all of the active jobs.
func (s *server) JobStatus(_ context.Context, in *pb.JobIDList) (*pb.JobList, error) {
slog.Info("JobStatus")
jobs := s.bm.JobStatus(baremetal.JobIDsFromPB(in.JobID)...)
return baremetal.JobsToPB(jobs), nil
}
// CancelJobs cancels list of job IDs. Returns error for jobs not found.
func (s *server) CancelJobs(_ context.Context, in *pb.JobIDList) (*pb.Error, error) {
slog.Info("CancelJobs")
err := s.bm.CancelJobs(baremetal.JobIDsFromPB(in.JobID)...)
estr := ""
if err != nil {
estr = err.Error()
}
return &pb.Error{Error: estr}, nil
}
// FetchResults gets job results back from server for given job id(s).
// Results are available as job.Results as a compressed tar file.
func (s *server) FetchResults(_ context.Context, in *pb.JobIDList) (*pb.JobList, error) {
slog.Info("FetchResults")
jobs, err := s.bm.FetchResults("", baremetal.JobIDsFromPB(in.JobID)...)
errors.Log(err)
return baremetal.JobsToPB(jobs), nil
}
// UpdateJobs pings the server to run its updates.
// This happens automatically very 10 seconds but this is for the impatient.
func (s *server) UpdateJobs(_ context.Context, in *emptypb.Empty) (*emptypb.Empty, error) {
slog.Info("UpdateJobs")
s.bm.UpdateJobs()
return &emptypb.Empty{}, nil
}
// RecoverJob
func (s *server) RecoverJob(_ context.Context, in *pb.Job) (*pb.Job, error) {
slog.Info("RecoverJob")
job, err := s.bm.RecoverJob(baremetal.JobFromPB(in))
errors.Log(err)
return baremetal.JobToPB(job), err
}
func main() {
logx.UserLevel = slog.LevelInfo
opts := cli.DefaultOptions("baremetal", "Bare metal server for job running on bare servers over ssh")
cfg := &Config{}
cli.Run(opts, cfg, Run)
}
func Run(cfg *Config) error {
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", cfg.Port))
if err != nil {
return errors.Log(fmt.Errorf("failed to listen: %v", err))
}
s := grpc.NewServer()
bms := &server{}
bms.bm = baremetal.NewBareMetal()
bms.bm.Init()
bms.bm.StartBGUpdates()
pb.RegisterBareMetalServer(s, bms)
log.Printf("server listening at %v", lis.Addr())
if err := s.Serve(lis); err != nil {
return errors.Log(fmt.Errorf("failed to serve: %v", err))
}
bms.bm.Interactive()
return nil
}
// Code generated by "core generate"; DO NOT EDIT.
package baremetal
import (
"cogentcore.org/core/enums"
)
var _StatusValues = []Status{0, 1, 2, 3, 4, 5}
// StatusN is the highest valid value for type Status, plus one.
const StatusN Status = 6
var _StatusValueMap = map[string]Status{`NoStatus`: 0, `Pending`: 1, `Running`: 2, `Completed`: 3, `Canceled`: 4, `Errored`: 5}
var _StatusDescMap = map[Status]string{0: `NoStatus is the unknown status state.`, 1: `Pending means the job has been submitted, but not yet run.`, 2: `Running means the job is running.`, 3: `Completed means the job finished on its own, with no error status.`, 4: `Canceled means the job was canceled by the user.`, 5: `Errored means the job quit with an error`}
var _StatusMap = map[Status]string{0: `NoStatus`, 1: `Pending`, 2: `Running`, 3: `Completed`, 4: `Canceled`, 5: `Errored`}
// String returns the string representation of this Status value.
func (i Status) String() string { return enums.String(i, _StatusMap) }
// SetString sets the Status value from its string representation,
// and returns an error if the string is invalid.
func (i *Status) SetString(s string) error { return enums.SetString(i, s, _StatusValueMap, "Status") }
// Int64 returns the Status value as an int64.
func (i Status) Int64() int64 { return int64(i) }
// SetInt64 sets the Status value from an int64.
func (i *Status) SetInt64(in int64) { *i = Status(in) }
// Desc returns the description of the Status value.
func (i Status) Desc() string { return enums.Desc(i, _StatusDescMap) }
// StatusValues returns all possible values for the type Status.
func StatusValues() []Status { return _StatusValues }
// Values returns all possible values for the type Status.
func (i Status) Values() []enums.Enum { return enums.Values(_StatusValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i Status) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *Status) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Status") }
// Code generated by "goal build"; DO NOT EDIT.
//line jobs.goal:1
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package baremetal
import (
"bytes"
"fmt"
"log/slog"
"strconv"
"strings"
"time"
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/goal/goalib"
)
// Status are the job status values.
type Status int32 //enums:enum
const (
// NoStatus is the unknown status state.
NoStatus Status = iota
// Pending means the job has been submitted, but not yet run.
Pending
// Running means the job is running.
Running
// Completed means the job finished on its own, with no error status.
Completed
// Canceled means the job was canceled by the user.
Canceled
// Errored means the job quit with an error
Errored
)
// Job is one bare metal job.
type Job struct {
// ID is the overall baremetal unique ID number.
ID int
// Status is the current status of the job.
Status Status
// Source is info about the source of the job, e.g., simrun sim project.
Source string
// Path is the path from the SSH home directory to launch the job in.
// This path will be created on the server when the job is run.
Path string
// Script is name of the job script to run, which must be at the top level
// within the given tar file.
Script string
// Files is the gzipped tar file of the job files set at submission.
Files []byte `display:"-"`
// ResultsGlob is a glob expression for the result files to get back
// from the server (e.g., *.tsv). job.out is automatically included as well,
// which has the job stdout, stederr output.
ResultsGlob string `display:"-"`
// Results is the gzipped tar file of the job result files, gathered
// at completion or when queried for results.
Results []byte `display:"-"`
// Submit is the time submitted.
Submit time.Time
// Start is the time actually started.
Start time.Time
// End is the time stopped running.
End time.Time
// ServerName is the name of the server it is running / ran on. Empty for pending.
ServerName string
// ServerGPU is the logical index of the GPU assigned to this job (0..N-1).
ServerGPU int
// pid is the process id of the job script.
PID int
}
// job returns the Job record for given job number; nil if not found
// in Active or Done;
func (bm *BareMetal) job(jobno int) *Job {
job, ok := bm.Active.AtTry(jobno)
if ok {
return job
}
job, ok = bm.Done.AtTry(jobno)
if ok {
return job
}
return nil
}
// submit adds a new Active job with given parameters.
func (bm *BareMetal) submit(src, path, script, results string, files []byte) *Job {
job := &Job{ID: bm.NextID, Status: Pending, Source: src, Path: path, Script: script, Files: files, ResultsGlob: results, Submit: time.Now(), ServerGPU: -1}
bm.NextID++
bm.Active.Add(job.ID, job)
bm.saveState()
return job
}
// runJob runs the given job on the given server on given gpu number.
func (bm *BareMetal) runJob(job *Job, sv *Server, gpu int) error {
defer func() {
goalrun.Run("cd")
goalrun.Run("@0")
}()
sv.Use()
goalrun.Run("cd")
goalrun.Run("mkdir", "-p", job.Path)
goalrun.Run("cd", job.Path)
sshcl, err := goalrun.SSHByHost(sv.Name)
if errors.Log(err) != nil {
return err
}
b := bytes.NewReader(job.Files)
sz := int64(len(job.Files))
ctx := goalrun.StartContext()
err = sshcl.CopyLocalToHost(ctx, b, sz, "job.files.tar.gz")
goalrun.EndContext()
if errors.Log(err) != nil {
return err
}
goalrun.Run("tar", "-xzf", "job.files.tar.gz")
// set BARE_GPU {gpu}
gpus := strconv.Itoa(gpu)
// $nohup {"./"+job.Script} > job.out 2>&1 & echo "$!" > job.pid $
// note: anything with an & in it just doesn't work on our ssh client, for unknown reasons.
// goalrun.Run("nohup", "./"+job.Script, ">&", "job.out", "&", "echo", "$!", ">", "job.pid")
goalrun.Run("BARE_GPU="+gpus, "nohup", "./"+job.Script)
for range 10 {
if bm.getJobPID(job) {
break
}
time.Sleep(time.Second)
}
job.ServerName = sv.Name
job.ServerGPU = gpu
job.Start = time.Now()
slog.Info("Job running on server", "Job:", job.ID, "Server:", sv.Name)
return nil
}
// getJobPID tries to get the job PID, returning true if obtained.
// Must already be in the ssh and directory for correct server.
func (bm *BareMetal) getJobPID(job *Job) bool {
pids := strings.TrimSpace(goalrun.Output("cat", "job.pid"))
if pids != "" {
pidn, err := strconv.Atoi(pids)
if err == nil {
job.PID = pidn
return true
}
}
return false
}
// runPendingJobs runs any pending jobs if there are available GPUs to run on.
// returns number of jobs started, and any errors incurred in starting jobs.
func (bm *BareMetal) runPendingJobs() (int, error) {
avail := bm.availableGPUs()
if len(avail) == 0 {
return 0, nil
}
nRun := 0
var errs []error
for _, job := range bm.Active.Values {
if job.Status != Pending {
continue
}
fmt.Println("job status:", job.Status, "jobno:", job.ID)
av := avail[0]
sv := bm.Servers.At(av.Name)
next := sv.NextGPU()
for next < 0 {
if len(avail) == 1 {
return nRun, errors.Join(errs...)
}
avail = avail[1:]
av = avail[0]
sv = bm.Servers.At(av.Name)
next = sv.NextGPU()
}
err := bm.runJob(job, sv, next)
if err != nil { // note: errors are server errors, not job errors, so don't affect job status
sv.FreeGPU(next)
errs = append(errs, err)
} else {
job.Status = Running
nRun++
}
}
if nRun > 0 {
bm.saveState()
}
return nRun, errors.Join(errs...)
}
// cancelJobs cancels list of job IDs.
// This is robust to jobs that are not found, and will
// create data for them, as a way of recovering the status of these jobs.
func (bm *BareMetal) cancelJobs(jobs ...int) error {
var errs []error
for _, jid := range jobs {
job, ok := bm.Active.AtTry(jid)
if !ok {
err := errors.Log(fmt.Errorf("CancelJobs: job id not found in Active job list: %d", jid))
errs = append(errs, err)
} else {
err := bm.cancelJob(job)
if err != nil {
errs = append(errs, err)
}
}
}
bm.saveState()
return errors.Join(errs...)
}
// cancelJob cancels the running of the given job (killing process if Running).
func (bm *BareMetal) cancelJob(job *Job) error {
if job.Status == Pending {
job.Status = Canceled
job.End = time.Now()
bm.Done.Add(job.ID, job)
bm.Active.DeleteByKey(job.ID)
return nil
}
sv, err := bm.Server(job.ServerName)
if errors.Log(err) != nil {
return err
}
goalrun.Run("@0")
job.Status = Canceled // always mark job as canceled, even if other stuff fails
bm.jobDone(job, sv)
bm.saveState()
sv.Use()
goalrun.Run("cd")
if job.PID == 0 {
goalrun.RunErrOK("cd", job.Path)
if !bm.getJobPID(job) {
return errors.Log(fmt.Errorf("CancelJob: Job %d PID is 0 and could not get it from job.pid file: must cancel manually", job.ID))
}
}
goalrun.RunErrOK("kill", "-9", job.PID)
goalrun.Run("cd")
goalrun.Run("@0")
return nil
}
// pollJobs checks to see if any running jobs have finished.
// Returns number of jobs that finished.
func (bm *BareMetal) pollJobs() (int, error) {
nDone := 0
njobs := bm.Active.Len()
goalrun.Run("@0")
// todo: this screws up parsing:
var errs []error
for ji := njobs - 1; ji >= 0; ji-- { // reverse b/c moves jobs to Done
job := bm.Active.Values[ji]
// fmt.Println("job status:", job.Status, "jobno:", job.ID)
if job.Status != Pending && job.Status != Running { // stray job for active
job.Status = Completed
bm.Done.Add(job.ID, job)
bm.Active.DeleteByKey(job.ID)
nDone++
continue
}
if job.Status != Running {
continue
}
sv, err := bm.Server(job.ServerName)
if errors.Log(err) != nil {
errs = append(errs, err)
continue
}
sv.Use()
if job.PID == 0 {
goalrun.Run("cd")
goalrun.RunErrOK("cd", job.Path)
if !bm.getJobPID(job) {
err := fmt.Errorf("PollJobs: Job %d PID is 0 and could not get it from job.pid file: must cancel manually", job.ID)
errs = append(errs, err)
goalrun.Run("cd")
job.Status = Completed
bm.jobDone(job, sv)
nDone++
}
goalrun.Run("cd")
}
goalrun.Run("cd")
// psout := $ps -p {job.PID} >/dev/null; echo "$?"$ // todo: don't parse ; ourselves!
psout := strings.TrimSpace(goalrun.Output("ps", "-p", job.PID, ">", "/dev/null", ";", "echo", "$?"))
// fmt.Println("status:", psout)
if psout == "1" {
job.Status = Completed
bm.fetchResultsJob(job, sv)
bm.jobDone(job, sv)
nDone++
}
}
goalrun.Run("@0")
if nDone > 0 {
bm.saveState()
}
return nDone, errors.Join(errs...)
}
// jobDone sets job to be completed and moves to Done category.
func (bm *BareMetal) jobDone(job *Job, sv *Server) {
job.End = time.Now()
if job.ServerGPU >= 0 {
sv.FreeGPU(job.ServerGPU)
}
bm.Done.Add(job.ID, job)
bm.Active.DeleteByKey(job.ID)
}
// fetchResults gets job results back from server for given job id(s).
// Results are available as job.Results as a compressed tar file.
func (bm *BareMetal) fetchResults(resultsGlob string, ids ...int) ([]*Job, error) {
var errs []error
var jobs []*Job
for _, id := range ids {
job := bm.job(id)
if job == nil {
errs = append(errs, fmt.Errorf("FetchResults: job id not found: %d", id))
continue
}
sv, err := bm.Server(job.ServerName)
if err != nil {
errs = append(errs, err)
continue
}
if resultsGlob != "" {
job.ResultsGlob = resultsGlob
}
err = bm.fetchResultsJob(job, sv)
if err != nil {
errs = append(errs, err)
} else {
jobs = append(jobs, job)
}
}
return jobs, errors.Join(errs...)
}
// fetchResultsJob gets job results back from server.
func (bm *BareMetal) fetchResultsJob(job *Job, sv *Server) error {
defer func() {
goalrun.Run("cd")
goalrun.Run("@0")
}()
sv.Use()
goalrun.Run("cd")
goalrun.Run("cd", job.Path)
jglob := job.ResultsGlob
if !strings.Contains(jglob, "job.label") {
jglob += " job.label"
}
res := goalrun.OutputErrOK("/bin/ls", "-1", jglob)
if strings.Contains(res, "No such file") {
res = ""
}
ress := goalib.SplitLines(res)
fmt.Println("results for:", jglob, ress)
goalrun.Run("tar", "-czf", "job.results.tar.gz", "job.out", "nohup.out", ress)
var b bytes.Buffer
sshcl, err := goalrun.SSHByHost(sv.Name)
if errors.Log(err) != nil {
return err
}
ctx := goalrun.StartContext()
err = sshcl.CopyHostToLocal(ctx, "job.results.tar.gz", &b)
goalrun.EndContext()
if errors.Log(err) != nil {
return err
}
job.Results = b.Bytes()
return nil
}
// setServerUsedFromJobs is called at startup to set the server Used status
// based on the current Active jobs, loaded from State.
func (bm *BareMetal) setServerUsedFromJobs() error {
for _, sv := range bm.Servers.Values {
sv.Used = make(map[int]bool)
}
var errs []error
for _, job := range bm.Active.Values {
if job.Status != Running {
continue
}
sv, err := bm.Server(job.ServerName)
if errors.Log(err) != nil {
errs = append(errs, err)
continue
}
sv.Used[job.ServerGPU] = true
}
return errors.Join(errs...)
}
// recoverJob attempts to recover the job information
// from given job, using remote or local data.
func (bm *BareMetal) recoverJob(job *Job) (*Job, error) {
// do we need to do anything actually?
bm.Active.Add(job.ID, job)
bm.saveState()
return job, nil
}
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package baremetal
//go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative baremetal/baremetal.proto
import (
pb "cogentcore.org/lab/examples/baremetal/baremetal"
"google.golang.org/protobuf/types/known/timestamppb"
)
// JobToPB returns the protobuf version of given Job.
func JobToPB(job *Job) *pb.Job {
pj := &pb.Job{ID: int64(job.ID), Status: pb.Status(job.Status), Source: job.Source, Path: job.Path, Script: job.Script, Files: job.Files, ResultsGlob: job.ResultsGlob, Results: job.Results, ServerName: job.ServerName, ServerGPU: int32(job.ServerGPU), PID: int64(job.PID)}
pj.Submit = timestamppb.New(job.Submit)
pj.Start = timestamppb.New(job.Start)
pj.End = timestamppb.New(job.End)
return pj
}
// JobFromPB returns a Job based on the protobuf version.
func JobFromPB(job *pb.Job) *Job {
bj := &Job{ID: int(job.ID), Status: Status(job.Status), Source: job.Source, Path: job.Path, Script: job.Script, Files: job.Files, ResultsGlob: job.ResultsGlob, Results: job.Results, ServerName: job.ServerName, ServerGPU: int(job.ServerGPU), PID: int(job.PID)}
bj.Submit = job.Submit.AsTime()
bj.Start = job.Start.AsTime()
bj.End = job.End.AsTime()
return bj
}
// JobsToPB returns the protobuf version of given Jobs list.
func JobsToPB(jobs []*Job) *pb.JobList {
pjs := make([]*pb.Job, len(jobs))
for i, job := range jobs {
pjs[i] = JobToPB(job)
}
return &pb.JobList{Jobs: pjs}
}
// JobsFromPB returns Jobs from the protobuf version of given Jobs list.
func JobsFromPB(pjs *pb.JobList) []*Job {
jobs := make([]*Job, len(pjs.Jobs))
for i, pj := range pjs.Jobs {
jobs[i] = JobFromPB(pj)
}
return jobs
}
// JobIDsToPB returns job id numbers as int64 for pb.JobIDList
func JobIDsToPB(ids []int) []int64 {
i64 := make([]int64, len(ids))
for i, id := range ids {
i64[i] = int64(id)
}
return i64
}
// JobIDsFromPB returns job id numbers from int64 in pb.JobIDList
func JobIDsFromPB(ids []int64) []int {
is := make([]int, len(ids))
for i, id := range ids {
is[i] = int(id)
}
return is
}
// Code generated by "goal build"; DO NOT EDIT.
//line server.goal:1
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package baremetal
import "fmt"
// Server specifies a bare metal Server.
type Server struct {
// Name is the alias used for gossh.
Name string
// SSH is string to gossh to.
SSH string
// NGPUs is the number of GPUs on this server.
NGPUs int
// Used is a map of GPUs current being used.
Used map[int]bool `edit:"-" toml:"-"`
}
// ServerAvail is used to report the number of available gpus per server.
type ServerAvail struct {
Name string
Avail int
}
// OpenSSH opens the SSH connection for this server.
func (sv *Server) OpenSSH() {
if sv.Used == nil {
sv.Used = make(map[int]bool)
}
goalrun.Run("gossh", sv.SSH, sv.Name)
}
// ID returns the server SSH ID string: @Name
func (sv *Server) ID() string {
return "@" + sv.Name
}
// Use makes this the active server.
func (sv *Server) Use() {
goalrun.Run(sv.ID())
}
// Avail returns the number of servers available.
func (sv *Server) Avail() int {
return sv.NGPUs - len(sv.Used)
}
// Server provides error-wrapped access to Servers by name.
func (bm *BareMetal) Server(name string) (*Server, error) {
sv, ok := bm.Servers.AtTry(name)
if !ok {
return nil, fmt.Errorf("BareMetal:Server name not found: %q", name)
}
return sv, nil
}
// availableGPUs returns the number of GPUs available on servers.
// Only includes servers with availability (nil if nothing avail).
func (bm *BareMetal) availableGPUs() []ServerAvail {
var avail []ServerAvail
for _, sv := range bm.Servers.Values {
na := sv.Avail()
if na > 0 {
avail = append(avail, ServerAvail{Name: sv.Name, Avail: na})
}
}
return avail
}
// NextGPU returns the next GPU index available,
// and adds it to the Used list. Returns -1 if none available.
func (sv *Server) NextGPU() int {
for i := range sv.NGPUs {
_, used := sv.Used[i]
if !used {
sv.Used[i] = true
return i
}
}
return -1
}
// FreeGPU makes the given GPU number available.
func (sv *Server) FreeGPU(gpu int) {
delete(sv.Used, gpu)
}
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package baremetal
import (
"archive/tar"
"compress/gzip"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"cogentcore.org/core/base/errors"
)
// AllFiles returns all file names within given directory, including subdirectory,
// excluding those matching given glob expressions. Files are relative to dir,
// and do not include the full path.
func AllFiles(dir string, exclude ...string) ([]string, error) {
var files []string
err := filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if !d.Type().IsRegular() {
return nil
}
for _, ex := range exclude {
if errors.Log1(filepath.Match(ex, path)) {
return nil
}
}
files = append(files, path)
return nil
})
return files, err
}
// note: Tar code helped significantly by Steve Domino examples:
// https://medium.com/@skdomino/taring-untaring-files-in-go-6b07cf56bc07
// TarFiles writes a tar file to given writer, from given source directory.
// Tar file names are as listed here so it will unpack directly to those files.
// If gz is true, then tar is gzipped.
func TarFiles(w io.Writer, dir string, gz bool, files ...string) error {
// ensure the src actually exists before trying to tar it
if _, err := os.Stat(dir); err != nil {
return fmt.Errorf("TarFiles: directory not accessible: %s", err.Error())
}
ow := w
if gz {
gzw := gzip.NewWriter(w)
defer gzw.Close()
ow = gzw
}
tw := tar.NewWriter(ow)
defer tw.Close()
var errs []error
for _, fn := range files {
fname := filepath.Join(dir, fn)
fi, err := os.Stat(fname)
if err != nil {
errs = append(errs, err)
continue
}
hdr, err := tar.FileInfoHeader(fi, fi.Name())
if err != nil {
errs = append(errs, err)
continue
}
hdr.Name = fn
if err := tw.WriteHeader(hdr); err != nil {
errs = append(errs, err)
break
}
f, err := os.Open(fname)
if err != nil {
errs = append(errs, err)
continue
}
if _, err := io.Copy(tw, f); err != nil {
errs = append(errs, err)
break
}
f.Close()
}
return errors.Join(errs...)
}
// Untar extracts a tar file from given reader, into given source directory.
// If gz is true, then tar is gzipped.
func Untar(r io.Reader, dir string, gz bool) error {
or := r
if gz {
gzr, err := gzip.NewReader(r)
if err != nil {
return err
}
or = gzr
defer gzr.Close()
}
tr := tar.NewReader(or)
var errs []error
addErr := func(err error) error { // if != nil, return
if err == nil {
return nil
}
errs = append(errs, err)
if len(errs) > 10 {
return errors.Join(errs...)
}
return nil
}
for {
hdr, err := tr.Next()
switch {
case err == io.EOF:
return errors.Join(errs...)
case err != nil:
if allErr := addErr(err); allErr != nil {
return allErr
}
continue
case hdr == nil:
continue
}
fn := filepath.Join(dir, hdr.Name)
switch hdr.Typeflag {
case tar.TypeDir:
err := os.MkdirAll(fn, 0755)
if allErr := addErr(err); allErr != nil {
return allErr
}
case tar.TypeReg:
f, err := os.OpenFile(fn, os.O_CREATE|os.O_RDWR|os.O_TRUNC, os.FileMode(hdr.Mode))
if allErr := addErr(err); allErr != nil {
return allErr
}
_, err = io.Copy(f, tr)
f.Close()
if allErr := addErr(err); allErr != nil {
return allErr
}
}
}
return errors.Join(errs...)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
//go:generate core generate
import (
"cogentcore.org/core/cli"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/icons"
"cogentcore.org/core/tree"
"cogentcore.org/core/yaegicore/coresymbols"
"cogentcore.org/lab/goal/interpreter"
"cogentcore.org/lab/lab"
_ "cogentcore.org/lab/lab/labscripts"
"cogentcore.org/lab/tensorfs"
"cogentcore.org/lab/yaegilab/labsymbols"
)
// important: must be run from an interactive terminal.
// Will quit immediately if not!
func main() {
tensorfs.Mkdir("Data")
opts := cli.DefaultOptions("basic", "basic Cogent Lab browser.")
cfg := &interpreter.Config{}
cfg.InteractiveFunc = Interactive
cli.Run(opts, cfg, interpreter.Run, interpreter.Build)
}
func Interactive(c *interpreter.Config, in *interpreter.Interpreter) error {
b, br := lab.NewBasicWindow(tensorfs.CurRoot, "Data")
br.Interpreter = in
in.Interp.Use(coresymbols.Symbols)
in.Interp.Use(labsymbols.Symbols)
in.Config()
b.AddTopBar(func(bar *core.Frame) {
tb := core.NewToolbar(bar)
// tb.Maker(tbv.MakeToolbar)
tb.Maker(func(p *tree.Plan) {
tree.Add(p, func(w *core.Button) {
w.SetText("README").SetIcon(icons.FileMarkdown).
SetTooltip("open README help file").OnClick(func(e events.Event) {
core.TheApp.OpenURL("https://github.com/cogentcore/lab/blob/main/examples/basic/README.md")
})
})
})
})
b.OnShow(func(e events.Event) {
go func() {
if c.Expr != "" {
in.Eval(c.Expr)
}
in.Interactive()
}()
})
b.RunWindow()
core.Wait()
return nil
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"embed"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/core"
"cogentcore.org/lab/plotcore"
"cogentcore.org/lab/stats/cluster"
"cogentcore.org/lab/stats/metric"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorcore"
)
//go:embed *.tsv
var tsv embed.FS
func main() {
pats := table.New("TrainPats")
metadata.SetDoc(pats, "Training patterns")
// todo: meta data for grid size
errors.Log(pats.OpenFS(tsv, "random_5x5_25.tsv", tensor.Tab))
b := core.NewBody("grids")
tv := core.NewTabs(b)
nt, _ := tv.NewTab("Patterns")
etv := tensorcore.NewTable(nt)
tensorcore.AddGridStylerTo(pats, func(s *tensorcore.GridStyle) {
s.TotalSize = 200
})
etv.SetTable(pats)
b.AddTopBar(func(bar *core.Frame) {
core.NewToolbar(bar).Maker(etv.MakeToolbar)
})
lt, _ := tv.NewTab("Labels")
gv := tensorcore.NewTensorGrid(lt)
tsr := pats.Column("Input").RowTensor(0).Clone()
tensorcore.AddGridStylerTo(tsr, func(s *tensorcore.GridStyle) {
s.ColumnRotation = 45
})
gv.SetTensor(tsr)
gv.RowLabels = []string{"Row 0", "Row 1,2", "", "Row 3", "Row 4"}
gv.ColumnLabels = []string{"Col 0,1", "", "Col 2", "Col 3", "Col 4"}
ct, _ := tv.NewTab("Cluster")
ctb := core.NewToolbar(ct)
plt := plotcore.NewEditor(ct)
ctb.Maker(plt.MakeToolbar)
cluster.PlotFromTable(plt, pats, metric.MetricL2Norm, cluster.Min, "Input", "Name")
b.RunMainWindow()
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"cogentcore.org/core/cli"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/icons"
"cogentcore.org/core/tree"
"cogentcore.org/core/yaegicore/coresymbols"
"cogentcore.org/lab/goal/interpreter"
"cogentcore.org/lab/lab"
_ "cogentcore.org/lab/lab/labscripts"
"cogentcore.org/lab/matrix"
"cogentcore.org/lab/stats/metric"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorfs"
"cogentcore.org/lab/yaegilab/labsymbols"
)
// important: must be run from an interactive terminal.
// Will quit immediately if not!
func main() {
opts := cli.DefaultOptions("iris", "interactive data analysis.")
cfg := &interpreter.Config{}
cfg.InteractiveFunc = Interactive
cli.Run(opts, cfg, interpreter.Run, interpreter.Build)
}
func Interactive(c *interpreter.Config, in *interpreter.Interpreter) error {
dir := tensorfs.Mkdir("Iris")
b, br := lab.NewBasicWindow(tensorfs.CurRoot, "Iris")
in.Interp.Use(coresymbols.Symbols)
in.Interp.Use(labsymbols.Symbols)
in.Config()
br.Interpreter = in
b.AddTopBar(func(bar *core.Frame) {
tb := core.NewToolbar(bar)
// tb.Maker(tbv.MakeToolbar)
tb.Maker(func(p *tree.Plan) {
tree.Add(p, func(w *core.Button) {
w.SetText("README").SetIcon(icons.FileMarkdown).
SetTooltip("open README help file").OnClick(func(e events.Event) {
core.TheApp.OpenURL("https://github.com/cogentcore/lab/blob/main/examples/pca/README.md")
})
})
})
})
b.OnShow(func(e events.Event) {
go func() {
if c.Expr != "" {
in.Eval(c.Expr)
}
AnalyzeIris(dir, br)
in.Interactive()
}()
})
b.RunWindow()
core.Wait()
return nil
}
func AnalyzeIris(dir *tensorfs.Node, br *lab.Basic) {
dt := table.New("iris")
err := dt.OpenCSV("iris.data", tensor.Comma)
if err != nil {
fmt.Println(err)
return
}
ddir := dir.Dir("Data")
tensorfs.DirFromTable(ddir, dt)
ped := br.Tabs.PlotTable("Iris", dt)
_ = ped
cdt := table.New()
cdt.AddFloat64Column("data", 4)
cdt.AddStringColumn("class")
err = cdt.OpenCSV("iris_nohead.data", tensor.Comma)
if err != nil {
fmt.Println(err)
return
}
data := cdt.Column("data")
covar := tensor.NewFloat64()
err = metric.CovarianceMatrixOut(metric.Correlation, data, covar)
cvg := br.Tabs.TensorGrid("Covar", covar)
_ = cvg
vecs, _ := matrix.EigSym(covar)
pcdir := dir.Dir("PCA")
tensorfs.SetTensor(pcdir, tensor.Reslice(vecs, 3, tensor.FullAxis), "pc0")
tensorfs.SetTensor(pcdir, tensor.Reslice(vecs, 2, tensor.FullAxis), "pc1")
colidx := tensor.NewFloat64Scalar(3) // strongest at end
prjn0 := tensor.NewFloat64()
matrix.ProjectOnMatrixColumnOut(vecs, data, colidx, prjn0)
pjdir := dir.Dir("Prjn")
tensorfs.SetTensor(pjdir, prjn0, "pc0")
colidx = tensor.NewFloat64Scalar(2)
prjn1 := tensor.NewFloat64()
matrix.ProjectOnMatrixColumnOut(vecs, data, colidx, prjn1)
tensorfs.SetTensor(pjdir, prjn1, "pc1")
tensorfs.SetTensor(pjdir, dt.Column("Name"), "name")
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"embed"
"math"
"cogentcore.org/core/cli"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/icons"
"cogentcore.org/core/tree"
"cogentcore.org/core/yaegicore/coresymbols"
"cogentcore.org/lab/goal/interpreter"
"cogentcore.org/lab/lab"
_ "cogentcore.org/lab/lab/labscripts"
"cogentcore.org/lab/stats/stats"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorfs"
"cogentcore.org/lab/yaegilab/labsymbols"
)
//go:embed *.csv
var csv embed.FS
// AnalyzePlanets analyzes planets.csv data following some of the examples
// in pandas from:
// https://jakevdp.github.io/PythonDataScienceHandbook/03.08-aggregation-and-grouping.html
func AnalyzePlanets(dir *tensorfs.Node) {
Planets := table.New("planets")
Planets.OpenFS(csv, "planets.csv", tensor.Comma)
vals := []string{"number", "orbital_period", "mass", "distance", "year"}
stats.DescribeTable(dir, Planets, vals...)
decade := Planets.AddFloat64Column("decade")
year := Planets.Column("year")
for row := range Planets.NumRows() {
yr := year.FloatRow(row, 0)
dec := math.Floor(yr/10) * 10
decade.SetFloatRow(dec, row, 0)
}
stats.TableGroups(dir, Planets, "method", "decade")
stats.TableGroupDescribe(dir, Planets, vals...)
// byMethod := split.GroupBy(PlanetsAll, "method")
// split.AggColumn(byMethod, "orbital_period", stats.Median)
// GpMethodOrbit = byMethod.AggsToTable(table.AddAggName)
// byMethod.DeleteAggs()
// split.DescColumn(byMethod, "year") // full desc stats of year
// byMethod.Filter(func(idx int) bool {
// ag := errors.Log1(byMethod.AggByColumnName("year:Std"))
// return ag.Aggs[idx][0] > 0 // exclude results with 0 std
// })
// GpMethodYear = byMethod.AggsToTable(table.AddAggName)
// split.AggColumn(byMethodDecade, "number", stats.Sum)
// uncomment this to switch to decade first, then method
// byMethodDecade.ReorderLevels([]int{1, 0})
// byMethodDecade.SortLevels()
// decadeOnly := errors.Log1(byMethodDecade.ExtractLevels([]int{1}))
// split.AggColumn(decadeOnly, "number", stats.Sum)
// GpDecade = decadeOnly.AggsToTable(table.AddAggName)
//
// GpMethodDecade = byMethodDecade.AggsToTable(table.AddAggName) // here to ensure that decadeOnly didn't mess up..
// todo: need unstack -- should be specific to the splits data because we already have the cols and
// groups etc -- the ExtractLevels method provides key starting point.
// todo: pivot table -- neeeds unstack function.
// todo: could have a generic unstack-like method that takes a column for the data to turn into columns
// and another that has the data to put in the cells.
}
// important: must be run from an interactive terminal.
// Will quit immediately if not!
func main() {
dir := tensorfs.Mkdir("Planets")
AnalyzePlanets(dir)
opts := cli.DefaultOptions("planets", "interactive data analysis.")
cfg := &interpreter.Config{}
cfg.InteractiveFunc = Interactive
cli.Run(opts, cfg, interpreter.Run, interpreter.Build)
}
func Interactive(c *interpreter.Config, in *interpreter.Interpreter) error {
b, _ := lab.NewBasicWindow(tensorfs.CurRoot, "Planets")
in.Interp.Use(coresymbols.Symbols)
in.Interp.Use(labsymbols.Symbols)
in.Config()
b.AddTopBar(func(bar *core.Frame) {
tb := core.NewToolbar(bar)
// tb.Maker(tbv.MakeToolbar)
tb.Maker(func(p *tree.Plan) {
tree.Add(p, func(w *core.Button) {
w.SetText("README").SetIcon(icons.FileMarkdown).
SetTooltip("open README help file").OnClick(func(e events.Event) {
core.TheApp.OpenURL("https://github.com/cogentcore/lab/blob/main/examples/planets/README.md")
})
})
})
})
b.OnShow(func(e events.Event) {
go func() {
if c.Expr != "" {
in.Eval(c.Expr)
}
in.Interactive()
}()
})
b.RunWindow()
core.Wait()
return nil
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"embed"
"cogentcore.org/core/core"
"cogentcore.org/lab/plot"
"cogentcore.org/lab/plotcore"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
)
//go:embed *.tsv
var tsv embed.FS
func main() {
b := core.NewBody("Plot Example")
epc := table.New("epc")
epc.OpenFS(tsv, "ra25epoch.tsv", tensor.Tab)
pst := func(s *plot.Style) {
s.Plot.Title = "RA25 Epoch Train"
}
perr := epc.Column("PctErr")
plot.SetStyler(perr, pst, func(s *plot.Style) {
s.On = true
s.Role = plot.Y
})
pl := plotcore.NewEditor(b)
pl.SetTable(epc)
b.AddTopBar(func(bar *core.Frame) {
core.NewToolbar(bar).Maker(pl.MakeToolbar)
})
b.RunMainWindow()
}
// Copyright (c) 2020, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package random plots histograms of random distributions.
package main
//go:generate core generate
import (
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/icons"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/core/styles"
"cogentcore.org/core/tree"
"cogentcore.org/lab/base/randx"
"cogentcore.org/lab/plot"
"cogentcore.org/lab/plot/plots"
"cogentcore.org/lab/plotcore"
"cogentcore.org/lab/stats/histogram"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
)
// Random is the random distribution plotter widget.
type Random struct {
core.Frame
Data
}
// Data contains the random distribution plotter data and options.
type Data struct { //types:add
// random params
Dist randx.RandParams `display:"add-fields"`
// number of samples
NumSamples int
// number of bins in the histogram
NumBins int
// range for histogram
Range minmax.F64
// table for raw data
Table *table.Table `display:"no-inline"`
// histogram of data
Histogram *table.Table `display:"no-inline"`
// the plot
plot *plotcore.Editor `display:"-"`
}
// logPrec is the precision for saving float values in logs.
const logPrec = 4
func (rd *Random) Init() {
rd.Frame.Init()
rd.Dist.Defaults()
rd.Dist.Dist = randx.Gaussian
rd.Dist.Mean = 0.5
rd.Dist.Var = 0.15
rd.NumSamples = 1000000
rd.NumBins = 100
rd.Range.Set(0, 1)
rd.Table = table.New()
rd.Histogram = table.New()
rd.ConfigTable(rd.Table)
rd.Plot()
rd.Styler(func(s *styles.Style) {
s.Grow.Set(1, 1)
})
tree.AddChild(rd, func(w *core.Splits) {
w.SetSplits(0.3, 0.7)
tree.AddChild(w, func(w *core.Form) {
w.SetStruct(&rd.Data)
w.OnChange(func(e events.Event) {
rd.Plot()
})
})
tree.AddChild(w, func(w *plotcore.Editor) {
w.SetTable(rd.Histogram)
rd.plot = w
})
})
}
// Plot generates the data and plots a histogram of results.
func (rd *Random) Plot() { //types:add
dt := rd.Table
dt.SetNumRows(rd.NumSamples)
for vi := 0; vi < rd.NumSamples; vi++ {
vl := rd.Dist.Gen()
dt.Column("Value").SetFloat(float64(vl), vi)
}
histogram.F64Table(rd.Histogram, dt.Columns.Values[0].(*tensor.Float64).Values, rd.NumBins, rd.Range.Min, rd.Range.Max)
if rd.plot != nil {
rd.plot.UpdatePlot()
}
}
func (rd *Random) ConfigTable(dt *table.Table) {
metadata.SetName(dt, "Data")
// metadata.SetReadOnly(dt, true)
tensor.SetPrecision(dt, logPrec)
val := dt.AddFloat64Column("Value")
plot.SetStyler(val, func(s *plot.Style) {
s.Role = plot.X
s.Plotter = plots.BarType
s.Plot.XAxis.Rotation = 45
s.Plot.Title = "Random distribution histogram"
})
}
func (rd *Random) MakeToolbar(p *tree.Plan) {
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(rd.Plot).SetIcon(icons.ScatterPlot)
})
tree.Add(p, func(w *core.Separator) {})
if rd.plot != nil {
rd.plot.MakeToolbar(p)
}
}
func main() {
b := core.NewBody("Random numbers")
rd := NewRandom(b)
b.AddTopBar(func(bar *core.Frame) {
core.NewToolbar(bar).Maker(rd.MakeToolbar)
})
b.RunMainWindow()
}
// Code generated by "core generate"; DO NOT EDIT.
package main
import (
"cogentcore.org/core/tree"
"cogentcore.org/core/types"
)
var _ = types.AddType(&types.Type{Name: "main.Random", IDName: "random", Doc: "Random is the random distribution plotter widget.", Methods: []types.Method{{Name: "Plot", Doc: "Plot generates the data and plots a histogram of results.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}}, Embeds: []types.Field{{Name: "Frame"}, {Name: "Data"}}})
// NewRandom returns a new [Random] with the given optional parent:
// Random is the random distribution plotter widget.
func NewRandom(parent ...tree.Node) *Random { return tree.New[Random](parent...) }
var _ = types.AddType(&types.Type{Name: "main.Data", IDName: "data", Doc: "Data contains the random distribution plotter data and options.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Fields: []types.Field{{Name: "Dist", Doc: "random params"}, {Name: "NumSamples", Doc: "number of samples"}, {Name: "NumBins", Doc: "number of bins in the histogram"}, {Name: "Range", Doc: "range for histogram"}, {Name: "Table", Doc: "table for raw data"}, {Name: "Histogram", Doc: "histogram of data"}, {Name: "plot", Doc: "the plot"}}})
// Code generated by "goal build"; DO NOT EDIT.
//line bare.goal:1
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"bytes"
"fmt"
"io"
"os"
"strconv"
"strings"
"time"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/core"
"cogentcore.org/lab/examples/baremetal"
"cogentcore.org/lab/goal/goalib"
)
// bare supports the baremetal platform without slurm or other job management infra.
// SubmitBare submits a bare metal run job, returning the pid of the resulting process.
func (sr *Simmer) SubmitBare(jid, args string) string {
goalrun.Run("@0")
script := "job.sbatch"
f, _ := os.Create(script)
sr.WriteBare(f, jid, args)
f.Close()
goalrun.Run("chmod", "+x", script)
files, err := baremetal.AllFiles("./", ".*") // no .DS_Store etc
if errors.Log(err) != nil {
return ""
}
// fmt.Println("files:", files)
// todo: this is triggering a type defn mode:
var b bytes.Buffer
err = baremetal.TarFiles(&b, "./", true, files...)
if errors.Log(err) != nil {
return ""
}
bm := sr.BareMetal
spath := sr.ServerJobPath(jid)
job, err := bm.Submit(sr.Config.Project, spath, script, sr.Config.FetchFiles, b.Bytes())
goalrun.Run("@0")
if err != nil {
core.ErrorSnackbar(sr, err)
return "-1"
}
bid := strconv.Itoa(job.ID)
goalib.WriteFile("job.job", bid)
bm.UpdateJobs()
sr.GetMeta(jid)
return bid
}
// WriteBare writes the bash script to run a "bare metal" run.
func (sr *Simmer) WriteBare(w io.Writer, jid, args string) {
if sr.Config.JobScript != "" {
js := sr.Config.JobScript
strings.ReplaceAll(js, "$JOB_ARGS", args)
fmt.Fprintln(w, js)
return
}
fmt.Fprintf(w, "#!/bin/bash -l\n") // -l = login session, sources your .bash_profile
fmt.Fprintf(w, "\n\n")
if sr.Config.SetupScript != "" {
fmt.Fprintln(w, sr.Config.SetupScript)
}
exe := sr.Config.Project
if sr.Config.Job.SubCmd {
projname := exe
exe += "/" + exe
fmt.Fprintf(w, "cd %s\n", projname)
fmt.Fprintf(w, "go build -mod=mod %s\n", sr.Config.Job.BuildArgs)
fmt.Fprintf(w, "cd ../\n")
} else {
fmt.Fprintf(w, "go build -mod=mod %s\n", sr.Config.Job.BuildArgs)
}
cmd := `date '+%Y-%m-%d %T %Z' > job.start`
fmt.Fprintln(w, cmd)
fmt.Fprintf(w, "./%s -nogui -cfg config_job.toml -gpu-device $BARE_GPU %s >& job.out & echo $! > job.pid", exe, args)
}
func (sr *Simmer) QueueBare() {
sr.UpdateBare()
ts := sr.Tabs.AsLab()
goalrun.Run("@1")
goalrun.Run("cd")
smi := goalrun.Output("nvidia-smi")
goalrun.Run("@0")
ts.EditorString("Queue", smi)
}
// UpdateBare updates the BareMetal jobs
func (sr *Simmer) UpdateBare() { //types:add
// nrun, nfin := errors.Log2(sr.BareMetal.UpdateJobs())
// core.MessageSnackbar(sr, fmt.Sprintf("BareMetal jobs run: %d finished: %d", nrun, nfin))
}
// FetchJobBare downloads results files from bare metal server.
func (sr *Simmer) FetchJobBare(jid string, force bool) {
jpath := sr.JobPath(jid)
goalrun.Run("@0")
goalrun.Run("cd", jpath)
sstat := goalib.ReadFile("job.status")
if !force && sstat == "Fetched" {
return
}
sjob := sr.ValueForJob(jid, "ServerJob")
sj := errors.Log1(strconv.Atoi(sjob))
jobs, err := sr.BareMetal.FetchResults(sr.Config.FetchFiles, sj)
if err != nil {
core.ErrorSnackbar(sr, err)
return
}
if len(jobs) == 0 {
return
}
job := jobs[0]
baremetal.Untar(bytes.NewReader(job.Results), jpath, true) // gzip
// note: we don't do any post-processing here -- see slurm version for combining separate runs
if sstat == "Finalized" {
// fmt.Println("status finalized")
goalib.WriteFile("job.status", "Fetched")
goalib.ReplaceInFile("dbmeta.toml", "\"Finalized\"", "\"Fetched\"")
// } else {
// fmt.Println("status:", sstat)
}
}
func (sr *Simmer) CancelJobsBare(jobs []string) {
jnos := make([]int, 0, len(jobs))
for _, jid := range jobs {
sjob := sr.ValueForJob(jid, "ServerJob")
if sjob != "" {
jno := errors.Log1(strconv.Atoi(sjob))
jnos = append(jnos, jno)
}
}
sr.BareMetal.CancelJobs(jnos...)
}
func (sr *Simmer) RecoverJobsBare(jobs []string) {
for _, jid := range jobs {
jno := 0
sjob := sr.ValueForJob(jid, "ServerJob")
if sjob != "" {
jno = errors.Log1(strconv.Atoi(sjob))
} else {
fmt.Println("job does not have a ServerJob id")
}
job := &baremetal.Job{ID: jno}
job.Status.SetString(sr.ValueForJob(jid, "Status"))
job.Path = sr.ServerJobPath(jid)
job.Source = sr.Config.Project
job.Script = "job.sbatch"
job.ResultsGlob = sr.Config.FetchFiles
job.Submit = errors.Log1(time.Parse(sr.Config.TimeFormat, sr.ValueForJob(jid, "Submit")))
job.Start = errors.Log1(time.Parse(sr.Config.TimeFormat, sr.ValueForJob(jid, "Start")))
job.End = errors.Log1(time.Parse(sr.Config.TimeFormat, sr.ValueForJob(jid, "End")))
job.ServerName = sr.ValueForJob(jid, "Server")
// todo: PID
// jpath := sr.JobPath(jid)
// @0
// cd {jpath}
// job.Status = goalib.ReadFile("job.status")
// job.Label = goalib.ReadFile("job.label")
sr.BareMetal.RecoverJob(job)
}
}
// Code generated by "goal build"; DO NOT EDIT.
//line config.goal:1
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"path/filepath"
"strings"
"cogentcore.org/lab/lab"
"cogentcore.org/lab/table"
)
// FilterResults specifies which results files to open.
type FilterResults struct {
// File name contains this string, e.g., "_epoch" or "_run"
FileContains string `width:"60"`
// extension of files, e.g., .tsv
Ext string
// if true, fetch results before opening.
Fetch bool
}
func (fp *FilterResults) Defaults() {
fp.FileContains = "_epc"
fp.Ext = ".tsv"
fp.Fetch = true
}
// SubmitParams specifies the parameters for submitting a job.
type SubmitParams struct {
// Message describing the simulation:
// this is key info for what is special about this job, like a github commit message
Message string `width:"80"`
// Label is brief, unique label used for plots to label this job
Label string `width:"80"`
// arguments to pass on the command line.
//
// -nogui is already passed by default
Args string `width:"80"`
}
// JobParams are parameters for running the job
type JobParams struct {
// NRuns is the number of parallel runs; can also set to 1
// and run multiple runs per job using args.
NRuns int
// Hours is the max number of hours: slurm will terminate if longer,
// so be generous 2d = 48, 3d = 72, 4d = 96, 5d = 120, 6d = 144, 7d = 168
Hours int
// Memory per CPU in gigabytes
Memory int
// Tasks is the number of mpi "tasks" (procs in MPI terminology).
Tasks int
// CPUsPerTask is the number of cpu cores (threads) per task.
CPUsPerTask int
// TasksPerNode is how to allocate tasks within compute nodes
// cpus_per_task * tasks_per_node <= total cores per node.
TasksPerNode int
// Qos is the queue "quality of service" name.
Qos string
// If true, the executable is in a subdirectory with the same name as [Config.Project],
// filename main.go, to allow the primary directory to be imported into other apps.
// Manages the copying and building of this sub-command.
SubCmd bool
// BuildArgs are extra arts to pass during building, such as -tags mpi for mpi
BuildArgs string
}
func (jp *JobParams) Defaults() {
jp.NRuns = 10
jp.Hours = 1
jp.Memory = 1
jp.Tasks = 1
jp.CPUsPerTask = 8
jp.TasksPerNode = 1
}
// ServerParams are parameters for the server.
type ServerParams struct {
// Name is the name of current server using to run jobs;
// gets recorded with each job.
Name string
// Root is the root path from user home dir on server.
// is auto-set to: filepath.Join("simserver", Project, User)
Root string
// Slurm uses the slurm job manager. Otherwise uses a bare job manager.
Slurm bool
}
func (sp *ServerParams) Defaults() {
}
// Configuration holds all of the user-settable parameters
type Configuration struct {
// DataRoot is the path to the root of the data to browse.
DataRoot string
// StartDir is the starting directory, where the app was originally started.
StartDir string
// User id as in system login name (i.e., user@system).
User string
// UserShort is the first 3 letters of User,
// for naming jobs (auto-set from User).
UserShort string
// Project is the name of simulation project, lowercase
// (should be name of source dir).
Project string
// Package is the parent package: e.g., github.com/emer/axon/v2
// This is used to update the go.mod, along with the Version.
Package string
// Version is the current git version string, from git describe --tags.
Version string
// Job has the parameters for job resources etc.
Job JobParams `display:"inline"`
// Server has server parameters.
Server ServerParams
// GroupColumns are the column(s) to use for grouping result data, for PlotMean.
// e.g., Epoch for epoch-level results.
GroupColumns []string
// FetchFiles is a glob expression for files to fetch from server,
// for Fetch command. Is *.tsv by default.
FetchFiles string
// ExcludeNodes are nodes to exclude from job, based on what is slow.
ExcludeNodes string
// ExtraFiles has extra files to upload with job submit, from same dir.
ExtraFiles []string
// ExtraDirs has subdirs with other files to upload with job submit
// (non-code -- see CodeDirs).
ExtraDirs []string
// CodeDirs has subdirs with code to upload with job submit;
// go.mod auto-updated to use.
CodeDirs []string
// ExtraGoGet is an extra package to do "go get" with, for launching the job.
ExtraGoGet string
// JobScript is a job script to use for running the simulation,
// instead of the basic default, if non-empty.
// This is written to the job.sbatch file. If it contains a $JOB_ARGS string
// then that is replaced with the args entered during submission.
// If using slurm, this switches to a simple direct sbatch submission instead
// of the default parallel job submission. All standard slurm job parameters
// are automatically inserted at the start of the file, so this script should
// just be the actual job running actions after that point.
JobScript string
// SetupScript contains optional lines of bash script code to insert at
// the start of the job submission script, which is then followed by
// the default script. For example, if a symbolic link to a large shared
// resource is needed, make that link here.
SetupScript string
// TimeFormat is the format for timestamps,
// defaults to "2006-01-02 15:04:05 MST"
TimeFormat string `default:"2006-01-02 15:04:05 MST"`
// Filter has parameters for filtering results.
Filter FilterResults
// Submit has parameters for submitting jobs; set from last job run.
Submit SubmitParams
}
func (cf *Configuration) Defaults() {
goalrun.Run("@0")
goalrun.Run("cd")
goalrun.Run("cd", cf.StartDir)
cf.Version = strings.TrimSpace(goalrun.Output("git", "describe", "--tags"))
goalrun.Run("cd", cf.DataRoot)
cf.User = strings.TrimSpace(goalrun.Output("echo", "$USER"))
_, pj := filepath.Split(cf.StartDir)
cf.Project = pj
cf.Job.Defaults()
cf.Server.Defaults()
cf.FetchFiles = "*.tsv job.label"
cf.Filter.Defaults()
cf.TimeFormat = "2006-01-02 15:04:05 MST"
cf.GroupColumns = []string{"Epoch"}
}
func (cf *Configuration) Update() {
cf.UserShort = cf.User[:3]
cf.Server.Root = filepath.Join("simserver", cf.Project, cf.User)
}
// Result has info for one loaded result, as a table.Table
type Result struct {
// job id for results
JobID string
// short label used as a legend in the plot
Label string
// description of job
Message string
// args used in running job
Args string
// path to data
Path string
// result data
Table *table.Table
}
// EditConfig edits the configuration
func (sr *Simmer) EditConfig() { //types:add
lab.PromptStruct(sr, &sr.Config, "Configuration parameters", nil)
}
// Code generated by "goal build"; DO NOT EDIT.
//line jobs.goal:1
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"bytes"
"fmt"
"path/filepath"
"strconv"
"strings"
"cogentcore.org/core/base/elide"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/base/iox/tomlx"
"cogentcore.org/core/base/strcase"
"cogentcore.org/core/core"
"cogentcore.org/core/styles"
"cogentcore.org/lab/goal/goalib"
"cogentcore.org/lab/lab"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
)
// Jobs updates the Jobs tab with a Table showing all the Jobs
// with their meta data. Uses the dbmeta.toml data compiled from
// the Status function.
func (sr *Simmer) Jobs() { //types:add
ts := sr.Tabs.AsLab()
if !sr.IsSlurm() {
// todo: get data back from server
sr.BareMetalActive = errors.Log1(sr.BareMetal.JobStatus())
at := ts.SliceTable("Bare", &sr.BareMetalActive)
if sr.BareMetalActiveTable != at {
sr.BareMetalActiveTable = at
at.Styler(func(s *styles.Style) {
s.SetReadOnly(true)
})
}
at.Update()
}
tv := ts.TensorTable("Jobs", sr.JobsTable)
dt := sr.JobsTable
if sr.JobsTableView != tv {
sr.JobsTableView = tv
tv.ShowIndexes = true
tv.ReadOnlyMultiSelect = true
tv.Styler(func(s *styles.Style) {
s.SetReadOnly(true)
})
}
dpath := filepath.Join(sr.DataRoot, "jobs")
// fmt.Println("opening data at:", dpath)
if dt.NumColumns() == 0 {
dbfmt := filepath.Join(sr.DataRoot, "dbformat.csv")
fdt := table.New()
if errors.Log1(fsx.FileExists(dbfmt)) {
fdt.OpenCSV(fsx.Filename(dbfmt), tensor.Comma)
} else {
fdt.ReadCSV(bytes.NewBuffer([]byte(defaultJobFormat)), tensor.Comma)
}
dt.ConfigFromTable(fdt)
}
ds := fsx.Dirs(dpath)
dt.SetNumRows(len(ds))
for i, d := range ds {
dt.Column("JobID").SetString(d, i)
dp := filepath.Join(dpath, d)
meta := filepath.Join(dp, "dbmeta.toml")
if goalib.FileExists(meta) {
md := make(map[string]string)
tomlx.Open(&md, meta)
for k, v := range md {
dc := dt.Column(k)
if dc != nil {
dc.SetString(v, i)
}
}
}
}
tv.Table.Sequential()
nrows := dt.NumRows()
if nrows > 0 && sr.Config.Submit.Message == "" {
sr.Config.Submit.Message = dt.Column("Message").String1D(nrows - 1)
sr.Config.Submit.Args = dt.Column("Args").String1D(nrows - 1)
sr.Config.Submit.Label = dt.Column("Label").String1D(nrows - 1)
}
}
// Jobs updates the Jobs tab with a Table showing all the Jobs
// with their meta data. Uses the dbmeta.toml data compiled from
// the Status function.
func (sr *Simmer) UpdateSims() { //types:add
sr.Jobs()
sr.Update()
}
// UpdateSims updates the sim status info, for async case.
func (sr *Simmer) UpdateSimsAsync() {
sr.AsyncLock()
sr.Jobs()
sr.Update()
sr.AsyncUnlock()
}
func (sr *Simmer) JobPath(jid string) string {
return filepath.Join(sr.DataRoot, "jobs", jid)
}
func (sr *Simmer) ServerJobPath(jid string) string {
return filepath.Join(sr.Config.Server.Root, "jobs", jid)
}
func (sr *Simmer) JobRow(jid string) int {
jt := sr.JobsTable.Column("JobID")
nr := jt.DimSize(0)
for i := range nr {
if jt.String1D(i) == jid {
return i
}
}
fmt.Println("JobRow ERROR: job id:", jid, "not found")
return -1
}
// ValueForJob returns value in given column for given job id
func (sr *Simmer) ValueForJob(jid, column string) string {
if jrow := sr.JobRow(jid); jrow >= 0 {
return sr.JobsTable.Column(column).String1D(jrow)
}
return ""
}
// Queue runs a queue query command on the server and shows the results.
func (sr *Simmer) Queue() { //types:add
if sr.IsSlurm() {
sr.QueueSlurm()
} else {
sr.QueueBare()
}
}
// JobStatus gets job status from server for given job id.
// jobs that are already Finalized are skipped, unless force is true.
func (sr *Simmer) JobStatus(jid string, force bool) {
// fmt.Println("############\nStatus of Job:", jid)
spath := sr.ServerJobPath(jid)
jpath := sr.JobPath(jid)
goalrun.Run("@0")
goalrun.Run("cd", jpath)
if !goalib.FileExists("job.status") {
goalib.WriteFile("job.status", "Unknown")
}
sstat := goalib.ReadFile("job.status")
if !force && (sstat == "Finalized" || sstat == "Fetched" || sstat == "Canceled") {
return
}
if sr.IsSlurm() {
goalrun.Run("@1")
goalrun.Run("cd")
goalrun.Run("cd", spath)
sj := goalrun.Output("@1", "cat", "job.job")
// fmt.Println("server job:", sj)
if sstat != "Done" && !force {
goalrun.RunErrOK("@1", "squeue", "-j", sj, "-o", "%T", ">&", "job.squeue")
stat := goalrun.Output("@1", "cat", "job.squeue")
// fmt.Println("server status:", stat)
switch {
case strings.Contains(stat, "Invalid job id"):
goalrun.Run("@1", "echo", "Invalid job id", ">", "job.squeue")
sstat = "Done"
case strings.Contains(stat, "RUNNING"):
nrep := strings.Count(stat, "RUNNING")
sstat = fmt.Sprintf("Running:%d", nrep)
case strings.Contains(stat, "PENDING"):
nrep := strings.Count(stat, "PENDING")
sstat = fmt.Sprintf("Pending:%d", nrep)
case strings.Contains(stat, "STATE"): // still visible in queue but done
sstat = "Done"
}
goalib.WriteFile("job.status", sstat)
}
goalrun.Run("@1", "/bin/ls", "-1", ">", "job.files")
goalrun.Run("@0")
core.MessageSnackbar(sr, "Retrieving job files for: "+jid)
jfiles := goalrun.Output("@1", "/bin/ls", "-1", "job.*")
for _, jf := range goalib.SplitLines(jfiles) {
if !sr.IsSlurm() && jf == "job.status" {
continue
}
// fmt.Println(jf)
rfn := "@1:" + jf
if !force {
goalrun.Run("scp", rfn, jf)
}
}
} else {
jstr := strings.TrimSpace(goalrun.OutputErrOK("cat", "job.job"))
if jstr == "" {
msg := fmt.Sprintf("Status for Job: %s: job.job is empty, so can't proceed with BareMetal", jid)
fmt.Println(msg)
core.MessageSnackbar(sr, msg)
return
}
sj := errors.Log1(strconv.Atoi(jstr))
// fmt.Println(jid, "jobno:", sj)
jobs, err := sr.BareMetal.JobStatus(sj)
if err != nil {
core.ErrorSnackbar(sr, err)
} else {
if len(jobs) == 1 {
job := jobs[0]
sstat = job.Status.String()
goalib.WriteFile("job.status", sstat)
goalib.WriteFile("job.squeue", sstat)
if !job.Start.IsZero() {
goalib.WriteFile("job.start", job.Start.Format(sr.Config.TimeFormat))
}
if !job.End.IsZero() {
goalib.WriteFile("job.end", job.End.Format(sr.Config.TimeFormat))
}
// fmt.Println(jid, sstat)
}
}
goalrun.Run("@0")
if sstat == "Running" {
// core.MessageSnackbar(sr, "Retrieving job files for: " + jid)
// @1
// cd
// todo: need more robust ways of testing for files and error recovery on remote ssh con
// cd {spath}
// scp @1:job.out job.out
// scp @1:nohup.out nohup.out
}
}
goalrun.Run("@0")
if sstat == "Done" || sstat == "Completed" {
sstat = "Finalized"
goalib.WriteFile("job.status", sstat)
if sr.IsSlurm() {
goalrun.RunErrOK("/bin/rm", "job.*.out")
}
sr.FetchJob(jid, false)
}
sr.GetMeta(jid)
core.MessageSnackbar(sr, "Job: "+jid+" updated with status: "+sstat)
}
// GetMeta gets the dbmeta.toml file from all job.* files in job dir.
func (sr *Simmer) GetMeta(jid string) {
goalrun.Run("@0")
goalrun.Run("cd")
jpath := sr.JobPath(jid)
goalrun.Run("cd", jpath)
// fmt.Println("getting meta for", jid)
jfiles := goalrun.Output("/bin/ls", "-1", "job.*") // local
meta := fmt.Sprintf("%s = %q\n", "Server", sr.Config.Server.Name)
for _, jf := range goalib.SplitLines(jfiles) {
if strings.Contains(jf, "sbatch") || strings.HasSuffix(jf, ".out") || strings.HasSuffix(jf, ".gz") {
continue
}
key := strcase.ToCamel(strings.TrimPrefix(jf, "job."))
switch key {
case "Job":
key = "ServerJob"
case "Squeue":
key = "ServerStatus"
}
val := strings.TrimSpace(goalib.ReadFile(jf))
if key == "ServerStatus" {
val = strings.ReplaceAll(elide.Middle(val, 50), "…", "...")
}
ln := fmt.Sprintf("%s = %q\n", key, val)
// fmt.Println(ln)
meta += ln
}
goalib.WriteFile("dbmeta.toml", meta)
}
// Status gets updated job.* files from the server for any job that
// doesn't have a Finalized or Fetched status. It updates the
// status based on the server job status query, assigning a
// status of Finalized if job is done. Updates the dbmeta.toml
// data based on current job data.
func (sr *Simmer) Status() { //types:add
goalrun.Run("@0")
sr.UpdateFiles()
sr.Jobs()
dpath := filepath.Join(sr.DataRoot, "jobs")
ds := fsx.Dirs(dpath)
for _, jid := range ds {
sr.JobStatus(jid, false) // true = update all -- for format and status edits
}
core.MessageSnackbar(sr, "Jobs Status completed")
sr.UpdateSims()
}
// FetchJob downloads results files from server.
// if force == true then will re-get already-Fetched jobs,
// otherwise these are skipped.
func (sr *Simmer) FetchJob(jid string, force bool) {
if sr.IsSlurm() {
sr.FetchJobSlurm(jid, force)
} else {
sr.FetchJobBare(jid, force)
}
}
// Fetch retrieves all the .tsv data files from the server
// for any jobs not already marked as Fetched.
// Operates on the jobs selected in the Jobs table,
// or on all jobs if none selected.
func (sr *Simmer) Fetch() { //types:add
goalrun.Run("@0")
tv := sr.JobsTableView
jobs := tv.SelectedColumnStrings("JobID")
if len(jobs) == 0 {
dpath := filepath.Join(sr.DataRoot, "jobs")
jobs = fsx.Dirs(dpath)
}
for _, jid := range jobs {
sr.FetchJob(jid, false)
}
core.MessageSnackbar(sr, "Fetch Jobs completed")
sr.UpdateSims()
}
// Cancel cancels the jobs selected in the Jobs table,
// with a confirmation prompt.
func (sr *Simmer) Cancel() { //types:add
tv := sr.JobsTableView
jobs := tv.SelectedColumnStrings("JobID")
if len(jobs) == 0 {
core.MessageSnackbar(sr, "No jobs selected for cancel")
return
}
lab.PromptOKCancel(sr, "Ok to cancel these jobs: "+strings.Join(jobs, " "), func() {
if sr.IsSlurm() {
sr.CancelJobsSlurm(jobs)
} else {
sr.CancelJobsBare(jobs)
}
sr.UpdateSims()
})
}
// DeleteJobs deletes the given jobs
func (sr *Simmer) DeleteJobs(jobs []string) {
goalrun.Run("@0")
dpath := filepath.Join(sr.DataRoot, "jobs")
spath := filepath.Join(sr.Config.Server.Root, "jobs")
for _, jid := range jobs {
goalrun.Run("@0")
goalrun.Run("cd", dpath)
goalrun.RunErrOK("/bin/rm", "-rf", jid)
goalrun.Run("@1")
goalrun.Run("cd")
// todo: [cd {spath} && /bin/rm -rf {jid}]
goalrun.Run("cd", spath, "&&", "/bin/rm", "-rf", jid)
goalrun.Run("@0")
}
goalrun.Run("@1")
goalrun.Run("cd")
goalrun.Run("@0")
core.MessageSnackbar(sr, "Done deleting jobs")
}
// Delete deletes the selected Jobs, with a confirmation prompt.
func (sr *Simmer) Delete() { //types:add
tv := sr.JobsTableView
jobs := tv.SelectedColumnStrings("JobID")
if len(jobs) == 0 {
core.MessageSnackbar(sr, "No jobs selected for deletion")
return
}
lab.PromptOKCancel(sr, "Ok to delete these jobs: "+strings.Join(jobs, " "), func() {
sr.DeleteJobs(jobs)
sr.UpdateSims()
})
}
// ArchiveJobs archives the given jobs
func (sr *Simmer) ArchiveJobs(jobs []string) {
goalrun.Run("@0")
dpath := filepath.Join(sr.DataRoot, "jobs")
apath := filepath.Join(sr.DataRoot, "archive", "jobs")
goalrun.Run("mkdir", "-p", apath)
spath := filepath.Join(sr.Config.Server.Root, "jobs")
for _, jid := range jobs {
goalrun.Run("@1")
goalrun.Run("cd")
goalrun.Run("cd", spath)
goalrun.RunErrOK("/bin/rm", "-rf", jid)
goalrun.Run("@0")
dj := filepath.Join(dpath, jid)
aj := filepath.Join(apath, jid)
goalrun.Run("/bin/mv", dj, aj)
}
goalrun.Run("@1")
goalrun.Run("cd")
goalrun.Run("@0")
core.MessageSnackbar(sr, "Done archiving jobs")
}
// Archive moves the selected Jobs to the Archive directory,
// locally, and deletes them from the server,
// for results that are useful but not immediately relevant.
func (sr *Simmer) Archive() { //types:add
tv := sr.JobsTableView
jobs := tv.SelectedColumnStrings("JobID")
if len(jobs) == 0 {
core.MessageSnackbar(sr, "No jobs selected for archiving")
return
}
lab.PromptOKCancel(sr, "Ok to archive these jobs: "+strings.Join(jobs, " "), func() {
sr.ArchiveJobs(jobs)
sr.UpdateSims()
})
}
// Recover recovers the jobs selected in the Jobs table,
// with a confirmation prompt.
func (sr *Simmer) Recover() { //types:add
tv := sr.JobsTableView
jobs := tv.SelectedColumnStrings("JobID")
if len(jobs) == 0 {
core.MessageSnackbar(sr, "No jobs selected for cancel")
return
}
lab.PromptOKCancel(sr, "Ok to recover these jobs: "+strings.Join(jobs, " "), func() {
// if sr.IsSlurm() {
// sr.CancelJobsSlurm(jobs)
// } else {
sr.RecoverJobsBare(jobs)
// }
sr.UpdateSims()
})
}
// Code generated by "goal build"; DO NOT EDIT.
//line results.goal:1
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"path/filepath"
"reflect"
"strings"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/core"
"cogentcore.org/core/styles"
"cogentcore.org/lab/lab"
"cogentcore.org/lab/plot"
"cogentcore.org/lab/stats/stats"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorfs"
)
// FindResult finds an existing result record for given job id and path,
// returning index and result record if found.
func (sr *Simmer) FindResult(jid, path string) (int, *Result) {
for i, r := range sr.ResultsList {
if r.JobID == jid && r.Path == path {
return i, r
}
}
return -1, nil
}
// OpenResultFiles opens the given result files.
func (sr *Simmer) OpenResultFiles(jobs []string, filter FilterResults) {
ts := sr.Tabs.AsLab()
for _, jid := range jobs {
jpath := sr.JobPath(jid)
message := sr.ValueForJob(jid, "Message")
label := sr.ValueForJob(jid, "Label")
args := sr.ValueForJob(jid, "Args")
fls := fsx.Filenames(jpath, filter.Ext)
for _, fn := range fls {
if filter.FileContains != "" && !strings.Contains(fn, filter.FileContains) {
continue
}
dt := table.New()
fpath := filepath.Join(jpath, fn)
err := dt.OpenCSV(core.Filename(fpath), tensor.Tab)
if err != nil {
fmt.Println(err.Error())
}
// rpath := strings.TrimPrefix(fpath, sr.DataRoot)
rpath := fn // actually visible in table
if ri, _ := sr.FindResult(jid, rpath); ri >= 0 {
sr.ResultsList[ri] = &Result{JobID: jid, Label: label, Message: message, Args: args, Path: rpath, Table: dt}
} else {
sr.ResultsList = append(sr.ResultsList, &Result{JobID: jid, Label: label, Message: message, Args: args, Path: rpath, Table: dt})
}
}
}
if len(sr.ResultsList) == 0 {
core.MessageSnackbar(sr, "No files containing: "+filter.FileContains+" with extension: "+filter.Ext)
return
}
tv := ts.SliceTable("Results", &sr.ResultsList)
if sr.ResultsTableView != tv {
sr.ResultsTableView = tv
sr.styleResults()
}
sr.ResultsTableView.Update()
sr.UpdateSims()
}
// Results loads specific .tsv data files from the jobs selected
// in the Jobs table, into the Results table. There are often
// multiple result files per job, so this step is necessary to
// choose which such files to select for plotting.
func (sr *Simmer) Results() { //types:add
tv := sr.JobsTableView
jobs := tv.SelectedColumnStrings("JobID")
if len(jobs) == 0 {
fmt.Println("No Jobs rows selected")
return
}
// fmt.Println(jobs)
if sr.Config.Filter.Ext == "" {
sr.Config.Filter.Ext = ".tsv"
}
lab.PromptStruct(sr, &sr.Config.Filter, "Open results data for files", func() {
if sr.Config.Filter.Fetch {
core.MessageSnackbar(sr, "Fetching jobs..")
for _, jid := range jobs {
sr.FetchJob(jid, false)
}
core.MessageSnackbar(sr, "Fetch Jobs completed")
sr.UpdateSims()
}
sr.OpenResultFiles(jobs, sr.Config.Filter)
})
}
// Diff shows the differences between two selected jobs, or if only
// one job is selected, between that job and the current source directory.
func (sr *Simmer) Diff() { //types:add
goalrun.Run("@0")
tv := sr.JobsTableView
jobs := tv.SelectedColumnStrings("JobID")
nj := len(jobs)
if nj == 0 || nj > 2 {
core.MessageSnackbar(sr, "Diff requires two Job rows to be selected")
return
}
if nj == 1 {
ja := sr.JobPath(jobs[0])
lab.NewDiffBrowserDirs(ja, sr.StartDir)
return
}
ja := sr.JobPath(jobs[0])
jb := sr.JobPath(jobs[1])
lab.NewDiffBrowserDirs(ja, jb)
}
// Plot concatenates selected Results data files and generates a plot
// of the resulting data.
func (sr *Simmer) Plot() { //types:add
ts := sr.Tabs.AsLab()
tv := sr.ResultsTableView
jis := tv.SelectedIndexesList(false)
if len(jis) == 0 {
fmt.Println("No Results rows selected")
return
}
var AggTable *table.Table
for _, i := range jis {
res := sr.ResultsList[i]
jid := res.JobID
label := res.Label
dt := res.Table.InsertKeyColumns("JobID", jid, "JobLabel", label) // this clones the table
if AggTable == nil {
AggTable = dt
plot.SetFirstStyler(dt.Columns.Values[0], func(s *plot.Style) {
s.Role = plot.Split
s.On = true
})
} else {
AggTable.AppendRows(dt)
}
}
ts.PlotTable("Plot", AggTable)
}
// PlotMean concatenates selected Results data files and generates a plot
// of the resulting data, computing the mean over the values in
// [Config.GroupColumns] to group values (e.g., across Epochs).
func (sr *Simmer) PlotMean() { //types:add
ts := sr.Tabs.AsLab()
tv := sr.ResultsTableView
jis := tv.SelectedIndexesList(false)
if len(jis) == 0 {
fmt.Println("No Results rows selected")
return
}
nc := len(sr.Config.GroupColumns)
rdir := "Stats/" + sr.Config.GroupColumns[nc-1] // results dir
var AggTable *table.Table
for _, i := range jis {
res := sr.ResultsList[i]
jid := res.JobID
label := res.Label
dir, _ := tensorfs.NewDir("root")
rdt := res.Table
stats.TableGroups(dir, rdt, sr.Config.GroupColumns...)
var fcols []string
for ci, cl := range rdt.Columns.Values {
if cl.DataType() != reflect.Float32 && cl.DataType() != reflect.Float64 {
continue
}
fcols = append(fcols, rdt.Columns.Keys[ci])
}
stats.TableGroupStats(dir, stats.StatMean, rdt, fcols...)
edir := dir.Dir(rdir)
sdt := tensorfs.DirTable(edir, nil)
dt := sdt.InsertKeyColumns("JobID", jid, "JobLabel", label) // this clones the table
if AggTable == nil {
AggTable = dt
plot.SetFirstStyler(dt.Columns.Values[0], func(s *plot.Style) {
s.Role = plot.Split
s.On = true
})
} else {
AggTable.AppendRows(dt)
}
}
ts.PlotTable("Plot", AggTable)
}
func (sr *Simmer) styleResults() {
tv := sr.ResultsTableView
tv.ShowIndexes = true
tv.ReadOnlyMultiSelect = true
tv.Styler(func(s *styles.Style) {
s.SetReadOnly(true)
})
}
// Reset resets the Results table
func (sr *Simmer) Reset() { //types:add
ts := sr.Tabs.AsLab()
sr.ResultsList = []*Result{}
tv := ts.SliceTable("Results", &sr.ResultsList)
if sr.ResultsTableView != tv {
sr.ResultsTableView = tv
sr.styleResults()
}
sr.UpdateSims()
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
//go:generate core generate -add-types -add-funcs
import (
"os"
"path/filepath"
"reflect"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/cli"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/icons"
"cogentcore.org/core/styles"
"cogentcore.org/core/tree"
"cogentcore.org/core/yaegicore/coresymbols"
"cogentcore.org/lab/examples/baremetal"
"cogentcore.org/lab/goal"
"cogentcore.org/lab/goal/interpreter"
"cogentcore.org/lab/lab"
"cogentcore.org/lab/lab/labscripts"
_ "cogentcore.org/lab/lab/labscripts"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensorcore"
"cogentcore.org/lab/yaegilab/labsymbols"
"github.com/cogentcore/yaegi/interp"
)
// goalrun is needed for running goal commands.
var goalrun *goal.Goal
var defaultJobFormat = `Name, Type
JobID, string
Status, string
Label, string
Message, string
Version, string
Args, string
Server, string
ServerJob, string
ServerStatus, string
Submit, string
Start, string
End, string
`
// Simmer manages the running and data analysis of results from simulations
// that are run on remote server(s), within a Cogent Lab browser environment,
// with the files as the left panel, and the Tabber as the right panel.
type Simmer struct {
core.Frame
lab.Browser
// Config holds all the configuration settings.
Config Configuration
// JobsTableView is the view of the jobs table.
JobsTableView *tensorcore.Table
// JobsTable is the jobs Table with one row per job.
JobsTable *table.Table
// ResultsTableView has the results table.
ResultsTableView *core.Table
// ResultsList is the list of result records.
ResultsList []*Result
// BareMetal RPC client.
BareMetal *baremetal.Client
// Status info from BareMetal
BareMetalActive []*baremetal.Job
BareMetalActiveTable *core.Table
}
// important: must be run from an interactive terminal.
// Will quit immediately if not!
func main() {
opts := cli.DefaultOptions("simmer", "interactive simulation running and data analysis.")
cfg := &interpreter.Config{}
cfg.InteractiveFunc = Interactive
cli.Run(opts, cfg, interpreter.Run, interpreter.Build)
}
// Interactive is the cli function that gets called by default at gui startup.
func Interactive(c *interpreter.Config, in *interpreter.Interpreter) error {
b, _ := NewSimmerWindow(in)
b.OnShow(func(e events.Event) {
// note: comment out if not running interactively (e.g., debugger)
go func() {
if c.Expr != "" {
in.Eval(c.Expr)
}
in.Interactive()
}()
})
b.RunWindow()
core.Wait()
return nil
}
// NewSimmerWindow returns a new Simmer window using given interpreter.
// do RunWindow on resulting [core.Body] to open the window.
func NewSimmerWindow(in *interpreter.Interpreter) (*core.Body, *Simmer) {
startDir, _ := os.Getwd()
startDir = errors.Log1(filepath.Abs(startDir))
b := core.NewBody("Simmer: " + fsx.DirAndFile(startDir))
sr := NewSimmer(b)
sr.Interpreter = in
b.AddTopBar(func(bar *core.Frame) {
tb := core.NewToolbar(bar)
sr.Toolbar = tb
tb.Maker(sr.MakeToolbar)
})
sr.InitSimmer(startDir)
return b, sr
}
// InitSimmer initializes the simmer configuration and data
// for given starting directory, which should be the main github
// current working directory for the simulation being run.
// All the simmer data is contained within a "simdata" directory
// under the startDir: this dir is typically a symbolic link
// to a common collection of such simdata directories for all
// the different simulations being run.
// The goal Interpreter is typically already set by this point
// but will be created if not.
func (sr *Simmer) InitSimmer(startDir string) {
sr.StartDir = startDir
ddr := errors.Log1(filepath.Abs("simdata"))
sr.SetDataRoot(ddr)
labscripts.InitInterpreter(&sr.Browser)
in, _ := labscripts.Interpreter(&sr.Browser)
in.Interp.Use(coresymbols.Symbols) // gui imports
in.Interp.Use(labsymbols.Symbols)
in.Interp.Use(interp.Exports{
"cogentcore.org/lab/lab/lab": map[string]reflect.Value{
"LabBrowser": reflect.ValueOf(sr), // our Simmer is available as lab.Lab
},
})
in.Config()
sr.SetScriptsDir(filepath.Join(sr.DataRoot, "labscripts"))
lab.LabBrowser = &sr.Browser
lab.Lab = sr.Browser.Tabs
goalrun = in.Goal
sr.Config.StartDir = sr.StartDir
sr.Config.DataRoot = sr.DataRoot
sr.Config.Defaults()
sr.JobsTable = table.New()
sr.UpdateScripts() // automatically runs lowercase init scripts
if !sr.IsSlurm() {
sr.BareMetal = baremetal.NewClient()
err := sr.BareMetal.Connect()
if errors.Log(err) != nil {
sr.BareMetal = nil
}
}
}
func (sr *Simmer) Init() {
sr.Frame.Init()
sr.Styler(func(s *styles.Style) {
s.Grow.Set(1, 1)
})
sr.OnShow(func(e events.Event) {
sr.UpdateFiles()
})
tree.AddChildAt(sr, "splits", func(w *core.Splits) {
sr.Splits = w
w.SetSplits(.15, .85)
tree.AddChildAt(w, "fileframe", func(w *core.Frame) {
w.Styler(func(s *styles.Style) {
s.Direction = styles.Column
s.Overflow.Set(styles.OverflowAuto)
s.Grow.Set(1, 1)
})
tree.AddChildAt(w, "filetree", func(w *lab.DataTree) {
sr.Files = w
})
})
tree.AddChildAt(w, "tabs", func(w *lab.Tabs) {
sr.Tabs = w
})
})
sr.Updater(func() {
if sr.Files != nil {
sr.Files.Tabber = sr.Tabs
}
})
}
// AsyncMessageSnackbar must be used for MessageSnackbar in a goroutine.
func (sr *Simmer) AsyncMessageSnackbar(message string) {
sr.AsyncLock()
core.MessageSnackbar(sr, message)
sr.AsyncUnlock()
}
// IsSlurm returns true if using slurm (vs. baremetal)
func (sr *Simmer) IsSlurm() bool {
return sr.Config.Server.Slurm
}
func (sr *Simmer) MakeToolbar(p *tree.Plan) {
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(sr.UpdateFiles).SetText("").SetIcon(icons.Refresh)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(sr.UpdateSims).SetText("Jobs").SetIcon(icons.ViewList).SetShortcut("Command+U")
})
tree.Add(p, func(w *core.Button) {
w.SetText("Bare").SetIcon(icons.Refresh).
SetTooltip("Update BareMetal jobs").OnClick(func(e events.Event) {
sr.UpdateBare()
})
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(sr.Queue).SetIcon(icons.List)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(sr.Status).SetIcon(icons.Sync)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(sr.Fetch).SetIcon(icons.Download)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(sr.Submit).SetIcon(icons.Add)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(sr.Search).SetIcon(icons.Search)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(sr.Results).SetIcon(icons.Open)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(sr.Reset).SetIcon(icons.Refresh)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(sr.Diff).SetIcon(icons.Difference)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(sr.Plot).SetIcon(icons.ShowChart)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(sr.PlotMean).SetIcon(icons.ShowChart)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(sr.Cancel).SetIcon(icons.Refresh)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(sr.Delete).SetIcon(icons.Delete)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(sr.Archive).SetIcon(icons.Archive)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(sr.Recover).SetIcon(icons.Archive)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(sr.EditConfig).SetIcon(icons.Edit)
})
tree.Add(p, func(w *core.Button) {
w.SetText("README").SetIcon(icons.FileMarkdown).
SetTooltip("open README help file").OnClick(func(e events.Event) {
core.TheApp.OpenURL("https://github.com/cogentcore/lab/blob/main/examples/simmer/README.md")
})
})
}
// Code generated by "goal build"; DO NOT EDIT.
//line slurm.goal:1
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"io"
"os"
"path/filepath"
"strings"
"cogentcore.org/core/core"
"cogentcore.org/lab/goal/goalib"
"cogentcore.org/lab/table"
)
// WriteSBatchHeader writes the header of a SLURM SBatch script
// that is common across all three scripts.
// IMPORTANT: set the job parameters here!
func (sr *Simmer) WriteSBatchHeader(w io.Writer, jid string) {
fmt.Fprintf(w, "#SBATCH --job-name=%s_%s\n", sr.Config.Project, jid)
fmt.Fprintf(w, "#SBATCH --mem-per-cpu=%dG\n", sr.Config.Job.Memory)
fmt.Fprintf(w, "#SBATCH --time=%d:00:00\n", sr.Config.Job.Hours)
fmt.Fprintf(w, "#SBATCH --ntasks=%d\n", sr.Config.Job.Tasks)
fmt.Fprintf(w, "#SBATCH --cpus-per-task=%d\n", sr.Config.Job.CPUsPerTask)
fmt.Fprintf(w, "#SBATCH --ntasks-per-node=%d\n", sr.Config.Job.TasksPerNode)
if sr.Config.ExcludeNodes != "" {
fmt.Fprintf(w, "#SBATCH --exclude=%s\n", sr.Config.ExcludeNodes)
}
// fmt.Fprint(w, "#SBATCH --nodelist=agate-[2,19]\n")
// fmt.Fprintf(w, "#SBATCH --qos=%s\n", qos)
// fmt.Fprintf(w, "#SBATCH --partition=%s\n", qosShort)
fmt.Fprintf(w, "#SBATCH --mail-type=FAIL\n")
fmt.Fprintf(w, "#SBATCH --mail-user=%s\n", sr.Config.User)
// these might be needed depending on environment in head node vs. compute nodes
// fmt.Fprintf(w, "#SBATCH --export=NONE\n")
// fmt.Fprintf(w, "unset SLURM_EXPORT_ENV\n")
}
func (sr *Simmer) WriteSBatchSetup(w io.Writer, jid string) {
fmt.Fprintf(w, "#!/bin/bash -l\n") // -l = login session, sources your .bash_profile
fmt.Fprint(w, "#SBATCH --output=job.setup.out\n")
fmt.Fprint(w, "#SBATCH --error=job.setup.err\n")
sr.WriteSBatchHeader(w, jid)
//////////////////////////////////////////////////////////
// now we do all the setup, like building the executable
fmt.Fprintf(w, "\n\n")
// fmt.Fprintf(w, "go build -mod=mod -tags mpi\n")
fmt.Fprintf(w, "go build -mod=mod\n")
// fmt.Fprintf(w, "/bin/rm images\n")
// fmt.Fprintf(w, "ln -s $HOME/ccn_images/CU3D100_20obj8inst_8tick4sac images\n")
cmd := "date '+%Y-%m-%d %T %Z' > job.start"
fmt.Fprintln(w, cmd)
}
func (sr *Simmer) WriteSBatchArray(w io.Writer, jid, setup_id, args string) {
fmt.Fprintf(w, "#!/bin/bash -l\n") // -l = login session, sources your .bash_profile
fmt.Fprintf(w, "#SBATCH --array=0-%d\n", sr.Config.Job.NRuns-1)
fmt.Fprint(w, "#SBATCH --output=job.%A_%a.out\n")
// fmt.Fprint(w, "#SBATCH --error=job.%A_%a.err\n")
fmt.Fprintf(w, "#SBATCH --dependency=afterany:%s\n", setup_id)
sr.WriteSBatchHeader(w, jid)
//////////////////////////////////////////////////////////
// now we run the job
fmt.Fprintf(w, "echo $SLURM_ARRAY_JOB_ID\n")
fmt.Fprintf(w, "\n\n")
// note: could use srun to run job; -runs = 1 is number to run from run start
fmt.Fprintf(w, "./%s -nogui -cfg config_job.toml -run $SLURM_ARRAY_TASK_ID -runs 1 %s\n", sr.Config.Project, args)
}
func (sr *Simmer) WriteSBatchCleanup(w io.Writer, jid, array_id string) {
fmt.Fprintf(w, "#!/bin/bash -l\n") // -l = login session, sources your .bash_profile
fmt.Fprint(w, "#SBATCH --output=job.cleanup.out\n")
// fmt.Fprint(w, "#SBATCH --error=job.cleanup.err")
fmt.Fprintf(w, "#SBATCH --dependency=afterany:%s\n", array_id)
sr.WriteSBatchHeader(w, jid)
fmt.Fprintf(w, "\n\n")
//////////////////////////////////////////////////////////
// now we cleanup after all the jobs have run
// can cat results files etc.
fmt.Fprintf(w, "cat job.*.out > job.out\n")
fmt.Fprintf(w, "/bin/rm job.*.out\n")
fmt.Fprintf(w, "cat *_train_run.tsv > all_run.tsv\n")
fmt.Fprintf(w, "/bin/rm *_train_run.tsv\n")
fmt.Fprintf(w, "cat *_train_epoch.tsv > all_epc.tsv\n")
fmt.Fprintf(w, "/bin/rm *_train_epoch.tsv\n")
cmd := "date '+%Y-%m-%d %T %Z' > job.end"
fmt.Fprintln(w, cmd)
}
func (sr *Simmer) SubmitSBatch(jid, args string) string {
goalrun.Run("@0")
f, _ := os.Create("job.setup.sbatch")
sr.WriteSBatchSetup(f, jid)
f.Close()
goalrun.Run("scp", "job.setup.sbatch", "@1:job.setup.sbatch")
sid := sr.RunSBatch("job.setup.sbatch")
f, _ = os.Create("job.sbatch")
sr.WriteSBatchArray(f, jid, sid, args)
f.Close()
goalrun.Run("scp", "job.sbatch", "@1:job.sbatch")
aid := sr.RunSBatch("job.sbatch")
f, _ = os.Create("job.cleanup.sbatch")
sr.WriteSBatchCleanup(f, jid, aid)
f.Close()
goalrun.Run("scp", "job.cleanup.sbatch", "@1:job.cleanup.sbatch")
sr.RunSBatch("job.cleanup.sbatch")
sr.GetMeta(jid)
return aid
}
// RunSBatch runs sbatch on the given sbatch file,
// returning the resulting job id.
func (sr *Simmer) RunSBatch(sbatch string) string {
goalrun.Run("@1")
goalrun.Run("sbatch", sbatch, ">", "job.slurm")
goalrun.Run("@0")
ss := goalrun.Output("@1", "cat", "job.slurm")
if ss == "" {
fmt.Println("JobStatus ERROR: no server job.slurm file to get server job id from")
goalrun.Run("@1", "cd")
goalrun.Run("@0")
return ""
}
ssf := strings.Fields(ss)
sj := ssf[len(ssf)-1]
return sj
}
// QueueSlurm runs a queue query command on the server and shows the results.
func (sr *Simmer) QueueSlurm() {
ts := sr.Tabs.AsLab()
goalrun.Run("@1")
goalrun.Run("cd")
myq := goalrun.Output("squeue", "-l", "-u", "$USER")
sinfoall := goalrun.Output("sinfo")
goalrun.Run("@0")
sis := []string{}
for _, l := range goalib.SplitLines(sinfoall) {
if strings.HasPrefix(l, "low") || strings.HasPrefix(l, "med") {
continue
}
sis = append(sis, l)
}
sinfo := strings.Repeat("#", 60) + "\n" + strings.Join(sis, "\n")
qstr := myq + "\n" + sinfo
ts.EditorString("Queue", qstr)
}
func (sr *Simmer) FetchJobSlurm(jid string, force bool) {
spath := sr.ServerJobPath(jid)
jpath := sr.JobPath(jid)
goalrun.Run("@1")
goalrun.Run("cd")
goalrun.Run("@0")
goalrun.Run("cd", jpath)
sstat := goalib.ReadFile("job.status")
if !force && sstat == "Fetched" {
return
}
goalrun.Run("@1", "cd", spath)
goalrun.Run("@0")
ffiles := goalrun.Output("@1", "/bin/ls", "-1", sr.Config.FetchFiles)
if len(ffiles) > 0 {
core.MessageSnackbar(sr, fmt.Sprintf("Fetching %d data files for job: %s", len(ffiles), jid))
}
for _, ff := range goalib.SplitLines(ffiles) {
// fmt.Println(ff)
rfn := "@1:" + ff
goalrun.Run("scp", rfn, ff)
if (sstat == "Finalized" || sstat == "Fetched") && strings.HasSuffix(ff, ".tsv") {
if ff == "all_epc.tsv" {
table.CleanCatTSV(ff, "Run", "Epoch")
idx := strings.Index(ff, "_epc.tsv")
goalrun.Run("tablecat", "-colavg", "-col", "Epoch", "-o", ff[:idx+1]+"avg"+ff[idx+1:], ff)
} else if ff == "all_run.tsv" {
table.CleanCatTSV(ff, "Run")
idx := strings.Index(ff, "_run.tsv")
goalrun.Run("tablecat", "-colavg", "-o", ff[:idx+1]+"avg"+ff[idx+1:], ff)
// } else {
// table.CleanCatTSV(ff, "Run")
}
}
}
goalrun.Run("@0")
if sstat == "Finalized" {
// fmt.Println("status finalized")
goalib.WriteFile("job.status", "Fetched")
goalib.ReplaceInFile("dbmeta.toml", "\"Finalized\"", "\"Fetched\"")
} else {
fmt.Println("status:", sstat)
}
}
// CancelJobsSlurm cancels the given jobs, for slurm
func (sr *Simmer) CancelJobsSlurm(jobs []string) {
goalrun.Run("@0")
filepath.Join(sr.DataRoot, "jobs")
filepath.Join(sr.Config.Server.Root, "jobs")
goalrun.Run("@1")
for _, jid := range jobs {
sjob := sr.ValueForJob(jid, "ServerJob")
if sjob != "" {
goalrun.Run("scancel", sjob)
}
}
goalrun.Run("@1")
goalrun.Run("cd")
goalrun.Run("@0")
core.MessageSnackbar(sr, "Done canceling jobs")
}
// Code generated by "goal build"; DO NOT EDIT.
//line submit.goal:1
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"time"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/fsx"
"cogentcore.org/lab/goal/goalib"
"cogentcore.org/lab/lab"
)
// NextJobNumber returns the next sequential job number to use,
// incrementing value saved in last_job.number file
func (sr *Simmer) NextJobNumber() int {
jf := "last_job.number"
jnf := goalib.ReadFile(jf)
jn := 0
if jnf != "" {
jn, _ = strconv.Atoi(strings.TrimSpace(jnf))
}
jn++
goalib.WriteFile(jf, strconv.Itoa(jn))
return jn
}
func (sr *Simmer) NextJobID() string {
jn := sr.NextJobNumber()
jstr := fmt.Sprintf("%s%05d", sr.Config.UserShort, jn)
return jstr
}
// FindGoMod finds the go.mod file starting from the given directory
func (sr *Simmer) FindGoMod(dir string) string {
for {
if goalib.FileExists(filepath.Join(dir, "go.mod")) {
return dir
}
dir = filepath.Dir(dir)
if dir == "" {
return ""
}
}
return ""
}
// CopyFilesToJob copies files with given extensions (none for all),
// from localSrc to localJob and remote hostJob (@1).
// Ensures directories are made in the job locations.
// If origProjPath and projPath are non-empty, and file is .go file, then the
// origProjPath -> projPath replacement.
func (sr *Simmer) CopyFilesToJob(localSrc, localJob, hostJob, origProjPath, projPath string, exts ...string) {
goalrun.Run("@0")
goalrun.Run("mkdir", "-p", localJob)
goalrun.Run("cd", localJob)
if sr.IsSlurm() {
goalrun.Run("@1")
goalrun.Run("cd")
goalrun.Run("mkdir", "-p", hostJob)
goalrun.Run("cd", hostJob)
goalrun.Run("@0")
}
efls := fsx.Filenames(localSrc, exts...)
for _, f := range efls {
sfn := filepath.Join(localSrc, f)
goalrun.Run("/bin/cp", sfn, f)
if projPath != "" && strings.HasSuffix(f, ".go") {
goalib.ReplaceInFile(f, origProjPath, projPath)
}
if sr.IsSlurm() {
goalrun.Run("scp", f, "@1:"+f)
}
}
}
// NewJob runs a new job with given parameters.
// This is run as a separate goroutine!
func (sr *Simmer) NewJob(jp SubmitParams) {
message := jp.Message
args := jp.Args
label := jp.Label
goalrun.Run("@0")
os.Chdir(sr.DataRoot)
jid := sr.NextJobID()
spath := sr.ServerJobPath(jid)
jpath := sr.JobPath(jid)
sr.AsyncMessageSnackbar("Submitting Job: " + jid)
isSlurm := sr.IsSlurm()
gomodDir := sr.FindGoMod(sr.StartDir)
subDir := strings.TrimPrefix(sr.StartDir, gomodDir)
projPath := path.Join("emer", subDir)
origProjPath := path.Join(sr.Config.Package, subDir)
// fmt.Println("gmd:", gomodDir, "sd:", subDir, "pp:", projPath, "opp:", origProjPath)
// fmt.Println(jpath)
os.MkdirAll(jpath, 0750)
os.Chdir(jpath)
goalib.WriteFile("job.message", message)
goalib.WriteFile("job.args", args)
goalib.WriteFile("job.label", label)
goalib.WriteFile("job.version", sr.Config.Version)
goalib.WriteFile("job.submit", time.Now().Format(sr.Config.TimeFormat))
goalib.WriteFile("job.status", "Submitted")
// need to do sub-code first and update paths in copied files
cdirs := sr.Config.CodeDirs
if sr.Config.Job.SubCmd {
cdirs = append(cdirs, sr.Config.Project)
}
for _, ed := range cdirs {
goalrun.Run("@0")
loce := filepath.Join(sr.StartDir, ed)
jpathe := filepath.Join(jpath, ed)
spathe := filepath.Join(spath, ed)
sr.CopyFilesToJob(loce, jpathe, spathe, origProjPath, projPath, ".go")
}
// copy local files:
if isSlurm {
goalrun.Run("@1")
goalrun.Run("cd")
goalrun.Run("mkdir", "-p", spath)
goalrun.Run("cd", spath)
}
goalrun.Run("@0")
goalrun.Run("cd", jpath)
fls := fsx.Filenames(sr.StartDir, ".go")
for _, f := range fls {
sfn := filepath.Join(sr.StartDir, f)
goalrun.Run("/bin/cp", sfn, f)
goalib.ReplaceInFile(f, origProjPath, projPath)
if isSlurm {
goalrun.Run("scp", f, "@1:"+f)
}
}
for _, f := range sr.Config.ExtraFiles {
sfn := filepath.Join(sr.StartDir, f)
goalrun.Run("/bin/cp", sfn, f)
if isSlurm {
goalrun.Run("scp", sfn, "@1:"+f)
}
}
for _, ed := range sr.Config.ExtraDirs {
jpathe := filepath.Join(jpath, ed)
spathe := filepath.Join(spath, ed)
loce := filepath.Join(sr.StartDir, ed)
sr.CopyFilesToJob(loce, jpathe, spathe, "", "")
}
if isSlurm {
goalrun.Run("@1")
goalrun.Run("cd")
goalrun.Run("cd", spath)
}
goalrun.Run("@0")
goalrun.Run("cd", jpath)
sr.AsyncMessageSnackbar("Job: " + jid + " files copied")
if gomodDir != "" {
sfn := filepath.Join(gomodDir, "go.mod")
// fmt.Println("go.mod dir:", gomodDir, sfn)
if isSlurm {
goalrun.Run("scp", sfn, "@1:go.mod")
sfn = filepath.Join(gomodDir, "go.sum")
goalrun.Run("scp", sfn, "@1:go.sum")
goalrun.Run("@1")
} else {
goalrun.Run("cp", sfn, "go.mod")
sfn = filepath.Join(gomodDir, "go.sum")
goalrun.Run("cp", sfn, "go.sum")
}
// note: using local go here for baremetal
goalrun.Run("go", "mod", "edit", "-module", projPath)
if sr.Config.Package != "" {
goalrun.Run("go", "get", sr.Config.Package+"@"+sr.Config.Version)
}
if sr.Config.ExtraGoGet != "" {
goalrun.Run("go", "get", sr.Config.ExtraGoGet)
}
goalrun.Run("go", "mod", "tidy")
if isSlurm {
goalrun.Run("@0")
goalrun.Run("scp", "@1:go.mod", "go.mod")
goalrun.Run("scp", "@1:go.sum", "go.sum")
}
} else {
fmt.Println("go.mod file not found!")
}
if isSlurm {
sid := sr.SubmitSBatch(jid, args)
goalib.WriteFile("job.job", sid)
fmt.Println("server job id:", sid)
goalrun.Run("scp", "job.job", "@1:job.job")
sr.AsyncMessageSnackbar("Job: " + jid + " server job: " + sid + " successfully submitted")
goalrun.Run("@1", "cd")
} else {
sid := sr.SubmitBare(jid, args)
goalib.WriteFile("job.job", sid)
fmt.Println("server job id:", sid)
sr.AsyncMessageSnackbar("Job: " + jid + " server job: " + sid + " successfully submitted")
}
goalrun.Run("@0")
sr.UpdateSimsAsync()
}
// Submit submits a job on the server.
// Creates a new job dir based on incrementing counter,
// synchronizing the job files.
func (sr *Simmer) Submit() { //types:add
lab.PromptStruct(sr, &sr.Config.Submit, "Submit a new job", func() {
go sr.NewJob(sr.Config.Submit)
})
}
// Search runs parameter search jobs, one for each parameter.
// The number of parameters is obtained by running the currently built
// simulation executable locally with the -search-n argument, which
// returns the total number of parameter searches to run.
// THUS, YOU MUST BUILD THE LOCAL SIM WITH THE PARAM SEARCH CONFIGURED.
// Then, it launches that number of jobs with -search-at values from 1..n.
// The jobs should write a job.label file for the searched parameter.
func (sr *Simmer) Search() { //types:add
lab.PromptStruct(sr, &sr.Config.Submit, "Submit a new param search: make sure local sim has been built with search configured!", func() {
go sr.NewSearch(sr.Config.Submit)
})
}
// NewSearch runs a new parameter search with given parameters.
// This is run as a separate goroutine!
func (sr *Simmer) NewSearch(jp SubmitParams) {
goalrun.Run("@0")
goalrun.Run("cd", sr.StartDir)
proj := sr.Config.Project
lsim := filepath.Join(proj, proj)
nss := goalrun.Output(lsim, "-nogui", "-search-n")
sl := goalib.SplitLines(nss)
if len(sl) == 0 {
fmt.Println("Search: no output")
return
}
ns, err := strconv.Atoi(sl[0])
if errors.Log(err) != nil {
return
}
if ns <= 0 {
fmt.Println("Number of search params <= 0:", ns)
return
}
fmt.Println("Param search launching jobs:", ns)
sarg := jp.Args
if sarg != "" {
sarg += " "
}
for i := range ns {
cjp := jp
cjp.Args = sarg + fmt.Sprintf("-search-at %d", i+1)
sr.NewJob(cjp)
}
}
// Code generated by "core generate -add-types -add-funcs"; DO NOT EDIT.
package main
import (
"cogentcore.org/core/core"
"cogentcore.org/core/tree"
"cogentcore.org/core/types"
"cogentcore.org/lab/examples/baremetal"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensorcore"
)
var _ = types.AddType(&types.Type{Name: "main.FilterResults", IDName: "filter-results", Doc: "FilterResults specifies which results files to open.", Fields: []types.Field{{Name: "FileContains", Doc: "File name contains this string, e.g., \"_epoch\" or \"_run\""}, {Name: "Ext", Doc: "extension of files, e.g., .tsv"}, {Name: "Fetch", Doc: "if true, fetch results before opening."}}})
var _ = types.AddType(&types.Type{Name: "main.SubmitParams", IDName: "submit-params", Doc: "SubmitParams specifies the parameters for submitting a job.", Fields: []types.Field{{Name: "Message", Doc: "Message describing the simulation:\nthis is key info for what is special about this job, like a github commit message"}, {Name: "Label", Doc: "Label is brief, unique label used for plots to label this job"}, {Name: "Args", Doc: "\targuments to pass on the command line.\n\n-nogui is already passed by default"}}})
var _ = types.AddType(&types.Type{Name: "main.JobParams", IDName: "job-params", Doc: "JobParams are parameters for running the job", Fields: []types.Field{{Name: "NRuns", Doc: "NRuns is the number of parallel runs; can also set to 1\nand run multiple runs per job using args."}, {Name: "Hours", Doc: "Hours is the max number of hours: slurm will terminate if longer,\nso be generous 2d = 48, 3d = 72, 4d = 96, 5d = 120, 6d = 144, 7d = 168"}, {Name: "Memory", Doc: "Memory per CPU in gigabytes"}, {Name: "Tasks", Doc: "Tasks is the number of mpi \"tasks\" (procs in MPI terminology)."}, {Name: "CPUsPerTask", Doc: "CPUsPerTask is the number of cpu cores (threads) per task."}, {Name: "TasksPerNode", Doc: "TasksPerNode is how to allocate tasks within compute nodes\ncpus_per_task * tasks_per_node <= total cores per node."}, {Name: "Qos", Doc: "Qos is the queue \"quality of service\" name."}, {Name: "SubCmd", Doc: "If true, the executable is in a subdirectory with the same name as [Config.Project],\nfilename main.go, to allow the primary directory to be imported into other apps.\nManages the copying and building of this sub-command."}, {Name: "BuildArgs", Doc: "BuildArgs are extra arts to pass during building, such as -tags mpi for mpi"}}})
var _ = types.AddType(&types.Type{Name: "main.ServerParams", IDName: "server-params", Doc: "ServerParams are parameters for the server.", Fields: []types.Field{{Name: "Name", Doc: "Name is the name of current server using to run jobs;\ngets recorded with each job."}, {Name: "Root", Doc: "Root is the root path from user home dir on server.\nis auto-set to: filepath.Join(\"simserver\", Project, User)"}, {Name: "Slurm", Doc: "Slurm uses the slurm job manager. Otherwise uses a bare job manager."}}})
var _ = types.AddType(&types.Type{Name: "main.Configuration", IDName: "configuration", Doc: "Configuration holds all of the user-settable parameters", Fields: []types.Field{{Name: "DataRoot", Doc: "DataRoot is the path to the root of the data to browse."}, {Name: "StartDir", Doc: "StartDir is the starting directory, where the app was originally started."}, {Name: "User", Doc: "User id as in system login name (i.e., user@system)."}, {Name: "UserShort", Doc: "UserShort is the first 3 letters of User,\nfor naming jobs (auto-set from User)."}, {Name: "Project", Doc: "Project is the name of simulation project, lowercase\n(should be name of source dir)."}, {Name: "Package", Doc: "Package is the parent package: e.g., github.com/emer/axon/v2\nThis is used to update the go.mod, along with the Version."}, {Name: "Version", Doc: "Version is the current git version string, from git describe --tags."}, {Name: "Job", Doc: "Job has the parameters for job resources etc."}, {Name: "Server", Doc: "Server has server parameters."}, {Name: "GroupColumns", Doc: "GroupColumns are the column(s) to use for grouping result data, for PlotMean.\ne.g., Epoch for epoch-level results."}, {Name: "FetchFiles", Doc: "FetchFiles is a glob expression for files to fetch from server,\nfor Fetch command. Is *.tsv by default."}, {Name: "ExcludeNodes", Doc: "ExcludeNodes are nodes to exclude from job, based on what is slow."}, {Name: "ExtraFiles", Doc: "ExtraFiles has extra files to upload with job submit, from same dir."}, {Name: "ExtraDirs", Doc: "ExtraDirs has subdirs with other files to upload with job submit\n(non-code -- see CodeDirs)."}, {Name: "CodeDirs", Doc: "CodeDirs has subdirs with code to upload with job submit;\ngo.mod auto-updated to use."}, {Name: "ExtraGoGet", Doc: "ExtraGoGet is an extra package to do \"go get\" with, for launching the job."}, {Name: "JobScript", Doc: "JobScript is a job script to use for running the simulation,\ninstead of the basic default, if non-empty.\nThis is written to the job.sbatch file. If it contains a $JOB_ARGS string\nthen that is replaced with the args entered during submission.\nIf using slurm, this switches to a simple direct sbatch submission instead\nof the default parallel job submission. All standard slurm job parameters\nare automatically inserted at the start of the file, so this script should\njust be the actual job running actions after that point."}, {Name: "SetupScript", Doc: "SetupScript contains optional lines of bash script code to insert at\nthe start of the job submission script, which is then followed by\nthe default script. For example, if a symbolic link to a large shared\nresource is needed, make that link here."}, {Name: "TimeFormat", Doc: "TimeFormat is the format for timestamps,\ndefaults to \"2006-01-02 15:04:05 MST\""}, {Name: "Filter", Doc: "Filter has parameters for filtering results."}, {Name: "Submit", Doc: "Submit has parameters for submitting jobs; set from last job run."}}})
var _ = types.AddType(&types.Type{Name: "main.Result", IDName: "result", Doc: "Result has info for one loaded result, as a table.Table", Fields: []types.Field{{Name: "JobID", Doc: "job id for results"}, {Name: "Label", Doc: "short label used as a legend in the plot"}, {Name: "Message", Doc: "description of job"}, {Name: "Args", Doc: "args used in running job"}, {Name: "Path", Doc: "path to data"}, {Name: "Table", Doc: "result data"}}})
var _ = types.AddType(&types.Type{Name: "main.Simmer", IDName: "simmer", Doc: "Simmer manages the running and data analysis of results from simulations\nthat are run on remote server(s), within a Cogent Lab browser environment,\nwith the files as the left panel, and the Tabber as the right panel.", Methods: []types.Method{{Name: "FetchJobBare", Doc: "FetchJobBare downloads results files from bare metal server.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"jid", "force"}}, {Name: "EditConfig", Doc: "EditConfig edits the configuration", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Jobs", Doc: "Jobs updates the Jobs tab with a Table showing all the Jobs\nwith their meta data. Uses the dbmeta.toml data compiled from\nthe Status function.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "UpdateSims", Doc: "Jobs updates the Jobs tab with a Table showing all the Jobs\nwith their meta data. Uses the dbmeta.toml data compiled from\nthe Status function.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Queue", Doc: "Queue runs a queue query command on the server and shows the results.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Status", Doc: "Status gets updated job.* files from the server for any job that\ndoesn't have a Finalized or Fetched status. It updates the\nstatus based on the server job status query, assigning a\nstatus of Finalized if job is done. Updates the dbmeta.toml\ndata based on current job data.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Fetch", Doc: "Fetch retrieves all the .tsv data files from the server\nfor any jobs not already marked as Fetched.\nOperates on the jobs selected in the Jobs table,\nor on all jobs if none selected.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Cancel", Doc: "Cancel cancels the jobs selected in the Jobs table,\nwith a confirmation prompt.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Delete", Doc: "Delete deletes the selected Jobs, with a confirmation prompt.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Archive", Doc: "Archive moves the selected Jobs to the Archive directory,\nlocally, and deletes them from the server,\nfor results that are useful but not immediately relevant.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Recover", Doc: "Recover recovers the jobs selected in the Jobs table,\nwith a confirmation prompt.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Results", Doc: "Results loads specific .tsv data files from the jobs selected\nin the Jobs table, into the Results table. There are often\nmultiple result files per job, so this step is necessary to\nchoose which such files to select for plotting.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Diff", Doc: "Diff shows the differences between two selected jobs, or if only\none job is selected, between that job and the current source directory.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Plot", Doc: "Plot concatenates selected Results data files and generates a plot\nof the resulting data.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "PlotMean", Doc: "PlotMean concatenates selected Results data files and generates a plot\nof the resulting data, computing the mean over the values in\n[Config.GroupColumns] to group values (e.g., across Epochs).", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Reset", Doc: "Reset resets the Results table", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Submit", Doc: "Submit submits a job on the server.\nCreates a new job dir based on incrementing counter,\nsynchronizing the job files.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "Search", Doc: "Search runs parameter search jobs, one for each parameter.\nThe number of parameters is obtained by running the currently built\nsimulation executable locally with the -search-n argument, which\nreturns the total number of parameter searches to run.\nTHUS, YOU MUST BUILD THE LOCAL SIM WITH THE PARAM SEARCH CONFIGURED.\nThen, it launches that number of jobs with -search-at values from 1..n.\nThe jobs should write a job.label file for the searched parameter.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}}, Embeds: []types.Field{{Name: "Frame"}, {Name: "Browser"}}, Fields: []types.Field{{Name: "Config", Doc: "Config holds all the configuration settings."}, {Name: "JobsTableView", Doc: "JobsTableView is the view of the jobs table."}, {Name: "JobsTable", Doc: "JobsTable is the jobs Table with one row per job."}, {Name: "ResultsTableView", Doc: "ResultsTableView has the results table."}, {Name: "ResultsList", Doc: "ResultsList is the list of result records."}, {Name: "BareMetal", Doc: "BareMetal RPC client."}, {Name: "BareMetalActive", Doc: "Status info from BareMetal"}, {Name: "BareMetalActiveTable"}}})
// NewSimmer returns a new [Simmer] with the given optional parent:
// Simmer manages the running and data analysis of results from simulations
// that are run on remote server(s), within a Cogent Lab browser environment,
// with the files as the left panel, and the Tabber as the right panel.
func NewSimmer(parent ...tree.Node) *Simmer { return tree.New[Simmer](parent...) }
// SetConfig sets the [Simmer.Config]:
// Config holds all the configuration settings.
func (t *Simmer) SetConfig(v Configuration) *Simmer { t.Config = v; return t }
// SetJobsTableView sets the [Simmer.JobsTableView]:
// JobsTableView is the view of the jobs table.
func (t *Simmer) SetJobsTableView(v *tensorcore.Table) *Simmer { t.JobsTableView = v; return t }
// SetJobsTable sets the [Simmer.JobsTable]:
// JobsTable is the jobs Table with one row per job.
func (t *Simmer) SetJobsTable(v *table.Table) *Simmer { t.JobsTable = v; return t }
// SetResultsTableView sets the [Simmer.ResultsTableView]:
// ResultsTableView has the results table.
func (t *Simmer) SetResultsTableView(v *core.Table) *Simmer { t.ResultsTableView = v; return t }
// SetResultsList sets the [Simmer.ResultsList]:
// ResultsList is the list of result records.
func (t *Simmer) SetResultsList(v ...*Result) *Simmer { t.ResultsList = v; return t }
// SetBareMetal sets the [Simmer.BareMetal]:
// BareMetal RPC client.
func (t *Simmer) SetBareMetal(v *baremetal.Client) *Simmer { t.BareMetal = v; return t }
// SetBareMetalActive sets the [Simmer.BareMetalActive]:
// Status info from BareMetal
func (t *Simmer) SetBareMetalActive(v ...*baremetal.Job) *Simmer { t.BareMetalActive = v; return t }
// SetBareMetalActiveTable sets the [Simmer.BareMetalActiveTable]
func (t *Simmer) SetBareMetalActiveTable(v *core.Table) *Simmer { t.BareMetalActiveTable = v; return t }
var _ = types.AddFunc(&types.Func{Name: "main.main", Doc: "important: must be run from an interactive terminal.\nWill quit immediately if not!"})
var _ = types.AddFunc(&types.Func{Name: "main.Interactive", Doc: "Interactive is the cli function that gets called by default at gui startup.", Args: []string{"c", "in"}, Returns: []string{"error"}})
var _ = types.AddFunc(&types.Func{Name: "main.NewSimmerWindow", Doc: "NewSimmerWindow returns a new Simmer window using given interpreter.\ndo RunWindow on resulting [core.Body] to open the window.", Args: []string{"in"}, Returns: []string{"Body", "Simmer"}})
// Code generated by "core generate"; DO NOT EDIT.
package main
import (
"cogentcore.org/core/enums"
)
var _TimesValues = []Times{0, 1, 2}
// TimesN is the highest valid value for type Times, plus one.
const TimesN Times = 3
var _TimesValueMap = map[string]Times{`Trial`: 0, `Epoch`: 1, `Run`: 2}
var _TimesDescMap = map[Times]string{0: ``, 1: ``, 2: ``}
var _TimesMap = map[Times]string{0: `Trial`, 1: `Epoch`, 2: `Run`}
// String returns the string representation of this Times value.
func (i Times) String() string { return enums.String(i, _TimesMap) }
// SetString sets the Times value from its string representation,
// and returns an error if the string is invalid.
func (i *Times) SetString(s string) error { return enums.SetString(i, s, _TimesValueMap, "Times") }
// Int64 returns the Times value as an int64.
func (i Times) Int64() int64 { return int64(i) }
// SetInt64 sets the Times value from an int64.
func (i *Times) SetInt64(in int64) { *i = Times(in) }
// Desc returns the description of the Times value.
func (i Times) Desc() string { return enums.Desc(i, _TimesDescMap) }
// TimesValues returns all possible values for the type Times.
func TimesValues() []Times { return _TimesValues }
// Values returns all possible values for the type Times.
func (i Times) Values() []enums.Enum { return enums.Values(_TimesValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i Times) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *Times) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Times") }
var _LoopPhaseValues = []LoopPhase{0, 1}
// LoopPhaseN is the highest valid value for type LoopPhase, plus one.
const LoopPhaseN LoopPhase = 2
var _LoopPhaseValueMap = map[string]LoopPhase{`Start`: 0, `Step`: 1}
var _LoopPhaseDescMap = map[LoopPhase]string{0: `Start is the start of the loop: resets accumulated stats, initializes.`, 1: `Step is each iteration of the loop.`}
var _LoopPhaseMap = map[LoopPhase]string{0: `Start`, 1: `Step`}
// String returns the string representation of this LoopPhase value.
func (i LoopPhase) String() string { return enums.String(i, _LoopPhaseMap) }
// SetString sets the LoopPhase value from its string representation,
// and returns an error if the string is invalid.
func (i *LoopPhase) SetString(s string) error {
return enums.SetString(i, s, _LoopPhaseValueMap, "LoopPhase")
}
// Int64 returns the LoopPhase value as an int64.
func (i LoopPhase) Int64() int64 { return int64(i) }
// SetInt64 sets the LoopPhase value from an int64.
func (i *LoopPhase) SetInt64(in int64) { *i = LoopPhase(in) }
// Desc returns the description of the LoopPhase value.
func (i LoopPhase) Desc() string { return enums.Desc(i, _LoopPhaseDescMap) }
// LoopPhaseValues returns all possible values for the type LoopPhase.
func LoopPhaseValues() []LoopPhase { return _LoopPhaseValues }
// Values returns all possible values for the type LoopPhase.
func (i LoopPhase) Values() []enums.Enum { return enums.Values(_LoopPhaseValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i LoopPhase) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *LoopPhase) UnmarshalText(text []byte) error {
return enums.UnmarshalText(i, text, "LoopPhase")
}
// Code generated by "goal build"; DO NOT EDIT.
//line sim.goal:1
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
//go:generate core generate
import (
"math/rand/v2"
"cogentcore.org/core/core"
"cogentcore.org/lab/lab"
"cogentcore.org/lab/plot"
"cogentcore.org/lab/stats/stats"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorfs"
)
// Times are the looping time levels for running and statistics.
type Times int32 //enums:enum
const (
Trial Times = iota
Epoch
Run
)
// LoopPhase is the phase of loop processing for given time.
type LoopPhase int32 //enums:enum
const (
// Start is the start of the loop: resets accumulated stats, initializes.
Start LoopPhase = iota
// Step is each iteration of the loop.
Step
)
type Sim struct {
// Root is the root data dir.
Root *tensorfs.Node
// Config has config data.
Config *tensorfs.Node
// Stats has all stats data.
Stats *tensorfs.Node
// Current has current value of all stats
Current *tensorfs.Node
// StatFuncs are statistics functions, per stat, handles everything.
StatFuncs []func(ltime Times, lphase LoopPhase)
// Counters are current values of counters: normally in looper.
Counters [TimesN]int
}
// ConfigAll configures the sim
func (ss *Sim) ConfigAll() {
ss.Root, _ = tensorfs.NewDir("Root")
ss.Config = ss.Root.Dir("Config")
mx := tensorfs.Value[int](ss.Config, "Max", int(TimesN)).(*tensor.Int)
mx.Set1D(5, int(Trial))
mx.Set1D(4, int(Epoch))
mx.Set1D(3, int(Run))
// todo: failing - assigns 3 to all
// # mx[Trial] = 5
// # mx[Epoch] = 4
// # mx[Run] = 3
ss.ConfigStats()
}
func (ss *Sim) AddStat(f func(ltime Times, lphase LoopPhase)) {
ss.StatFuncs = append(ss.StatFuncs, f)
}
func (ss *Sim) RunStats(ltime Times, lphase LoopPhase) {
for _, sf := range ss.StatFuncs {
sf(ltime, lphase)
}
}
func (ss *Sim) ConfigStats() {
ss.Stats = ss.Root.Dir("Stats")
ss.Current = ss.Stats.Dir("Current")
ctrs := []Times{Run, Epoch, Trial}
for _, ctr := range ctrs {
ss.AddStat(func(ltime Times, lphase LoopPhase) {
if ltime > ctr { // don't record counter for time above it
return
}
name := ctr.String() // name of stat = counter
timeDir := ss.Stats.Dir(ltime.String())
tsr := tensorfs.Value[int](timeDir, name)
if lphase == Start {
tsr.SetNumRows(0)
plot.SetFirstStyler(tsr, func(s *plot.Style) {
s.Range.SetMin(0)
})
return
}
ctv := ss.Counters[ctr]
tensorfs.Scalar[int](ss.Current, name).SetInt1D(ctv, 0)
tsr.AppendRowInt(ctv)
})
}
// note: it is essential to only have 1 per func
// so generic names can be used for everything.
ss.AddStat(func(ltime Times, lphase LoopPhase) {
name := "SSE"
timeDir := ss.Stats.Dir(ltime.String())
tsr := timeDir.Float64(name)
if lphase == Start {
tsr.SetNumRows(0)
plot.SetFirstStyler(tsr, func(s *plot.Style) {
s.Range.SetMin(0).SetMax(1)
s.On = true
})
return
}
switch ltime {
case Trial:
stat := rand.Float64()
tensorfs.Scalar[float64](ss.Current, name).SetFloat(stat, 0)
tsr.AppendRowFloat(stat)
case Epoch:
subd := ss.Stats.Dir((ltime - 1).String())
stat := stats.StatMean.Call(subd.Float64(name))
tsr.AppendRow(stat)
case Run:
subd := ss.Stats.Dir((ltime - 1).String())
stat := stats.StatMean.Call(subd.Float64(name))
tsr.AppendRow(stat)
}
})
ss.AddStat(func(ltime Times, lphase LoopPhase) {
name := "Err"
timeDir := ss.Stats.Dir(ltime.String())
tsr := tensorfs.Value[float64](timeDir, name)
if lphase == Start {
tsr.SetNumRows(0)
plot.SetFirstStyler(tsr, func(s *plot.Style) {
s.Range.SetMin(0).SetMax(1)
s.On = true
})
return
}
switch ltime {
case Trial:
sse := tensorfs.Scalar[float64](ss.Current, "SSE").Float1D(0)
stat := 1.0
if sse < 0.5 {
stat = 0
}
tensorfs.Scalar[float64](ss.Current, name).SetFloat(stat, 0)
tsr.AppendRowFloat(stat)
case Epoch:
subd := ss.Stats.Dir((ltime - 1).String())
stat := stats.StatMean.Call(subd.Value(name))
tsr.AppendRow(stat)
case Run:
subd := ss.Stats.Dir((ltime - 1).String())
stat := stats.StatMean.Call(subd.Value(name))
tsr.AppendRow(stat)
}
})
}
func (ss *Sim) Run() {
mx := ss.Config.Value("Max").(*tensor.Int)
nrun := mx.Value1D(int(Run))
nepc := mx.Value1D(int(Epoch))
ntrl := mx.Value1D(int(Trial))
ss.RunStats(Run, Start)
for run := range nrun {
ss.Counters[Run] = run
ss.RunStats(Epoch, Start)
for epc := range nepc {
ss.Counters[Epoch] = epc
ss.RunStats(Trial, Start)
for trl := range ntrl {
ss.Counters[Trial] = trl
ss.RunStats(Trial, Step)
}
ss.RunStats(Epoch, Step)
}
ss.RunStats(Run, Step)
}
// todo: could do final analysis here
// alldt := ss.Logs.Item("AllTrials").GetDirTable(nil)
// dir := ss.Logs.Dir("Stats")
// stats.TableGroups(dir, alldt, "Run", "Epoch", "Trial")
// sts := []string{"SSE", "AvgSSE", "TrlErr"}
// stats.TableGroupStats(dir, stats.StatMean, alldt, sts...)
// stats.TableGroupStats(dir, stats.StatSem, alldt, sts...)
}
func main() {
ss := &Sim{}
ss.ConfigAll()
ss.Run()
b, _ := lab.NewBasicWindow(ss.Root, "Root")
b.RunWindow()
core.Wait()
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package goal
import (
"context"
"fmt"
"log/slog"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/exec"
"cogentcore.org/core/base/logx"
"cogentcore.org/core/base/sshclient"
"cogentcore.org/core/base/stringsx"
"github.com/mitchellh/go-homedir"
)
// InstallBuiltins adds the builtin goal commands to [Goal.Builtins].
func (gl *Goal) InstallBuiltins() {
gl.Builtins = make(map[string]func(cmdIO *exec.CmdIO, args ...string) error)
gl.Builtins["cd"] = gl.Cd
gl.Builtins["exit"] = gl.Exit
gl.Builtins["jobs"] = gl.JobsCmd
gl.Builtins["kill"] = gl.Kill
gl.Builtins["set"] = gl.Set
gl.Builtins["unset"] = gl.Unset
gl.Builtins["add-path"] = gl.AddPath
gl.Builtins["which"] = gl.Which
gl.Builtins["source"] = gl.Source
gl.Builtins["gossh"] = gl.GoSSH
gl.Builtins["scp"] = gl.Scp
gl.Builtins["debug"] = gl.Debug
gl.Builtins["history"] = gl.History
}
// Cd changes the current directory.
func (gl *Goal) Cd(cmdIO *exec.CmdIO, args ...string) error {
if len(args) > 1 {
return fmt.Errorf("no more than one argument can be passed to cd")
}
dir := ""
if len(args) == 1 {
dir = args[0]
}
dir, err := homedir.Expand(dir)
if err != nil {
return err
}
if dir == "" {
dir, err = homedir.Dir()
if err != nil {
return err
}
}
dir, err = filepath.Abs(dir)
if err != nil {
return err
}
err = os.Chdir(dir)
if err != nil {
return err
}
gl.Config.Dir = dir
return nil
}
// Exit exits the shell.
func (gl *Goal) Exit(cmdIO *exec.CmdIO, args ...string) error {
os.Exit(0)
return nil
}
// Set sets the given environment variable to the given value.
func (gl *Goal) Set(cmdIO *exec.CmdIO, args ...string) error {
if len(args) != 2 {
return fmt.Errorf("expected two arguments, got %d", len(args))
}
val := args[1]
if strings.Count(val, ":") > 1 || strings.Contains(val, "~") {
vl := stringsx.UniqueList(strings.Split(val, ":"))
vl = AddHomeExpand([]string{}, vl...)
val = strings.Join(vl, ":")
}
err := os.Setenv(args[0], val)
if runtime.GOOS == "darwin" {
gl.Config.RunIO(cmdIO, "/bin/launchctl", "setenv", args[0], val)
}
return err
}
// Unset un-sets the given environment variable.
func (gl *Goal) Unset(cmdIO *exec.CmdIO, args ...string) error {
if len(args) != 1 {
return fmt.Errorf("expected one argument, got %d", len(args))
}
err := os.Unsetenv(args[0])
if runtime.GOOS == "darwin" {
gl.Config.RunIO(cmdIO, "/bin/launchctl", "unsetenv", args[0])
}
return err
}
// JobsCmd is the builtin jobs command
func (gl *Goal) JobsCmd(cmdIO *exec.CmdIO, args ...string) error {
for i, jb := range gl.Jobs {
cmdIO.Printf("[%d] %s\n", i+1, jb.String())
}
return nil
}
// Kill kills a job by job number or PID.
// Just expands the job id expressions %n into PIDs and calls system kill.
func (gl *Goal) Kill(cmdIO *exec.CmdIO, args ...string) error {
if len(args) == 0 {
return fmt.Errorf("goal kill: expected at least one argument")
}
gl.JobIDExpand(args)
gl.Config.RunIO(cmdIO, "kill", args...)
return nil
}
// Fg foregrounds a job by job number
func (gl *Goal) Fg(cmdIO *exec.CmdIO, args ...string) error {
if len(args) != 1 {
return fmt.Errorf("goal fg: requires exactly one job id argument")
}
jid := args[0]
exp := gl.JobIDExpand(args)
if exp != 1 {
return fmt.Errorf("goal fg: argument was not a job id in the form %%n")
}
jno, _ := strconv.Atoi(jid[1:]) // guaranteed good
job := gl.Jobs[jno]
cmdIO.Printf("foregrounding job [%d]\n", jno)
_ = job
// todo: the problem here is we need to change the stdio for running job
// job.Cmd.Wait() // wait
// * probably need to have wrapper StdIO for every exec so we can flexibly redirect for fg, bg commands.
// * likewise, need to run everything effectively as a bg job with our own explicit Wait, which we can then communicate with to move from fg to bg.
return nil
}
// AddHomeExpand adds given strings to the given list of strings,
// expanding any ~ symbols with the home directory,
// and returns the updated list.
func AddHomeExpand(list []string, adds ...string) []string {
for _, add := range adds {
add, err := homedir.Expand(add)
errors.Log(err)
has := false
for _, s := range list {
if s == add {
has = true
}
}
if !has {
list = append(list, add)
}
}
return list
}
// AddPath adds the given path(s) to $PATH.
func (gl *Goal) AddPath(cmdIO *exec.CmdIO, args ...string) error {
if len(args) == 0 {
return fmt.Errorf("goal add-path expected at least one argument")
}
path := os.Getenv("PATH")
ps := strings.Split(path, ":")
ps = stringsx.UniqueList(ps)
ps = AddHomeExpand(ps, args...)
path = strings.Join(ps, ":")
err := os.Setenv("PATH", path)
// if runtime.GOOS == "darwin" {
// this is what would be required to work:
// sudo launchctl config user path $PATH -- the following does not work:
// gl.Config.RunIO(cmdIO, "/bin/launchctl", "setenv", "PATH", path)
// }
return err
}
// Which reports the executable associated with the given command.
// Processes builtins and commands, and if not found, then passes on
// to exec which.
func (gl *Goal) Which(cmdIO *exec.CmdIO, args ...string) error {
if len(args) != 1 {
return fmt.Errorf("goal which: requires one argument")
}
cmd := args[0]
if _, hasCmd := gl.Commands[cmd]; hasCmd {
cmdIO.Println(cmd, "is a user-defined command")
return nil
}
if _, hasBlt := gl.Builtins[cmd]; hasBlt {
cmdIO.Println(cmd, "is a goal builtin command")
return nil
}
gl.Config.RunIO(cmdIO, "which", args...)
return nil
}
// Source loads and evaluates the given file(s)
func (gl *Goal) Source(cmdIO *exec.CmdIO, args ...string) error {
if len(args) == 0 {
return fmt.Errorf("goal source: requires at least one argument")
}
for _, fn := range args {
gl.TranspileCodeFromFile(fn)
}
// note that we do not execute the file -- just loads it in
return nil
}
// GoSSH manages SSH connections, which are referenced by the @name
// identifier. It handles the following cases:
// - @name -- switches to using given host for all subsequent commands
// - host [name] -- connects to a server specified in first arg and switches
// to using it, with optional name instead of default sequential number.
// - close -- closes all open connections, or the specified one
func (gl *Goal) GoSSH(cmdIO *exec.CmdIO, args ...string) error {
if len(args) < 1 {
return fmt.Errorf("gossh: requires at least one argument")
}
cmd := args[0]
var err error
host := ""
name := fmt.Sprintf("%d", 1+len(gl.SSHClients))
con := false
switch {
case cmd == "close":
gl.CloseSSH()
return nil
case cmd == "@" && len(args) == 2:
name = args[1]
case len(args) == 2:
con = true
host = args[0]
name = args[1]
default:
con = true
host = args[0]
}
if con {
cl := sshclient.NewClient(gl.SSH)
err = cl.Connect(host)
if err != nil {
return err
}
gl.SSHClients[name] = cl
gl.SSHActive = name
} else {
if name == "0" {
gl.SSHActive = ""
} else {
gl.SSHActive = name
cl := gl.ActiveSSH()
if cl == nil {
err = fmt.Errorf("goal: ssh connection named: %q not found", name)
}
}
}
return err
}
// Scp performs file copy over SSH connection, with the remote filename
// prefixed with the @name: and the local filename un-prefixed.
// The order is from -> to, as in standard cp.
// The remote filename is automatically relative to the current working
// directory on the remote host.
func (gl *Goal) Scp(cmdIO *exec.CmdIO, args ...string) error {
if len(args) != 2 {
return fmt.Errorf("scp: requires exactly two arguments")
}
var lfn, hfn string
toHost := false
if args[0][0] == '@' {
hfn = args[0]
lfn = args[1]
} else if args[1][0] == '@' {
hfn = args[1]
lfn = args[0]
toHost = true
} else {
return fmt.Errorf("scp: one of the files must a remote host filename, specified by @name:")
}
ci := strings.Index(hfn, ":")
if ci < 0 {
return fmt.Errorf("scp: remote host filename does not contain a : after the host name")
}
host := hfn[1:ci]
hfn = hfn[ci+1:]
cl, err := gl.SSHByHost(host)
if err != nil {
return err
}
ctx := gl.Ctx
if ctx == nil {
ctx = context.Background()
}
if toHost {
err = cl.CopyLocalFileToHost(ctx, lfn, hfn)
} else {
err = cl.CopyHostToLocalFile(ctx, hfn, lfn)
}
return err
}
// Debug changes log level
func (gl *Goal) Debug(cmdIO *exec.CmdIO, args ...string) error {
if len(args) == 0 {
if logx.UserLevel == slog.LevelDebug {
logx.UserLevel = slog.LevelInfo
} else {
logx.UserLevel = slog.LevelDebug
}
}
if len(args) == 1 {
lev := args[0]
if lev == "on" || lev == "true" || lev == "1" {
logx.UserLevel = slog.LevelDebug
} else {
logx.UserLevel = slog.LevelInfo
}
}
return nil
}
// History shows history
func (gl *Goal) History(cmdIO *exec.CmdIO, args ...string) error {
n := len(gl.Hist)
nh := n
if len(args) == 1 {
an, err := strconv.Atoi(args[0])
if err != nil {
return fmt.Errorf("history: error parsing number of history items: %q, error: %s", args[0], err.Error())
}
nh = min(n, an)
} else if len(args) > 1 {
return fmt.Errorf("history: uses at most one argument")
}
for i := n - nh; i < n; i++ {
cmdIO.Printf("%d:\t%s\n", i, gl.Hist[i])
}
return nil
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Command goal is an interactive cli for running and compiling Goal code.
package main
import (
"cogentcore.org/core/cli"
"cogentcore.org/lab/goal/interpreter"
)
func main() { //types:skip
opts := cli.DefaultOptions("goal", "An interactive tool for running and compiling Goal (Go augmented language).")
cfg := &interpreter.Config{}
cfg.InteractiveFunc = interpreter.Interactive
cli.Run(opts, cfg, interpreter.Run, interpreter.Build)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package goal
import (
"os"
"path/filepath"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/icons"
"cogentcore.org/core/text/parse/complete"
"github.com/mitchellh/go-homedir"
)
// CompleteMatch is the [complete.MatchFunc] for the shell.
func (gl *Goal) CompleteMatch(data any, text string, posLine, posChar int) (md complete.Matches) {
comps := complete.Completions{}
text = text[:posChar]
md.Seed = complete.SeedPath(text)
fullPath := complete.SeedSpace(text)
fullPath = errors.Log1(homedir.Expand(fullPath))
parent := strings.TrimSuffix(fullPath, md.Seed)
dir := filepath.Join(gl.Config.Dir, parent)
if filepath.IsAbs(parent) {
dir = parent
}
entries := errors.Log1(os.ReadDir(dir))
for _, entry := range entries {
icon := icons.File
if entry.IsDir() {
icon = icons.Folder
}
name := strings.ReplaceAll(entry.Name(), " ", `\ `) // escape spaces
comps = append(comps, complete.Completion{
Text: name,
Icon: icon,
Desc: filepath.Join(gl.Config.Dir, name),
})
}
if parent == "" {
for cmd := range gl.Builtins {
comps = append(comps, complete.Completion{
Text: cmd,
Icon: icons.Terminal,
Desc: "Builtin command: " + cmd,
})
}
for cmd := range gl.Commands {
comps = append(comps, complete.Completion{
Text: cmd,
Icon: icons.Terminal,
Desc: "Command: " + cmd,
})
}
// todo: write something that looks up all files on path -- should cache that per
// path string setting
}
md.Matches = complete.MatchSeedCompletion(comps, md.Seed)
return md
}
// CompleteEdit is the [complete.EditFunc] for the shell.
func (gl *Goal) CompleteEdit(data any, text string, cursorPos int, completion complete.Completion, seed string) (ed complete.Edit) {
return complete.EditWord(text, cursorPos, completion.Text, seed)
}
// ReadlineCompleter implements [github.com/cogentcore/readline.AutoCompleter].
type ReadlineCompleter struct {
Goal *Goal
}
func (rc *ReadlineCompleter) Do(line []rune, pos int) (newLine [][]rune, length int) {
text := string(line)
md := rc.Goal.CompleteMatch(nil, text, 0, pos)
res := [][]rune{}
for _, match := range md.Matches {
after := strings.TrimPrefix(match.Text, md.Seed)
if md.Seed != "" && after == match.Text {
continue // no overlap
}
if match.Icon == icons.Folder {
after += string(filepath.Separator)
} else {
after += " "
}
res = append(res, []rune(after))
}
return res, len(md.Seed)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package goal
import (
"bytes"
"fmt"
"os"
"path/filepath"
"slices"
"strings"
"sync"
"time"
"cogentcore.org/core/base/exec"
"cogentcore.org/core/base/reflectx"
"cogentcore.org/core/base/sshclient"
"github.com/mitchellh/go-homedir"
)
// Exec handles command execution for all cases, parameterized by the args.
// It executes the given command string, waiting for the command to finish,
// handling the given arguments appropriately.
// If there is any error, it adds it to the goal, and triggers CancelExecution.
// - errOk = don't call AddError so execution will not stop on error
// - start = calls Start on the command, which then runs asynchronously, with
// a goroutine forked to Wait for it and close its IO
// - output = return the output of the command as a string (otherwise return is "")
func (gl *Goal) Exec(errOk, start, output bool, cmd any, args ...any) string {
out := ""
if !errOk && len(gl.Errors) > 0 {
return out
}
cmdIO := exec.NewCmdIO(&gl.Config)
cmdIO.StackStart()
if start {
cmdIO.PushIn(nil) // no stdin for bg
}
cl, scmd, sargs := gl.ExecArgs(cmdIO, errOk, cmd, args...)
if scmd == "" {
return out
}
var err error
if cl != nil {
switch {
case scmd == "set":
if len(sargs) != 2 {
err := fmt.Errorf("expected two arguments, got %d", len(sargs))
if !errOk {
gl.AddError(err)
}
return ""
}
cl.SetEnv(sargs[0], sargs[1])
case start:
err = cl.Start(&cmdIO.StdIOState, scmd, sargs...)
case output:
cmdIO.PushOut(nil)
out, err = cl.Output(&cmdIO.StdIOState, scmd, sargs...)
default:
err = cl.Run(&cmdIO.StdIOState, scmd, sargs...)
}
if !errOk {
gl.AddError(err)
}
} else {
ran := false
ran, out = gl.RunBuiltinOrCommand(cmdIO, errOk, start, output, scmd, sargs...)
if !ran {
gl.isCommand.Push(false)
switch {
case start:
// fmt.Fprintf(gl.debugTrace, "start exe %s in: %#v out: %#v %v\n ", scmd, cmdIO.In, cmdIO.Out, cmdIO.OutIsPipe())
err = gl.Config.StartIO(cmdIO, scmd, sargs...)
job := &Job{CmdIO: cmdIO}
gl.Jobs.Push(job)
go func() {
if !cmdIO.OutIsPipe() {
fmt.Printf("[%d] %s\n", len(gl.Jobs), cmdIO.String())
}
cmdIO.Cmd.Wait()
cmdIO.PopToStart()
gl.DeleteJob(job)
}()
case output:
cmdIO.PushOut(nil)
out, err = gl.Config.OutputIO(cmdIO, scmd, sargs...)
default:
// fmt.Fprintf(gl.debugTrace, "run exe %s in: %#v out: %#v %v\n ", scmd, cmdIO.In, cmdIO.Out, cmdIO.OutIsPipe())
err = gl.Config.RunIO(cmdIO, scmd, sargs...)
}
if !errOk {
gl.AddError(err)
}
gl.isCommand.Pop()
}
}
if !start {
cmdIO.PopToStart()
}
return out
}
// RunBuiltinOrCommand runs a builtin or a command, returning true if it ran,
// and the output string if running in output mode.
func (gl *Goal) RunBuiltinOrCommand(cmdIO *exec.CmdIO, errOk, start, output bool, cmd string, args ...string) (bool, string) {
out := ""
cmdFun, hasCmd := gl.Commands[cmd]
bltFun, hasBlt := gl.Builtins[cmd]
if !hasCmd && !hasBlt {
return false, out
}
if hasCmd {
gl.commandArgs.Push(args)
gl.isCommand.Push(true)
}
// note: we need to set both os. and wrapper versions, so it works the same
// in compiled vs. interpreted mode
var oldsh, oldwrap, oldstd *exec.StdIO
save := func() {
oldsh = gl.Config.StdIO.Set(&cmdIO.StdIO)
oldwrap = gl.StdIOWrappers.SetWrappers(&cmdIO.StdIO)
oldstd = cmdIO.SetToOS()
}
done := func() {
if hasCmd {
gl.isCommand.Pop()
gl.commandArgs.Pop()
}
// fmt.Fprintf(gl.debugTrace, "%s restore %#v\n", cmd, oldstd.In)
oldstd.SetToOS()
gl.StdIOWrappers.SetWrappers(oldwrap)
gl.Config.StdIO = *oldsh
}
switch {
case start:
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
if !cmdIO.OutIsPipe() {
fmt.Printf("[%d] %s\n", len(gl.Jobs), cmd)
}
if hasCmd {
oldwrap = gl.StdIOWrappers.SetWrappers(&cmdIO.StdIO)
// oldstd = cmdIO.SetToOS()
// fmt.Fprintf(gl.debugTrace, "%s oldstd in: %#v out: %#v\n", cmd, oldstd.In, oldstd.Out)
cmdFun(args...)
// oldstd.SetToOS()
gl.StdIOWrappers.SetWrappers(oldwrap)
gl.isCommand.Pop()
gl.commandArgs.Pop()
} else {
gl.AddError(bltFun(cmdIO, args...))
}
time.Sleep(time.Millisecond)
wg.Done()
}()
// fmt.Fprintf(gl.debugTrace, "%s push: %#v out: %#v %v\n", cmd, cmdIO.In, cmdIO.Out, cmdIO.OutIsPipe())
job := &Job{CmdIO: cmdIO}
gl.Jobs.Push(job)
go func() {
wg.Wait()
cmdIO.PopToStart()
gl.DeleteJob(job)
}()
case output:
save()
obuf := &bytes.Buffer{}
// os.Stdout = obuf // needs a file
gl.Config.StdIO.Out = obuf
gl.StdIOWrappers.SetWrappedOut(obuf)
cmdIO.PushOut(obuf)
if hasCmd {
cmdFun(args...)
} else {
gl.AddError(bltFun(cmdIO, args...))
}
out = strings.TrimSuffix(obuf.String(), "\n")
done()
default:
save()
if hasCmd {
cmdFun(args...)
} else {
gl.AddError(bltFun(cmdIO, args...))
}
done()
}
return true, out
}
func (gl *Goal) HandleArgErr(errok bool, err error) error {
if err == nil {
return err
}
if errok {
gl.Config.StdIO.ErrPrintln(err.Error())
} else {
gl.AddError(err)
}
return err
}
// ExecArgs processes the args to given exec command,
// handling all of the input / output redirection and
// file globbing, homedir expansion, etc.
func (gl *Goal) ExecArgs(cmdIO *exec.CmdIO, errOk bool, cmd any, args ...any) (*sshclient.Client, string, []string) {
if len(gl.Jobs) > 0 {
jb := gl.Jobs.Peek()
if jb.OutIsPipe() && !jb.GotPipe {
jb.GotPipe = true
cmdIO.PushIn(jb.PipeIn.Peek())
}
}
scmd := reflectx.ToString(cmd)
cl := gl.ActiveSSH()
// isCmd := gl.isCommand.Peek()
sargs := make([]string, 0, len(args))
var err error
homeDir := func(s string) string {
if cl == nil {
s, err = homedir.Expand(s)
gl.HandleArgErr(errOk, err)
// note: handling globbing in a later pass, to not clutter..
} else {
if s[0] == '~' {
s = "$HOME/" + s[1:]
}
}
return s
}
for _, a := range args {
if sa, ok := a.([]string); ok {
for _, s := range sa {
if s == "" {
continue
}
s = homeDir(s)
sargs = append(sargs, s)
}
continue
}
if sa, ok := a.([]any); ok {
for _, aa := range sa {
s := reflectx.ToString(aa)
if s == "" {
continue
}
s = homeDir(s)
sargs = append(sargs, s)
}
continue
}
s := reflectx.ToString(a)
if s == "" {
continue
}
s = homeDir(s)
sargs = append(sargs, s)
}
if scmd[0] == '@' {
newHost := ""
if scmd == "@0" { // local
cl = nil
} else {
hnm := scmd[1:]
if scl, ok := gl.SSHClients[hnm]; ok {
newHost = hnm
cl = scl
} else {
gl.HandleArgErr(errOk, fmt.Errorf("goal: ssh connection named: %q not found", hnm))
}
}
if len(sargs) > 0 {
scmd = sargs[0]
sargs = sargs[1:]
} else { // just a ssh switch
gl.SSHActive = newHost
return nil, "", nil
}
}
for i := 0; i < len(sargs); i++ { // we modify so no range
s := sargs[i]
switch {
case s[0] == '>':
sargs = gl.OutToFile(cl, cmdIO, errOk, sargs, i)
case s[0] == '|':
sargs = gl.OutToPipe(cl, cmdIO, errOk, sargs, i)
case cl == nil && strings.HasPrefix(s, "args"):
sargs = gl.CmdArgs(errOk, sargs, i)
i-- // back up because we consume this one
}
}
// do globbing late here so we don't have to wade through everything.
// only for local.
if cl == nil {
gargs := make([]string, 0, len(sargs))
for _, s := range sargs {
g, err := filepath.Glob(s)
if err != nil || len(g) == 0 { // not valid
gargs = append(gargs, s)
} else {
gargs = append(gargs, g...)
}
}
sargs = gargs
}
return cl, scmd, sargs
}
// OutToFile processes the > arg that sends output to a file
func (gl *Goal) OutToFile(cl *sshclient.Client, cmdIO *exec.CmdIO, errOk bool, sargs []string, i int) []string {
n := len(sargs)
s := sargs[i]
sn := len(s)
fn := ""
narg := 1
if i < n-1 {
fn = sargs[i+1]
narg = 2
}
appn := false
errf := false
switch {
case sn > 1 && s[1] == '>':
appn = true
if sn > 2 && s[2] == '&' {
errf = true
}
case sn > 1 && s[1] == '&':
errf = true
case sn > 1:
fn = s[1:]
narg = 1
}
if fn == "" {
gl.HandleArgErr(errOk, fmt.Errorf("goal: no output file specified"))
return sargs
}
if cl != nil {
if !strings.HasPrefix(fn, "@0:") {
return sargs
}
fn = fn[3:]
}
sargs = slices.Delete(sargs, i, i+narg)
// todo: process @n: expressions here -- if @0 then it is the same
// if @1, then need to launch an ssh "cat >[>] file" with pipe from command as stdin
var f *os.File
var err error
if appn {
f, err = os.OpenFile(fn, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
} else {
f, err = os.Create(fn)
}
if err == nil {
cmdIO.PushOut(f)
if errf {
cmdIO.PushErr(f)
}
} else {
gl.HandleArgErr(errOk, err)
}
return sargs
}
// OutToPipe processes the | arg that sends output to a pipe
func (gl *Goal) OutToPipe(cl *sshclient.Client, cmdIO *exec.CmdIO, errOk bool, sargs []string, i int) []string {
s := sargs[i]
sn := len(s)
errf := false
if sn > 1 && s[1] == '&' {
errf = true
}
sargs = slices.Delete(sargs, i, i+1)
cmdIO.PushOutPipe()
if errf {
cmdIO.PushErr(cmdIO.Out)
}
return sargs
}
// CmdArgs processes expressions involving "args" for commands
func (gl *Goal) CmdArgs(errOk bool, sargs []string, i int) []string {
// n := len(sargs)
// s := sargs[i]
// sn := len(s)
args := gl.commandArgs.Peek()
// fmt.Println("command args:", args)
switch {
case sargs[i] == "args...":
sargs = slices.Delete(sargs, i, i+1)
sargs = slices.Insert(sargs, i, args...)
}
return sargs
}
// CancelExecution calls the Cancel() function if set.
func (gl *Goal) CancelExecution() {
if gl.Cancel != nil {
gl.Cancel()
}
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package goal provides the Goal Go augmented language transpiler,
// which combines the best parts of Go, bash, and Python to provide
// an integrated shell and numerical expression processing experience.
package goal
import (
"context"
"fmt"
"io/fs"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/exec"
"cogentcore.org/core/base/logx"
"cogentcore.org/core/base/reflectx"
"cogentcore.org/core/base/sshclient"
"cogentcore.org/core/base/stack"
"cogentcore.org/lab/goal/transpile"
"github.com/mitchellh/go-homedir"
)
// Goal represents one running Goal language context.
type Goal struct {
// Config is the [exec.Config] used to run commands.
Config exec.Config
// StdIOWrappers are IO wrappers sent to the interpreter, so we can
// control the IO streams used within the interpreter.
// Call SetWrappers on this with another StdIO object to update settings.
StdIOWrappers exec.StdIO
// ssh connection, configuration
SSH *sshclient.Config
// collection of ssh clients
SSHClients map[string]*sshclient.Client
// SSHActive is the name of the active SSH client
SSHActive string
// Builtins are all the builtin shell commands
Builtins map[string]func(cmdIO *exec.CmdIO, args ...string) error
// commands that have been defined, which can be run in Exec mode.
Commands map[string]func(args ...string)
// Jobs is a stack of commands running in the background
// (via Start instead of Run)
Jobs stack.Stack[*Job]
// Cancel, while the interpreter is running, can be called
// to stop the code interpreting.
// It is connected to the Ctx context, by StartContext()
// Both can be nil.
Cancel func()
// Errors is a stack of runtime errors.
Errors []error
// CliArgs are input arguments from the command line.
CliArgs []any
// Ctx is the context used for cancelling current shell running
// a single chunk of code, typically from the interpreter.
// We are not able to pass the context around so it is set here,
// in the StartContext function. Clear when done with ClearContext.
Ctx context.Context
// original standard IO setings, to restore
OrigStdIO exec.StdIO
// Hist is the accumulated list of command-line input,
// which is displayed with the history builtin command,
// and saved / restored from ~/.goalhist file
Hist []string
// transpiling state
TrState transpile.State
// commandArgs is a stack of args passed to a command, used for simplified
// processing of args expressions.
commandArgs stack.Stack[[]string]
// isCommand is a stack of bools indicating whether the _immediate_ run context
// is a command, which affects the way that args are processed.
isCommand stack.Stack[bool]
// debugTrace is a file written to for debugging
debugTrace *os.File
}
// NewGoal returns a new [Goal] with default options.
func NewGoal() *Goal {
gl := &Goal{
Config: exec.Config{
Dir: errors.Log1(os.Getwd()),
Env: map[string]string{},
Buffer: false,
},
}
gl.TrState.FuncToVar = true
gl.Config.StdIO.SetFromOS()
gl.SSH = sshclient.NewConfig(&gl.Config)
gl.SSHClients = make(map[string]*sshclient.Client)
gl.Commands = make(map[string]func(args ...string))
gl.InstallBuiltins()
// gl.debugTrace, _ = os.Create("goal.debug") // debugging
return gl
}
// StartContext starts a processing context,
// setting the Ctx and Cancel Fields.
// Call EndContext when current operation finishes.
func (gl *Goal) StartContext() context.Context {
gl.Ctx, gl.Cancel = context.WithCancel(context.Background())
return gl.Ctx
}
// EndContext ends a processing context, clearing the
// Ctx and Cancel fields.
func (gl *Goal) EndContext() {
gl.Ctx = nil
gl.Cancel = nil
}
// SaveOrigStdIO saves the current Config.StdIO as the original to revert to
// after an error, and sets the StdIOWrappers to use them.
func (gl *Goal) SaveOrigStdIO() {
gl.OrigStdIO = gl.Config.StdIO
gl.StdIOWrappers.NewWrappers(&gl.OrigStdIO)
}
// RestoreOrigStdIO reverts to using the saved OrigStdIO
func (gl *Goal) RestoreOrigStdIO() {
gl.Config.StdIO = gl.OrigStdIO
gl.OrigStdIO.SetToOS()
gl.StdIOWrappers.SetWrappers(&gl.OrigStdIO)
}
// Close closes any resources associated with the shell,
// including terminating any commands that are not running "nohup"
// in the background.
func (gl *Goal) Close() {
gl.CloseSSH()
// todo: kill jobs etc
}
// CloseSSH closes all open ssh client connections
func (gl *Goal) CloseSSH() {
gl.SSHActive = ""
for _, cl := range gl.SSHClients {
cl.Close()
}
gl.SSHClients = make(map[string]*sshclient.Client)
}
// ActiveSSH returns the active ssh client
func (gl *Goal) ActiveSSH() *sshclient.Client {
if gl.SSHActive == "" {
return nil
}
return gl.SSHClients[gl.SSHActive]
}
// Host returns the name we're running commands on,
// which is empty if localhost (default).
func (gl *Goal) Host() string {
cl := gl.ActiveSSH()
if cl == nil {
return ""
}
return "@" + gl.SSHActive + ":" + cl.Host
}
// HostAndDir returns the name we're running commands on,
// which is empty if localhost (default),
// and the current directory on that host.
func (gl *Goal) HostAndDir() string {
host := ""
dir := gl.Config.Dir
home := errors.Log1(homedir.Dir())
cl := gl.ActiveSSH()
if cl != nil {
host = "@" + gl.SSHActive + ":" + cl.Host + ":"
dir = cl.Dir
home = cl.HomeDir
}
rel := errors.Log1(filepath.Rel(home, dir))
// if it has to go back, then it is not in home dir, so no ~
if strings.Contains(rel, "..") {
return host + dir + string(filepath.Separator)
}
return host + filepath.Join("~", rel) + string(filepath.Separator)
}
// SSHByHost returns the SSH client for given host name, with err if not found
func (gl *Goal) SSHByHost(host string) (*sshclient.Client, error) {
if scl, ok := gl.SSHClients[host]; ok {
return scl, nil
}
return nil, fmt.Errorf("ssh connection named: %q not found", host)
}
// TranspileCode processes each line of given code,
// adding the results to the LineStack
func (gl *Goal) TranspileCode(code string) {
gl.TrState.TranspileCode(code)
}
// TranspileCodeFromFile transpiles the code in given file
func (gl *Goal) TranspileCodeFromFile(file string) error {
b, err := os.ReadFile(file)
if err != nil {
return err
}
gl.TranspileCode(string(b))
return nil
}
// TranspileFile transpiles the given input goal file to the
// given output Go file. If no existing package declaration
// is found, then package main and func main declarations are
// added. This also affects how functions are interpreted.
func (gl *Goal) TranspileFile(in string, out string) error {
return gl.TrState.TranspileFile(in, out)
}
// AddError adds the given error to the error stack if it is non-nil,
// and calls the Cancel function if set, to stop execution.
// This is the main way that goal errors are handled.
// It also prints the error.
func (gl *Goal) AddError(err error) error {
if err == nil {
return nil
}
gl.Errors = append(gl.Errors, err)
logx.PrintlnError(err)
gl.CancelExecution()
return err
}
// TranspileConfig transpiles the .goal startup config file in the user's
// home directory if it exists.
func (gl *Goal) TranspileConfig() error {
path, err := homedir.Expand("~/.goal")
if err != nil {
return err
}
b, err := os.ReadFile(path)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil
}
return err
}
gl.TranspileCode(string(b))
return nil
}
// AddHistory adds given line to the Hist record of commands
func (gl *Goal) AddHistory(line string) {
gl.Hist = append(gl.Hist, line)
}
// SaveHistory saves up to the given number of lines of current history
// to given file, e.g., ~/.goalhist for the default goal program.
// If n is <= 0 all lines are saved. n is typically 500 by default.
func (gl *Goal) SaveHistory(n int, file string) error {
path, err := homedir.Expand(file)
if err != nil {
return err
}
hn := len(gl.Hist)
sn := hn
if n > 0 {
sn = min(n, hn)
}
lh := strings.Join(gl.Hist[hn-sn:hn], "\n")
err = os.WriteFile(path, []byte(lh), 0666)
if err != nil {
return err
}
return nil
}
// OpenHistory opens Hist history lines from given file,
// e.g., ~/.goalhist
func (gl *Goal) OpenHistory(file string) error {
path, err := homedir.Expand(file)
if err != nil {
return err
}
b, err := os.ReadFile(path)
if err != nil {
return err
}
gl.Hist = strings.Split(string(b), "\n")
return nil
}
// Args returns the command line arguments.
func (gl *Goal) Args() []any {
return gl.CliArgs
}
// AddCommand adds given command to list of available commands.
func (gl *Goal) AddCommand(name string, cmd func(args ...string)) {
gl.Commands[name] = cmd
}
// RunCommands runs the given command(s). This is typically called
// from a Makefile-style goal script.
func (gl *Goal) RunCommands(cmds []any) error {
for _, cmd := range cmds {
if cmdFun, hasCmd := gl.Commands[reflectx.ToString(cmd)]; hasCmd {
cmdFun()
} else {
return errors.Log(fmt.Errorf("command %q not found", cmd))
}
}
return nil
}
// DeleteAllJobs deletes any existing jobs, closing stdio.
func (gl *Goal) DeleteAllJobs() {
n := len(gl.Jobs)
for i := n - 1; i >= 0; i-- {
jb := gl.Jobs.Pop()
jb.CmdIO.PopToStart()
}
}
// DeleteJob deletes the given job and returns true if successful,
func (gl *Goal) DeleteJob(job *Job) bool {
idx := slices.Index(gl.Jobs, job)
if idx >= 0 {
gl.Jobs = slices.Delete(gl.Jobs, idx, idx+1)
return true
}
return false
}
// JobIDExpand expands %n job id values in args with the full PID
// returns number of PIDs expanded
func (gl *Goal) JobIDExpand(args []string) int {
exp := 0
for i, id := range args {
if id[0] == '%' {
idx, err := strconv.Atoi(id[1:])
if err == nil {
if idx > 0 && idx <= len(gl.Jobs) {
jb := gl.Jobs[idx-1]
if jb.Cmd != nil && jb.Cmd.Process != nil {
args[i] = fmt.Sprintf("%d", jb.Cmd.Process.Pid)
exp++
}
} else {
gl.AddError(fmt.Errorf("goal: job number out of range: %d", idx))
}
}
}
}
return exp
}
// Job represents a job that has been started and we're waiting for it to finish.
type Job struct {
*exec.CmdIO
IsExec bool
GotPipe bool
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package goalib defines convenient utility functions for
// use in the goal shell, available with the goalib prefix.
package goalib
import (
"io/fs"
"os"
"path/filepath"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/base/slicesx"
"cogentcore.org/core/base/stringsx"
)
// SplitLines returns a slice of given string split by lines
// with any extra whitespace trimmed for each line entry.
func SplitLines(s string) []string {
sl := stringsx.SplitLines(s)
for i, s := range sl {
sl[i] = strings.TrimSpace(s)
}
return sl
}
// FileExists returns true if given file exists
func FileExists(path string) bool {
ex := errors.Log1(fsx.FileExists(path))
return ex
}
// WriteFile writes string to given file with standard permissions,
// logging any errors.
func WriteFile(filename, str string) error {
err := os.WriteFile(filename, []byte(str), 0666)
if err != nil {
errors.Log(err)
}
return err
}
// ReadFile reads the string from the given file, logging any errors.
func ReadFile(filename string) string {
str, err := os.ReadFile(filename)
if err != nil {
errors.Log(err)
}
return string(str)
}
// ReplaceInFile replaces all occurrences of given string with replacement
// in given file, rewriting the file. Also returns the updated string.
func ReplaceInFile(filename, old, new string) string {
str := ReadFile(filename)
str = strings.ReplaceAll(str, old, new)
WriteFile(filename, str)
return str
}
// StringsToAnys converts a slice of strings to a slice of any,
// using slicesx.ToAny. The interpreter cannot process generics
// yet, so this wrapper is needed. Use for passing args to
// a command, for example.
func StringsToAnys(s []string) []any {
return slicesx.As[string, any](s)
}
// AllFiles returns a list of all files (excluding directories)
// under the given path.
func AllFiles(path string) []string {
var files []string
filepath.WalkDir(path, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
return nil
}
files = append(files, path)
return nil
})
return files
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package interpreter
import (
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/exec"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/base/logx"
"cogentcore.org/lab/goal"
"cogentcore.org/lab/goal/goalib"
"github.com/cogentcore/yaegi/interp"
)
//go:generate core generate -add-types -add-funcs
// Config is the configuration information for the goal cli.
type Config struct {
// Input is the input file to run/compile.
// If this is provided as the first argument,
// then the program will exit after running,
// unless the Interactive mode is flagged.
Input string `posarg:"0" required:"-"`
// Expr is an optional expression to evaluate, which can be used
// in addition to the Input file to run, to execute commands
// defined within that file for example, or as a command to run
// prior to starting interactive mode if no Input is specified.
Expr string `flag:"e,expr"`
// Dir is a directory path to change to prior to running.
Dir string
// Args is an optional list of arguments to pass in the run command.
// These arguments will be turned into an "args" local variable in the goal.
// These are automatically processed from any leftover arguments passed, so
// you should not need to specify this flag manually.
Args []string `cmd:"run" posarg:"leftover" required:"-"`
// Interactive runs the interactive command line after processing any input file.
// Interactive mode is the default mode for the run command unless an input file
// is specified.
Interactive bool `cmd:"run" flag:"i,interactive"`
// InteractiveFunc is the function to run in interactive mode.
// set it to your own function as needed.
InteractiveFunc func(c *Config, in *Interpreter) error
}
// Run runs the specified goal file. If no file is specified,
// it runs an interactive shell that allows the user to input goal.
func Run(c *Config) error { //cli:cmd -root
in := NewInterpreter(interp.Options{})
if len(c.Args) > 0 {
in.Goal.CliArgs = goalib.StringsToAnys(c.Args)
}
if c.Dir != "" {
errors.Log(os.Chdir(c.Dir))
}
if c.Input == "" {
return c.InteractiveFunc(c, in)
}
in.Config()
code := ""
if errors.Log1(fsx.FileExists(c.Input)) {
b, err := os.ReadFile(c.Input)
if err != nil && c.Expr == "" {
return err
}
code = string(b)
}
if c.Expr != "" {
if code != "" {
code += "\n"
}
code += c.Expr + "\n"
}
_, _, err := in.Eval(code)
if err == nil {
err = in.Goal.TrState.DepthError()
}
if c.Interactive {
return c.InteractiveFunc(c, in)
}
return err
}
// Interactive runs an interactive shell that allows the user to input goal.
func Interactive(c *Config, in *Interpreter) error {
in.Config()
if c.Expr != "" {
in.Eval(c.Expr)
}
in.Interactive()
return nil
}
// Build builds the specified input goal file, or all .goal files in the current
// directory if no input is specified, to corresponding .go file name(s).
// If the file does not already contain a "package" specification, then
// "package main; func main()..." wrappers are added, which allows the same
// code to be used in interactive and Go compiled modes.
// go build is run after this.
func Build(c *Config) error {
if c.Dir != "" {
errors.Log(os.Chdir(c.Dir))
}
var fns []string
verbose := logx.UserLevel <= slog.LevelInfo
if c.Input != "" {
fns = []string{c.Input}
} else {
fns = fsx.Filenames(".", ".goal")
}
curpkg, _ := exec.Minor().Output("go", "list", "./")
var errs []error
for _, fn := range fns {
fpath := filepath.Join(curpkg, fn)
if verbose {
fmt.Println(fpath)
}
ofn := strings.TrimSuffix(fn, filepath.Ext(fn)) + ".go"
err := goal.NewGoal().TranspileFile(fn, ofn)
if err != nil {
errs = append(errs, err)
}
}
args := []string{"build"}
if verbose {
args = append(args, "-v")
}
inCmd := false
output := ""
if goalib.FileExists("cmd/main.go") {
output = filepath.Base(errors.Log1(os.Getwd()))
inCmd = true
args = append(args, "-o", output)
os.Chdir("cmd")
}
err := exec.Verbose().Run("go", args...)
if err != nil {
errs = append(errs, err)
}
if inCmd {
os.Rename(output, filepath.Join("..", output))
os.Chdir("../")
}
return errors.Join(errs...)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package interpreter
import (
"reflect"
"github.com/cogentcore/yaegi/interp"
)
var Symbols = map[string]map[string]reflect.Value{}
// ImportGoal makes the methods of goal object available in goalrun package.
func (in *Interpreter) ImportGoal() {
in.Interp.Use(interp.Exports{
"cogentcore.org/lab/goalrun/goalrun": map[string]reflect.Value{
"Run": reflect.ValueOf(in.Goal.Run),
"RunErrOK": reflect.ValueOf(in.Goal.RunErrOK),
"Output": reflect.ValueOf(in.Goal.Output),
"OutputErrOK": reflect.ValueOf(in.Goal.OutputErrOK),
"Start": reflect.ValueOf(in.Goal.Start),
"AddCommand": reflect.ValueOf(in.Goal.AddCommand),
"RunCommands": reflect.ValueOf(in.Goal.RunCommands),
"Args": reflect.ValueOf(in.Goal.Args),
},
})
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package interpreter
import (
"context"
"fmt"
"io"
"log"
"os"
"os/signal"
"reflect"
"strconv"
"strings"
"syscall"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/yaegicore/basesymbols"
"cogentcore.org/lab/goal"
_ "cogentcore.org/lab/stats/metric"
_ "cogentcore.org/lab/stats/stats"
_ "cogentcore.org/lab/tensor/tmath"
"cogentcore.org/lab/tensorfs"
"cogentcore.org/lab/yaegilab/tensorsymbols"
"github.com/cogentcore/readline"
"github.com/cogentcore/yaegi/interp"
"github.com/cogentcore/yaegi/stdlib"
)
// Interpreter represents one running shell context
type Interpreter struct {
// the goal shell
Goal *goal.Goal
// HistFile is the name of the history file to open / save.
// Defaults to ~/.goal-history for the default goal shell.
// Update this prior to running Config() to take effect.
HistFile string
// the yaegi interpreter
Interp *interp.Interpreter
}
func init() {
delete(stdlib.Symbols, "errors/errors") // use our errors package instead
}
// NewInterpreter returns a new [Interpreter] initialized with the given options.
// It automatically imports the standard library and configures necessary shell
// functions. End user app must call [Interp.Config] after importing any additional
// symbols, prior to running the interpreter.
func NewInterpreter(options interp.Options) *Interpreter {
in := &Interpreter{HistFile: "~/.goal-history"}
in.Goal = goal.NewGoal()
if options.Stdin != nil {
in.Goal.Config.StdIO.In = options.Stdin
}
if options.Stdout != nil {
in.Goal.Config.StdIO.Out = options.Stdout
}
if options.Stderr != nil {
in.Goal.Config.StdIO.Err = options.Stderr
}
in.Goal.SaveOrigStdIO()
options.Stdout = in.Goal.StdIOWrappers.Out
options.Stderr = in.Goal.StdIOWrappers.Err
options.Stdin = in.Goal.StdIOWrappers.In
in.Interp = interp.New(options)
errors.Log(in.Interp.Use(basesymbols.Symbols))
errors.Log(in.Interp.Use(tensorsymbols.Symbols))
in.ImportGoal()
go in.MonitorSignals()
return in
}
// Prompt returns the appropriate REPL prompt to show the user.
func (in *Interpreter) Prompt() string {
dp := in.Goal.TrState.TotalDepth()
pc := ">"
dir := in.Goal.HostAndDir()
if in.Goal.TrState.MathMode {
pc = "#"
dir = tensorfs.CurDir.Path()
}
if dp == 0 {
return dir + " " + pc + " "
}
res := pc + " "
for range dp {
res += " " // note: /t confuses readline
}
return res
}
// Eval evaluates (interprets) the given code,
// returning the value returned from the interpreter.
// HasPrint indicates whether the last line of code
// has the string print in it, which is for determining
// whether to print the result in interactive mode.
// It automatically logs any error in addition to returning it.
func (in *Interpreter) Eval(code string) (v reflect.Value, hasPrint bool, err error) {
in.Goal.TranspileCode(code)
source := false
if in.Goal.SSHActive == "" {
source = strings.HasPrefix(code, "source")
}
if in.Goal.TrState.TotalDepth() == 0 {
nl := len(in.Goal.TrState.Lines)
if nl > 0 {
ln := in.Goal.TrState.Lines[nl-1]
if strings.Contains(strings.ToLower(ln), "print") {
hasPrint = true
}
}
v, err = in.RunCode()
in.Goal.Errors = nil
}
if source {
v, err = in.RunCode() // run accumulated code
}
return
}
// RunCode runs the accumulated set of code lines
// and clears the stack of code lines.
// It automatically logs any error in addition to returning it.
func (in *Interpreter) RunCode() (reflect.Value, error) {
if len(in.Goal.Errors) > 0 {
return reflect.Value{}, errors.Join(in.Goal.Errors...)
}
in.Goal.TrState.AddChunk()
code := in.Goal.TrState.Chunks
in.Goal.TrState.ResetCode()
var v reflect.Value
var err error
for _, ch := range code {
ctx := in.Goal.StartContext()
v, err = in.Interp.EvalWithContext(ctx, ch)
in.Goal.EndContext()
if err != nil {
cancelled := errors.Is(err, context.Canceled)
// fmt.Println("cancelled:", cancelled)
in.Goal.DeleteAllJobs()
in.Goal.RestoreOrigStdIO()
in.Goal.TrState.ResetDepth()
if !cancelled {
in.Goal.AddError(err)
} else {
in.Goal.Errors = nil
}
break
}
}
return v, err
}
// RunConfig runs the .goal startup config file in the user's
// home directory if it exists.
func (in *Interpreter) RunConfig() error {
err := in.Goal.TranspileConfig()
if err != nil {
errors.Log(err)
}
_, err = in.RunCode()
return err
}
// MonitorSignals monitors the operating system signals to appropriately
// stop the interpreter and prevent the shell from closing on Control+C.
// It is called automatically in another goroutine in [NewInterpreter].
func (in *Interpreter) MonitorSignals() {
c := make(chan os.Signal, 1)
// todo: syscall.SIGSEGV not defined on web
signal.Notify(c, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
// signal.Notify(c, os.Interrupt, syscall.SIGTERM, syscall.SIGINT, syscall.SIGSEGV)
for {
<-c
in.Goal.CancelExecution()
}
}
// Config performs final configuration after all the imports have been Use'd
func (in *Interpreter) Config() {
in.Interp.ImportUsed()
in.RunConfig()
}
// OpenHistory opens history from the current HistFile
// and loads it into the readline history for given rl instance
func (in *Interpreter) OpenHistory(rl *readline.Instance) error {
err := in.Goal.OpenHistory(in.HistFile)
if err == nil {
for _, h := range in.Goal.Hist {
rl.SaveToHistory(h)
}
}
return err
}
// SaveHistory saves last 500 (or HISTFILESIZE env value) lines of history,
// to the current HistFile.
func (in *Interpreter) SaveHistory() error {
n := 500
if hfs := os.Getenv("HISTFILESIZE"); hfs != "" {
en, err := strconv.Atoi(hfs)
if err != nil {
in.Goal.Config.StdIO.ErrPrintf("SaveHistory: environment variable HISTFILESIZE: %q not a number: %s", hfs, err.Error())
} else {
n = en
}
}
return in.Goal.SaveHistory(n, in.HistFile)
}
// Interactive runs an interactive shell that allows the user to input goal.
// Must have done in.Config() prior to calling.
func (in *Interpreter) Interactive() error {
in.Goal.TrState.MathRecord = true
rl, err := readline.NewFromConfig(&readline.Config{
AutoComplete: &goal.ReadlineCompleter{Goal: in.Goal},
Undo: true,
})
if err != nil {
return err
}
in.OpenHistory(rl)
defer rl.Close()
log.SetOutput(rl.Stderr()) // redraw the prompt correctly after log output
for {
rl.SetPrompt(in.Prompt())
line, err := rl.ReadLine()
if errors.Is(err, readline.ErrInterrupt) {
continue
}
if errors.Is(err, io.EOF) {
in.SaveHistory()
os.Exit(0)
}
if err != nil {
in.SaveHistory()
return err
}
if len(line) > 0 && line[0] == '!' { // history command
hl, err := strconv.Atoi(line[1:])
nh := len(in.Goal.Hist)
if err != nil {
in.Goal.Config.StdIO.ErrPrintf("history number: %q not a number: %s", line[1:], err.Error())
line = ""
} else if hl >= nh {
in.Goal.Config.StdIO.ErrPrintf("history number: %d not in range: [0:%d]", hl, nh)
line = ""
} else {
line = in.Goal.Hist[hl]
fmt.Printf("h:%d\t%s\n", hl, line)
}
} else if line != "" && !strings.HasPrefix(line, "history") && line != "h" {
in.Goal.AddHistory(line)
}
in.Goal.Errors = nil
v, hasPrint, err := in.Eval(line)
if err == nil && !hasPrint && v.IsValid() && !v.IsZero() && v.Kind() != reflect.Func {
in.Goal.Config.StdIO.Println(v.Interface())
}
}
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package goal
// Run executes the given command string, waiting for the command to finish,
// handling the given arguments appropriately.
// If there is any error, it adds it to the goal, and triggers CancelExecution.
// It forwards output to [exec.Config.Stdout] and [exec.Config.Stderr] appropriately.
func (gl *Goal) Run(cmd any, args ...any) {
gl.Exec(false, false, false, cmd, args...)
}
// RunErrOK executes the given command string, waiting for the command to finish,
// handling the given arguments appropriately.
// It does not stop execution if there is an error.
// If there is any error, it adds it to the goal. It forwards output to
// [exec.Config.Stdout] and [exec.Config.Stderr] appropriately.
func (gl *Goal) RunErrOK(cmd any, args ...any) {
gl.Exec(true, false, false, cmd, args...)
}
// Start starts the given command string for running in the background,
// handling the given arguments appropriately.
// If there is any error, it adds it to the goal. It forwards output to
// [exec.Config.Stdout] and [exec.Config.Stderr] appropriately.
func (gl *Goal) Start(cmd any, args ...any) {
gl.Exec(false, true, false, cmd, args...)
}
// Output executes the given command string, handling the given arguments
// appropriately. If there is any error, it adds it to the goal. It returns
// the stdout as a string and forwards stderr to [exec.Config.Stderr] appropriately.
func (gl *Goal) Output(cmd any, args ...any) string {
return gl.Exec(false, false, true, cmd, args...)
}
// OutputErrOK executes the given command string, handling the given arguments
// appropriately. If there is any error, it adds it to the goal. It returns
// the stdout as a string and forwards stderr to [exec.Config.Stderr] appropriately.
func (gl *Goal) OutputErrOK(cmd any, args ...any) string {
return gl.Exec(true, false, true, cmd, args...)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package transpile
import (
"path"
"reflect"
"strings"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/yaegilab/tensorsymbols"
)
func init() {
AddYaegiTensorFuncs()
}
var yaegiTensorPackages = []string{"/lab/tensor", "/lab/stats", "/lab/vector", "/lab/matrix"}
// AddYaegiTensorFuncs grabs all tensor* package functions registered
// in yaegicore and adds them to the `tensor.Funcs` map so we can
// properly convert symbols to either tensors or basic literals,
// depending on the arg types for the current function.
func AddYaegiTensorFuncs() {
for pth, symap := range tensorsymbols.Symbols {
has := false
for _, p := range yaegiTensorPackages {
if strings.Contains(pth, p) {
has = true
break
}
}
if !has {
continue
}
_, pkg := path.Split(pth)
for name, val := range symap {
if val.Kind() != reflect.Func {
continue
}
pnm := pkg + "." + name
tensor.AddFunc(pnm, val.Interface())
}
}
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package transpile
import (
"fmt"
"strings"
"unicode"
)
func ExecWords(ln string) ([]string, error) {
ln = strings.TrimSpace(ln)
n := len(ln)
if n == 0 {
return nil, nil
}
if ln[0] == '$' {
ln = strings.TrimSpace(ln[1:])
n = len(ln)
if n == 0 {
return nil, nil
}
if ln[n-1] == '$' {
ln = strings.TrimSpace(ln[:n-1])
n = len(ln)
if n == 0 {
return nil, nil
}
}
}
word := ""
esc := false
dQuote := false
bQuote := false
brace := 0
brack := 0
redir := false
var words []string
addWord := func() {
if brace > 0 { // always accum into one token inside brace
return
}
if len(word) > 0 {
words = append(words, word)
word = ""
}
}
atStart := true
sbrack := (ln[0] == '[')
if sbrack {
word = "["
addWord()
brack++
ln = ln[1:]
atStart = false
}
for _, r := range ln {
quote := dQuote || bQuote
if redir {
redir = false
if r == '&' {
word += string(r)
addWord()
continue
}
if r == '>' {
word += string(r)
redir = true
continue
}
addWord()
}
switch {
case esc:
if brace == 0 && unicode.IsSpace(r) { // we will be quoted later anyway
word = word[:len(word)-1]
}
word += string(r)
esc = false
case r == '\\':
esc = true
word += string(r)
case r == '"':
if !bQuote {
dQuote = !dQuote
}
word += string(r)
case r == '`':
if !dQuote {
bQuote = !bQuote
}
word += string(r)
case quote: // absorbs quote -- no need to check below
word += string(r)
case unicode.IsSpace(r):
addWord()
continue // don't reset at start
case r == '{':
if brace == 0 {
addWord()
word = "{"
addWord()
}
brace++
case r == '}':
brace--
if brace == 0 {
addWord()
word = "}"
addWord()
}
case r == '[':
word += string(r)
if atStart && brack == 0 {
sbrack = true
addWord()
}
brack++
case r == ']':
brack--
if brack == 0 && sbrack { // only point of tracking brack is to get this end guy
addWord()
word = "]"
addWord()
} else {
word += string(r)
}
case r == '<' || r == '>' || r == '|':
addWord()
word += string(r)
redir = true
case r == '&': // known to not be redir
addWord()
word += string(r)
case r == ';':
addWord()
word += string(r)
addWord()
atStart = true
continue // avoid reset
default:
word += string(r)
}
atStart = false
}
addWord()
if dQuote || bQuote || brack > 0 {
return words, fmt.Errorf("goal: exec command has unterminated quotes (\": %v, `: %v) or brackets [ %v ]", dQuote, bQuote, brack > 0)
}
return words, nil
}
// ExecWordIsCommand returns true if given exec word is a command-like string
// (excluding any paths)
func ExecWordIsCommand(f string) bool {
if strings.Contains(f, "(") || strings.Contains(f, "[") || strings.Contains(f, "=") {
return false
}
return true
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package transpile
import (
"fmt"
"go/ast"
"go/token"
"strings"
"cogentcore.org/core/base/stack"
"cogentcore.org/lab/goal/transpile/mparser"
"cogentcore.org/lab/tensor"
)
// TranspileMath does math mode transpiling. fullLine indicates code should be
// full statement(s).
func (st *State) TranspileMath(toks Tokens, code string, fullLine bool) Tokens {
nt := len(toks)
if nt == 0 {
return nil
}
// fmt.Println(nt, toks)
str := code[toks[0].Pos-1 : toks[nt-1].Pos]
if toks[nt-1].Str != "" {
str += toks[nt-1].Str[1:]
}
// fmt.Println(str)
mp := mathParse{state: st, toks: toks, code: code}
// mp.trace = true
mods := mparser.AllErrors // | Trace
if fullLine {
ewords, err := ExecWords(str)
if len(ewords) > 0 {
if cmd, ok := tensorfsCommands[ewords[0]]; ok {
mp.ewords = ewords
err := cmd(&mp)
if err != nil {
fmt.Println(ewords[0]+":", err.Error())
return nil
} else {
return mp.out
}
}
}
stmts, err := mparser.ParseLine(str, mods)
if err != nil {
fmt.Println("line code:", str)
fmt.Println("parse err:", err)
}
if len(stmts) == 0 {
return toks
}
mp.stmtList(stmts)
} else {
ex, err := mparser.ParseExpr(str, mods)
if err != nil {
fmt.Println("expr:", str)
fmt.Println("parse err:", err)
}
mp.expr(ex)
}
if !(mp.idx == len(toks) || mp.idx == len(toks)-1) { // -1 is for comment at end
fmt.Println(code)
fmt.Println(mp.out.Code())
fmt.Printf("parsing error: index: %d != len(toks): %d\n", mp.idx, len(toks))
}
return mp.out
}
// funcInfo is info about the function being processed
type funcInfo struct {
tensor.Func
// current arg index we are processing
curArg int
}
// mathParse has the parsing state, active only during a parsing pass
// on one specific chunk of code and tokens.
type mathParse struct {
state *State
code string // code string
toks Tokens // source tokens we are parsing
ewords []string // exec words
idx int // current index in source tokens -- critical to sync as we "use" source
out Tokens // output tokens we generate
trace bool // trace of parsing -- turn on to see alignment
inArray bool // we are in an array
// stack of function info -- top of stack reflects the current function
funcs stack.Stack[*funcInfo]
}
// returns the current argument for current function
func (mp *mathParse) curArg() *tensor.Arg {
cfun := mp.funcs.Peek()
if cfun == nil {
return nil
}
if cfun.curArg < len(cfun.Args) {
return cfun.Args[cfun.curArg]
}
return nil
}
func (mp *mathParse) nextArg() {
cfun := mp.funcs.Peek()
if cfun == nil || len(cfun.Args) == 0 {
// fmt.Println("next arg no fun or no args")
return
}
n := len(cfun.Args)
if cfun.curArg == n-1 {
carg := cfun.Args[n-1]
if !carg.IsVariadic {
fmt.Println("math transpile: args exceed registered function number:", cfun)
}
return
}
cfun.curArg++
}
func (mp *mathParse) curArgIsTensor() bool {
carg := mp.curArg()
if carg == nil {
return false
}
return carg.IsTensor
}
func (mp *mathParse) curArgIsInts() bool {
carg := mp.curArg()
if carg == nil {
return false
}
return carg.IsInt && carg.IsVariadic
}
// startFunc is called when starting a new function.
// empty is "dummy" assign case using Inc.
// optional noLookup indicates to not lookup type and just
// push the name -- for internal cases to prevent arg conversions.
func (mp *mathParse) startFunc(name string, noLookup ...bool) *funcInfo {
fi := &funcInfo{}
sname := name
if name == "" || name == "tensor.Tensor" {
sname = "tmath.Inc" // one arg tensor fun
}
if len(noLookup) == 1 && noLookup[0] {
fi.Name = name
} else {
if tf, err := tensor.FuncByName(sname); err == nil {
fi.Func = *tf
} else {
fi.Name = name
}
}
mp.funcs.Push(fi)
if name != "" {
mp.out.Add(token.IDENT, name)
}
return fi
}
func (mp *mathParse) endFunc() {
mp.funcs.Pop()
}
// addToken adds output token and increments idx
func (mp *mathParse) addToken(tok token.Token) {
mp.out.Add(tok)
if mp.trace {
ctok := &Token{}
if mp.idx < len(mp.toks) {
ctok = mp.toks[mp.idx]
}
fmt.Printf("%d\ttok: %s \t replaces: %s\n", mp.idx, tok, ctok)
}
mp.idx++
}
func (mp *mathParse) addCur() {
if len(mp.toks) > mp.idx {
mp.out.AddTokens(mp.toks[mp.idx])
mp.idx++
return
}
fmt.Println("out of tokens!", mp.idx, mp.toks)
}
func (mp *mathParse) stmtList(sts []ast.Stmt) {
for _, st := range sts {
mp.stmt(st)
}
}
func (mp *mathParse) stmt(st ast.Stmt) {
if st == nil {
return
}
switch x := st.(type) {
case *ast.BadStmt:
fmt.Println("bad stmt!")
case *ast.DeclStmt:
case *ast.ExprStmt:
mp.expr(x.X)
case *ast.SendStmt:
mp.expr(x.Chan)
mp.addToken(token.ARROW)
mp.expr(x.Value)
case *ast.IncDecStmt:
fn := "Inc"
if x.Tok == token.DEC {
fn = "Dec"
}
mp.startFunc("tmath." + fn)
mp.out.Add(token.LPAREN)
mp.expr(x.X)
mp.addToken(token.RPAREN)
case *ast.AssignStmt:
switch x.Tok {
case token.DEFINE:
mp.defineStmt(x)
default:
mp.assignStmt(x)
}
case *ast.GoStmt:
mp.addToken(token.GO)
mp.callExpr(x.Call)
case *ast.DeferStmt:
mp.addToken(token.DEFER)
mp.callExpr(x.Call)
case *ast.ReturnStmt:
mp.addToken(token.RETURN)
mp.exprList(x.Results)
case *ast.BranchStmt:
mp.addToken(x.Tok)
mp.ident(x.Label)
case *ast.BlockStmt:
mp.addToken(token.LBRACE)
mp.stmtList(x.List)
mp.addToken(token.RBRACE)
case *ast.IfStmt:
mp.addToken(token.IF)
mp.stmt(x.Init)
if x.Init != nil {
mp.addToken(token.SEMICOLON)
}
mp.expr(x.Cond)
mp.out.Add(token.IDENT, ".Bool1D(0)") // turn bool expr into actual bool
if x.Body != nil && len(x.Body.List) > 0 {
mp.addToken(token.LBRACE)
mp.stmtList(x.Body.List)
mp.addToken(token.RBRACE)
} else {
mp.addToken(token.LBRACE)
}
if x.Else != nil {
mp.addToken(token.ELSE)
mp.stmt(x.Else)
}
case *ast.ForStmt:
mp.addToken(token.FOR)
mp.stmt(x.Init)
if x.Init != nil {
mp.addToken(token.SEMICOLON)
}
mp.expr(x.Cond)
if x.Cond != nil {
mp.out.Add(token.IDENT, ".Bool1D(0)") // turn bool expr into actual bool
mp.addToken(token.SEMICOLON)
}
mp.stmt(x.Post)
if x.Body != nil && len(x.Body.List) > 0 {
mp.addToken(token.LBRACE)
mp.stmtList(x.Body.List)
mp.addToken(token.RBRACE)
} else {
mp.addToken(token.LBRACE)
}
case *ast.RangeStmt:
if x.Key == nil || x.Value == nil {
fmt.Println("for range statement requires both index and value variables")
return
}
ki, _ := x.Key.(*ast.Ident)
vi, _ := x.Value.(*ast.Ident)
ei, _ := x.X.(*ast.Ident)
if ki == nil || vi == nil || ei == nil {
fmt.Println("for range statement requires all variables (index, value, range) to be variable names, not other expressions")
return
}
knm := ki.Name
vnm := vi.Name
enm := ei.Name
mp.addToken(token.FOR)
mp.expr(x.Key)
mp.idx += 2
mp.addToken(token.DEFINE)
mp.out.Add(token.IDENT, "0")
mp.out.Add(token.SEMICOLON)
mp.out.Add(token.IDENT, knm)
mp.out.Add(token.IDENT, "<")
mp.out.Add(token.IDENT, enm)
mp.out.Add(token.PERIOD)
mp.out.Add(token.IDENT, "Len")
mp.idx++
mp.out.AddMulti(token.LPAREN, token.RPAREN)
mp.idx++
mp.out.Add(token.SEMICOLON)
mp.idx++
mp.out.Add(token.IDENT, knm)
mp.out.AddMulti(token.INC, token.LBRACE)
mp.out.Add(token.IDENT, vnm)
mp.out.Add(token.DEFINE)
mp.out.Add(token.IDENT, enm)
mp.out.Add(token.IDENT, ".Float1D")
mp.out.Add(token.LPAREN)
mp.out.Add(token.IDENT, knm)
mp.out.Add(token.RPAREN)
if x.Body != nil && len(x.Body.List) > 0 {
mp.stmtList(x.Body.List)
mp.addToken(token.RBRACE)
}
// TODO
// CaseClause: SwitchStmt:, TypeSwitchStmt:, CommClause:, SelectStmt:
}
}
func (mp *mathParse) expr(ex ast.Expr) {
if ex == nil {
return
}
switch x := ex.(type) {
case *ast.BadExpr:
fmt.Println("bad expr!")
case *ast.Ident:
mp.ident(x)
case *ast.UnaryExpr:
mp.unaryExpr(x)
case *ast.Ellipsis:
cfun := mp.funcs.Peek()
if cfun != nil && cfun.Name == "tensor.AnySlice" {
mp.out.Add(token.IDENT, "tensor.Ellipsis")
mp.idx++
} else {
mp.addToken(token.ELLIPSIS)
}
case *ast.StarExpr:
mp.addToken(token.MUL)
mp.expr(x.X)
case *ast.BinaryExpr:
mp.binaryExpr(x)
case *ast.BasicLit:
mp.basicLit(x)
case *ast.FuncLit:
case *ast.ParenExpr:
mp.addToken(token.LPAREN)
mp.expr(x.X)
mp.addToken(token.RPAREN)
case *ast.SelectorExpr:
mp.selectorExpr(x)
case *ast.TypeAssertExpr:
case *ast.IndexExpr:
mp.indexExpr(x)
case *ast.IndexListExpr:
if x.X == nil { // array literal
mp.arrayLiteral(x)
} else {
mp.indexListExpr(x)
}
case *ast.SliceExpr:
mp.sliceExpr(x)
case *ast.CallExpr:
mp.callExpr(x)
case *ast.ArrayType:
// note: shouldn't happen normally:
fmt.Println("array type:", x, x.Len)
fmt.Printf("%#v\n", x.Len)
}
}
func (mp *mathParse) exprList(ex []ast.Expr) {
n := len(ex)
if n == 0 {
return
}
if n == 1 {
mp.expr(ex[0])
return
}
for i := range n {
mp.expr(ex[i])
if i < n-1 {
mp.addToken(token.COMMA)
}
}
}
func (mp *mathParse) argsList(ex []ast.Expr) {
n := len(ex)
if n == 0 {
return
}
if n == 1 {
mp.expr(ex[0])
return
}
for i := range n {
// cfun := mp.funcs.Peek()
// if i != cfun.curArg {
// fmt.Println(cfun, "arg should be:", i, "is:", cfun.curArg)
// }
mp.expr(ex[i])
if i < n-1 {
mp.nextArg()
mp.addToken(token.COMMA)
}
}
}
func (mp *mathParse) exprIsBool(ex ast.Expr) bool {
switch x := ex.(type) {
case *ast.BinaryExpr:
if (x.Op >= token.EQL && x.Op <= token.GTR) || (x.Op >= token.NEQ && x.Op <= token.GEQ) {
return true
}
case *ast.ParenExpr:
return mp.exprIsBool(x.X)
}
return false
}
func (mp *mathParse) exprsAreBool(ex []ast.Expr) bool {
for _, x := range ex {
if mp.exprIsBool(x) {
return true
}
}
return false
}
func (mp *mathParse) binaryExpr(ex *ast.BinaryExpr) {
if ex.Op == token.ILLEGAL { // @ = matmul
mp.startFunc("matrix.Mul")
mp.out.Add(token.LPAREN)
mp.expr(ex.X)
mp.out.Add(token.COMMA)
mp.idx++
mp.expr(ex.Y)
mp.out.Add(token.RPAREN)
mp.endFunc()
return
}
fn := ""
switch ex.Op {
case token.ADD:
fn = "Add"
case token.SUB:
fn = "Sub"
case token.MUL:
fn = "Mul"
if un, ok := ex.Y.(*ast.StarExpr); ok { // ** power operator
ex.Y = un.X
fn = "Pow"
}
case token.QUO:
fn = "Div"
case token.EQL:
fn = "Equal"
case token.LSS:
fn = "Less"
case token.GTR:
fn = "Greater"
case token.NEQ:
fn = "NotEqual"
case token.LEQ:
fn = "LessEqual"
case token.GEQ:
fn = "GreaterEqual"
case token.LOR:
fn = "Or"
case token.LAND:
fn = "And"
default:
fmt.Println("binary token:", ex.Op)
}
mp.startFunc("tmath." + fn)
mp.out.Add(token.LPAREN)
mp.expr(ex.X)
mp.out.Add(token.COMMA)
mp.idx++
if fn == "Pow" {
mp.idx++
}
mp.expr(ex.Y)
mp.out.Add(token.RPAREN)
mp.endFunc()
}
func (mp *mathParse) unaryExpr(ex *ast.UnaryExpr) {
if _, isbl := ex.X.(*ast.BasicLit); isbl {
mp.addToken(ex.Op)
mp.expr(ex.X)
return
}
fn := ""
switch ex.Op {
case token.NOT:
fn = "Not"
case token.SUB:
fn = "Negate"
case token.ADD:
mp.expr(ex.X)
return
default: // * goes to StarExpr -- not sure what else could happen here?
mp.addToken(ex.Op)
mp.expr(ex.X)
return
}
mp.startFunc("tmath." + fn)
mp.addToken(token.LPAREN)
mp.expr(ex.X)
mp.out.Add(token.RPAREN)
mp.endFunc()
}
func (mp *mathParse) defineStmt(as *ast.AssignStmt) {
firstStmt := mp.idx == 0
mp.exprList(as.Lhs)
mp.addToken(as.Tok)
mp.startFunc("tensor.Tensor")
mp.out.Add(token.LPAREN)
mp.exprList(as.Rhs)
mp.out.Add(token.RPAREN)
mp.endFunc()
if firstStmt && mp.state.MathRecord {
nvar, ok := as.Lhs[0].(*ast.Ident)
if ok {
mp.out.Add(token.SEMICOLON)
mp.out.Add(token.IDENT, "tensorfs.Record("+nvar.Name+",`"+nvar.Name+"`)")
}
}
}
func (mp *mathParse) assignStmt(as *ast.AssignStmt) {
if as.Tok == token.ASSIGN {
if _, ok := as.Lhs[0].(*ast.Ident); ok {
mp.exprList(as.Lhs)
mp.addToken(as.Tok)
mp.startFunc("")
mp.exprList(as.Rhs)
mp.endFunc()
return
}
}
fn := ""
switch as.Tok {
case token.ASSIGN:
fn = "Assign"
case token.ADD_ASSIGN:
fn = "AddAssign"
case token.SUB_ASSIGN:
fn = "SubAssign"
case token.MUL_ASSIGN:
fn = "MulAssign"
case token.QUO_ASSIGN:
fn = "DivAssign"
}
mp.startFunc("tmath." + fn)
mp.out.Add(token.LPAREN)
mp.exprList(as.Lhs)
mp.out.Add(token.COMMA)
mp.idx++
mp.exprList(as.Rhs)
mp.out.Add(token.RPAREN)
mp.endFunc()
}
func (mp *mathParse) basicLit(lit *ast.BasicLit) {
if mp.curArgIsTensor() {
mp.tensorLit(lit)
return
}
mp.out.Add(lit.Kind, lit.Value)
if mp.trace {
fmt.Printf("%d\ttok: %s literal\n", mp.idx, lit.Value)
}
mp.idx++
return
}
func (mp *mathParse) tensorLit(lit *ast.BasicLit) {
switch lit.Kind {
case token.INT:
mp.out.Add(token.IDENT, "tensor.NewIntScalar("+lit.Value+")")
mp.idx++
case token.FLOAT:
mp.out.Add(token.IDENT, "tensor.NewFloat64Scalar("+lit.Value+")")
mp.idx++
case token.STRING:
mp.out.Add(token.IDENT, "tensor.NewStringScalar("+lit.Value+")")
mp.idx++
}
}
// funWrap is a function wrapper for simple numpy property / functions
type funWrap struct {
fun string // function to call on tensor
wrap string // code for wrapping function for results of call
}
// nis: NewIntScalar, niv: NewIntFromValues, etc
var numpyProps = map[string]funWrap{
"ndim": {"NumDims()", "nis"},
"len": {"Len()", "nis"},
"size": {"Len()", "nis"},
"shape": {"Shape().Sizes", "niv"},
"T": {"", "tensor.Transpose"},
}
// tensorFunc outputs the wrapping function and whether it needs ellipsis
func (fw *funWrap) wrapFunc(mp *mathParse) bool {
ellip := false
wrapFun := fw.wrap
switch fw.wrap {
case "nis":
wrapFun = "tensor.NewIntScalar"
case "nfs":
wrapFun = "tensor.NewFloat64Scalar"
case "nss":
wrapFun = "tensor.NewStringScalar"
case "niv":
wrapFun = "tensor.NewIntFromValues"
ellip = true
case "nfv":
wrapFun = "tensor.NewFloat64FromValues"
ellip = true
case "nsv":
wrapFun = "tensor.NewStringFromValues"
ellip = true
default:
wrapFun = fw.wrap
}
mp.startFunc(wrapFun, true) // don't lookup -- don't auto-convert args
mp.out.Add(token.LPAREN)
return ellip
}
func (mp *mathParse) selectorExpr(ex *ast.SelectorExpr) {
fw, ok := numpyProps[ex.Sel.Name]
if !ok {
mp.expr(ex.X)
mp.addToken(token.PERIOD)
mp.out.Add(token.IDENT, ex.Sel.Name)
mp.idx++
return
}
ellip := fw.wrapFunc(mp)
mp.expr(ex.X)
if fw.fun != "" {
mp.addToken(token.PERIOD)
mp.out.Add(token.IDENT, fw.fun)
mp.idx++
} else {
mp.idx += 2
}
if ellip {
mp.out.Add(token.ELLIPSIS)
}
mp.out.Add(token.RPAREN)
mp.endFunc()
}
func (mp *mathParse) indexListExpr(il *ast.IndexListExpr) {
// fmt.Println("slice expr", se)
}
func (mp *mathParse) indexExpr(il *ast.IndexExpr) {
if _, ok := il.Index.(*ast.IndexListExpr); ok {
mp.basicSlicingExpr(il)
}
}
func (mp *mathParse) basicSlicingExpr(il *ast.IndexExpr) {
iil := il.Index.(*ast.IndexListExpr)
fun := "tensor.AnySlice"
mp.startFunc(fun)
mp.out.Add(token.LPAREN)
mp.expr(il.X)
mp.nextArg()
mp.addToken(token.COMMA) // use the [ -- can't use ( to preserve X
mp.exprList(iil.Indices)
mp.addToken(token.RPAREN) // replaces ]
mp.endFunc()
}
func (mp *mathParse) sliceExpr(se *ast.SliceExpr) {
if se.Low == nil && se.High == nil && se.Max == nil {
mp.out.Add(token.IDENT, "tensor.FullAxis")
mp.idx++
return
}
mp.out.Add(token.IDENT, "tensor.Slice")
mp.out.Add(token.LBRACE)
prev := false
if se.Low != nil {
mp.out.Add(token.IDENT, "Start:")
mp.expr(se.Low)
prev = true
if se.High == nil && se.Max == nil {
mp.idx++
}
}
if se.High != nil {
if prev {
mp.out.Add(token.COMMA)
}
mp.out.Add(token.IDENT, "Stop:")
mp.idx++
mp.expr(se.High)
prev = true
}
if se.Max != nil {
if prev {
mp.out.Add(token.COMMA)
}
mp.idx++
if se.Low == nil && se.High == nil {
mp.idx++
}
mp.out.Add(token.IDENT, "Step:")
mp.expr(se.Max)
}
mp.out.Add(token.RBRACE)
}
func (mp *mathParse) arrayLiteral(il *ast.IndexListExpr) {
kind := inferKindExprList(il.Indices)
if kind == token.ILLEGAL {
kind = token.FLOAT // default
}
// todo: look for sub-arrays etc.
typ := "float64"
fun := "Float64"
switch kind {
case token.FLOAT:
case token.INT:
typ = "int"
fun = "Int"
case token.STRING:
typ = "string"
fun = "String"
}
if mp.inArray || mp.curArgIsInts() {
mp.idx++ // opening brace we're not using
mp.exprList(il.Indices)
mp.idx++ // closing brace we're not using
return
}
var sh []int
mp.inArray = true
mp.arrayShape(il.Indices, &sh)
if len(sh) > 1 {
mp.startFunc("tensor.Reshape")
mp.out.Add(token.LPAREN)
}
mp.startFunc("tensor.New" + fun + "FromValues")
mp.out.Add(token.LPAREN)
mp.out.Add(token.IDENT, "[]"+typ)
mp.addToken(token.LBRACE)
mp.exprList(il.Indices)
mp.addToken(token.RBRACE)
mp.out.AddMulti(token.ELLIPSIS, token.RPAREN)
mp.endFunc()
if len(sh) > 1 {
mp.out.Add(token.COMMA)
nsh := len(sh)
for i, s := range sh {
mp.out.Add(token.INT, fmt.Sprintf("%d", s))
if i < nsh-1 {
mp.out.Add(token.COMMA)
}
}
mp.out.Add(token.RPAREN)
mp.endFunc()
}
mp.inArray = false
}
func (mp *mathParse) arrayShape(ex []ast.Expr, sh *[]int) {
n := len(ex)
if n == 0 {
return
}
*sh = append(*sh, n)
for i := range n {
if il, ok := ex[i].(*ast.IndexListExpr); ok {
mp.arrayShape(il.Indices, sh)
return
}
}
}
// nofun = do not accept a function version, just a method
var numpyFuncs = map[string]funWrap{
"array": {"tensor.NewFromValues", ""},
"zeros": {"tensor.NewFloat64", ""},
"full": {"tensor.NewFloat64Full", ""},
"ones": {"tensor.NewFloat64Ones", ""},
"rand": {"tensor.NewFloat64Rand", ""},
"arange": {"tensor.NewIntRange", ""},
"linspace": {"tensor.NewFloat64SpacedLinear", ""},
"reshape": {"tensor.Reshape", ""},
"copy": {"tensor.Clone", ""},
"get": {"tensorfs.Get", ""},
"set": {"tensorfs.Set", ""},
"setcp": {"tensorfs.SetCopy", ""},
"flatten": {"tensor.Flatten", "nofun"},
"squeeze": {"tensor.Squeeze", "nofun"},
}
func (mp *mathParse) callExpr(ex *ast.CallExpr) {
switch x := ex.Fun.(type) {
case *ast.Ident:
if fw, ok := numpyProps[x.Name]; ok && fw.wrap != "nofun" {
mp.callPropFun(ex, fw)
return
}
mp.callName(ex, x.Name, "")
case *ast.SelectorExpr:
fun := x.Sel.Name
if pkg, ok := x.X.(*ast.Ident); ok {
if fw, ok := numpyFuncs[fun]; ok {
mp.callPropSelFun(ex, x.X, fw)
return
} else {
// fmt.Println("call name:", fun, pkg.Name)
mp.callName(ex, fun, pkg.Name)
}
} else {
if fw, ok := numpyFuncs[fun]; ok {
mp.callPropSelFun(ex, x.X, fw)
return
}
// todo: dot fun?
mp.expr(ex.Fun)
}
default:
mp.expr(ex.Fun)
}
mp.argsList(ex.Args)
// todo: ellipsis
mp.addToken(token.RPAREN)
mp.endFunc()
}
// this calls a "prop" function like ndim(a) on the object.
func (mp *mathParse) callPropFun(cf *ast.CallExpr, fw funWrap) {
ellip := fw.wrapFunc(mp)
mp.idx += 2
mp.exprList(cf.Args) // this is the tensor
mp.addToken(token.PERIOD)
mp.out.Add(token.IDENT, fw.fun)
if ellip {
mp.out.Add(token.ELLIPSIS)
}
mp.out.Add(token.RPAREN)
mp.endFunc()
}
// this calls global function through selector like: a.reshape()
func (mp *mathParse) callPropSelFun(cf *ast.CallExpr, ex ast.Expr, fw funWrap) {
mp.startFunc(fw.fun)
mp.out.Add(token.LPAREN) // use the (
mp.expr(ex)
mp.idx += 2
if len(cf.Args) > 0 {
mp.nextArg() // did first
mp.addToken(token.COMMA)
mp.argsList(cf.Args)
} else {
mp.idx++
}
mp.addToken(token.RPAREN)
mp.endFunc()
}
func (mp *mathParse) callName(cf *ast.CallExpr, funName, pkgName string) {
if fw, ok := numpyFuncs[funName]; ok {
mp.startFunc(fw.fun)
mp.addToken(token.LPAREN) // use the (
mp.idx++ // paren too
return
}
var err error // validate name
if pkgName != "" {
funName = pkgName + "." + funName
_, err = tensor.FuncByName(funName)
} else { // non-package qualified names are _only_ in tmath! can be lowercase
_, err = tensor.FuncByName("tmath." + funName)
if err != nil {
funName = strings.ToUpper(funName[:1]) + funName[1:] // first letter uppercased
_, err = tensor.FuncByName("tmath." + funName)
}
if err == nil { // registered, must be in tmath
funName = "tmath." + funName
}
}
if err != nil { // not a registered tensor function
// fmt.Println("regular fun", funName)
mp.startFunc(funName)
mp.addToken(token.LPAREN) // use the (
mp.idx += 3
return
}
mp.startFunc(funName)
mp.idx += 1
if pkgName != "" {
mp.idx += 2 // . and selector
}
mp.addToken(token.LPAREN)
}
// basic ident replacements
var consts = map[string]string{
"newaxis": "tensor.NewAxis",
"pi": "tensor.NewFloat64Scalar(math.Pi)",
}
func (mp *mathParse) ident(id *ast.Ident) {
if id == nil {
return
}
if cn, ok := consts[id.Name]; ok {
mp.out.Add(token.IDENT, cn)
mp.idx++
return
}
if mp.curArgIsInts() {
mp.out.Add(token.IDENT, "tensor.AsIntSlice")
mp.out.Add(token.LPAREN)
mp.addCur()
mp.out.AddMulti(token.RPAREN, token.ELLIPSIS)
} else {
mp.addCur()
}
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// mparse is a hacked version of go/parser:
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mparser
import (
"fmt"
"go/ast"
"go/build/constraint"
"go/scanner"
"go/token"
"strings"
)
// ParseLine parses a line of code that could contain one or more statements
func ParseLine(code string, mode Mode) (stmts []ast.Stmt, err error) {
fset := token.NewFileSet()
var p parser
defer func() {
if e := recover(); e != nil {
// resume same panic if it's not a bailout
bail, ok := e.(bailout)
if !ok {
panic(e)
} else if bail.msg != "" {
p.errors.Add(p.file.Position(bail.pos), bail.msg)
}
}
p.errors.Sort()
err = p.errors.Err()
}()
p.init(fset, "", []byte(code), mode)
stmts = p.parseStmtList()
// If a semicolon was inserted, consume it;
// report an error if there's more tokens.
if p.tok == token.SEMICOLON && p.lit == "\n" {
p.next()
}
if p.tok == token.RBRACE {
return
}
p.expect(token.EOF)
return
}
// ParseExpr parses an expression
func ParseExpr(code string, mode Mode) (expr ast.Expr, err error) {
fset := token.NewFileSet()
var p parser
defer func() {
if e := recover(); e != nil {
// resume same panic if it's not a bailout
bail, ok := e.(bailout)
if !ok {
panic(e)
} else if bail.msg != "" {
p.errors.Add(p.file.Position(bail.pos), bail.msg)
}
}
p.errors.Sort()
err = p.errors.Err()
}()
p.init(fset, "", []byte(code), mode)
expr = p.parseRhs()
// If a semicolon was inserted, consume it;
// report an error if there's more tokens.
if p.tok == token.SEMICOLON && p.lit == "\n" {
p.next()
}
p.expect(token.EOF)
return
}
// A Mode value is a set of flags (or 0).
// They control the amount of source code parsed and other optional
// parser functionality.
type Mode uint
const (
ParseComments Mode = 1 << iota // parse comments and add them to AST
Trace // print a trace of parsed productions
DeclarationErrors // report declaration errors
SpuriousErrors // same as AllErrors, for backward-compatibility
AllErrors = SpuriousErrors // report all errors (not just the first 10 on different lines)
)
// The parser structure holds the parser's internal state.
type parser struct {
file *token.File
errors scanner.ErrorList
scanner scanner.Scanner
// Tracing/debugging
mode Mode // parsing mode
trace bool // == (mode&Trace != 0)
indent int // indentation used for tracing output
// Comments
comments []*ast.CommentGroup
leadComment *ast.CommentGroup // last lead comment
lineComment *ast.CommentGroup // last line comment
top bool // in top of file (before package clause)
goVersion string // minimum Go version found in //go:build comment
// Next token
pos token.Pos // token position
tok token.Token // one token look-ahead
lit string // token literal
// Error recovery
// (used to limit the number of calls to parser.advance
// w/o making scanning progress - avoids potential endless
// loops across multiple parser functions during error recovery)
syncPos token.Pos // last synchronization position
syncCnt int // number of parser.advance calls without progress
// Non-syntactic parser control
exprLev int // < 0: in control clause, >= 0: in expression
inRhs bool // if set, the parser is parsing a rhs expression
imports []*ast.ImportSpec // list of imports
// nestLev is used to track and limit the recursion depth
// during parsing.
nestLev int
}
func (p *parser) init(fset *token.FileSet, filename string, src []byte, mode Mode) {
p.file = fset.AddFile(filename, -1, len(src))
eh := func(pos token.Position, msg string) {
if !strings.Contains(msg, "@") {
p.errors.Add(pos, msg)
}
}
p.scanner.Init(p.file, src, eh, scanner.ScanComments)
p.top = true
p.mode = mode
p.trace = mode&Trace != 0 // for convenience (p.trace is used frequently)
p.next()
}
// ----------------------------------------------------------------------------
// Parsing support
func (p *parser) printTrace(a ...any) {
const dots = ". . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . "
const n = len(dots)
pos := p.file.Position(p.pos)
fmt.Printf("%5d:%3d: ", pos.Line, pos.Column)
i := 2 * p.indent
for i > n {
fmt.Print(dots)
i -= n
}
// i <= n
fmt.Print(dots[0:i])
fmt.Println(a...)
}
func trace(p *parser, msg string) *parser {
p.printTrace(msg, "(")
p.indent++
return p
}
// Usage pattern: defer un(trace(p, "..."))
func un(p *parser) {
p.indent--
p.printTrace(")")
}
// maxNestLev is the deepest we're willing to recurse during parsing
const maxNestLev int = 1e5
func incNestLev(p *parser) *parser {
p.nestLev++
if p.nestLev > maxNestLev {
p.error(p.pos, "exceeded max nesting depth")
panic(bailout{})
}
return p
}
// decNestLev is used to track nesting depth during parsing to prevent stack exhaustion.
// It is used along with incNestLev in a similar fashion to how un and trace are used.
func decNestLev(p *parser) {
p.nestLev--
}
// Advance to the next token.
func (p *parser) next0() {
// Because of one-token look-ahead, print the previous token
// when tracing as it provides a more readable output. The
// very first token (!p.pos.IsValid()) is not initialized
// (it is token.ILLEGAL), so don't print it.
if p.trace && p.pos.IsValid() {
s := p.tok.String()
switch {
case p.tok.IsLiteral():
p.printTrace(s, p.lit)
case p.tok.IsOperator(), p.tok.IsKeyword():
p.printTrace("\"" + s + "\"")
default:
p.printTrace(s)
}
}
for {
p.pos, p.tok, p.lit = p.scanner.Scan()
if p.tok == token.COMMENT {
if p.top && strings.HasPrefix(p.lit, "//go:build") {
if x, err := constraint.Parse(p.lit); err == nil {
p.goVersion = constraint.GoVersion(x)
}
}
if p.mode&ParseComments == 0 {
continue
}
} else {
// Found a non-comment; top of file is over.
p.top = false
}
break
}
}
// Consume a comment and return it and the line on which it ends.
func (p *parser) consumeComment() (comment *ast.Comment, endline int) {
// /*-style comments may end on a different line than where they start.
// Scan the comment for '\n' chars and adjust endline accordingly.
endline = p.file.Line(p.pos)
if p.lit[1] == '*' {
// don't use range here - no need to decode Unicode code points
for i := 0; i < len(p.lit); i++ {
if p.lit[i] == '\n' {
endline++
}
}
}
comment = &ast.Comment{Slash: p.pos, Text: p.lit}
p.next0()
return
}
// Consume a group of adjacent comments, add it to the parser's
// comments list, and return it together with the line at which
// the last comment in the group ends. A non-comment token or n
// empty lines terminate a comment group.
func (p *parser) consumeCommentGroup(n int) (comments *ast.CommentGroup, endline int) {
var list []*ast.Comment
endline = p.file.Line(p.pos)
for p.tok == token.COMMENT && p.file.Line(p.pos) <= endline+n {
var comment *ast.Comment
comment, endline = p.consumeComment()
list = append(list, comment)
}
// add comment group to the comments list
comments = &ast.CommentGroup{List: list}
p.comments = append(p.comments, comments)
return
}
// Advance to the next non-comment token. In the process, collect
// any comment groups encountered, and remember the last lead and
// line comments.
//
// A lead comment is a comment group that starts and ends in a
// line without any other tokens and that is followed by a non-comment
// token on the line immediately after the comment group.
//
// A line comment is a comment group that follows a non-comment
// token on the same line, and that has no tokens after it on the line
// where it ends.
//
// Lead and line comments may be considered documentation that is
// stored in the AST.
func (p *parser) next() {
p.leadComment = nil
p.lineComment = nil
prev := p.pos
p.next0()
if p.tok == token.COMMENT {
var comment *ast.CommentGroup
var endline int
if p.file.Line(p.pos) == p.file.Line(prev) {
// The comment is on same line as the previous token; it
// cannot be a lead comment but may be a line comment.
comment, endline = p.consumeCommentGroup(0)
if p.file.Line(p.pos) != endline || p.tok == token.SEMICOLON || p.tok == token.EOF {
// The next token is on a different line, thus
// the last comment group is a line comment.
p.lineComment = comment
}
}
// consume successor comments, if any
endline = -1
for p.tok == token.COMMENT {
comment, endline = p.consumeCommentGroup(1)
}
if endline+1 == p.file.Line(p.pos) {
// The next token is following on the line immediately after the
// comment group, thus the last comment group is a lead comment.
p.leadComment = comment
}
}
}
// A bailout panic is raised to indicate early termination. pos and msg are
// only populated when bailing out of object resolution.
type bailout struct {
pos token.Pos
msg string
}
func (p *parser) error(pos token.Pos, msg string) {
if p.trace {
defer un(trace(p, "error: "+msg))
}
epos := p.file.Position(pos)
// If AllErrors is not set, discard errors reported on the same line
// as the last recorded error and stop parsing if there are more than
// 10 errors.
if p.mode&AllErrors == 0 {
n := len(p.errors)
if n > 0 && p.errors[n-1].Pos.Line == epos.Line {
return // discard - likely a spurious error
}
if n > 10 {
panic(bailout{})
}
}
p.errors.Add(epos, msg)
}
func (p *parser) errorExpected(pos token.Pos, msg string) {
msg = "expected " + msg
if pos == p.pos {
// the error happened at the current position;
// make the error message more specific
switch {
case p.tok == token.SEMICOLON && p.lit == "\n":
msg += ", found newline"
case p.tok.IsLiteral():
// print 123 rather than 'INT', etc.
msg += ", found " + p.lit
default:
msg += ", found '" + p.tok.String() + "'"
}
}
p.error(pos, msg)
}
func (p *parser) expect(tok token.Token) token.Pos {
pos := p.pos
if p.tok != tok {
p.errorExpected(pos, "'"+tok.String()+"'")
}
p.next() // make progress
return pos
}
// expect2 is like expect, but it returns an invalid position
// if the expected token is not found.
func (p *parser) expect2(tok token.Token) (pos token.Pos) {
if p.tok == tok {
pos = p.pos
} else {
p.errorExpected(p.pos, "'"+tok.String()+"'")
}
p.next() // make progress
return
}
// expectClosing is like expect but provides a better error message
// for the common case of a missing comma before a newline.
func (p *parser) expectClosing(tok token.Token, context string) token.Pos {
if p.tok != tok && p.tok == token.SEMICOLON && p.lit == "\n" {
p.error(p.pos, "missing ',' before newline in "+context)
p.next()
}
return p.expect(tok)
}
// expectSemi consumes a semicolon and returns the applicable line comment.
func (p *parser) expectSemi() (comment *ast.CommentGroup) {
// semicolon is optional before a closing ')' or '}'
if p.tok != token.RPAREN && p.tok != token.RBRACE {
switch p.tok {
case token.COMMA:
// permit a ',' instead of a ';' but complain
p.errorExpected(p.pos, "';'")
fallthrough
case token.SEMICOLON:
if p.lit == ";" {
// explicit semicolon
p.next()
comment = p.lineComment // use following comments
} else {
// artificial semicolon
comment = p.lineComment // use preceding comments
p.next()
}
return comment
default:
// math: allow unexpected endings..
// p.errorExpected(p.pos, "';'")
// p.advance(stmtStart)
}
}
return nil
}
func (p *parser) atComma(context string, follow token.Token) bool {
if p.tok == token.COMMA {
return true
}
if p.tok != follow {
msg := "missing ','"
if p.tok == token.SEMICOLON && p.lit == "\n" {
msg += " before newline"
}
p.error(p.pos, msg+" in "+context)
return true // "insert" comma and continue
}
return false
}
func passert(cond bool, msg string) {
if !cond {
panic("go/parser internal error: " + msg)
}
}
// advance consumes tokens until the current token p.tok
// is in the 'to' set, or token.EOF. For error recovery.
func (p *parser) advance(to map[token.Token]bool) {
for ; p.tok != token.EOF; p.next() {
if to[p.tok] {
// Return only if parser made some progress since last
// sync or if it has not reached 10 advance calls without
// progress. Otherwise consume at least one token to
// avoid an endless parser loop (it is possible that
// both parseOperand and parseStmt call advance and
// correctly do not advance, thus the need for the
// invocation limit p.syncCnt).
if p.pos == p.syncPos && p.syncCnt < 10 {
p.syncCnt++
return
}
if p.pos > p.syncPos {
p.syncPos = p.pos
p.syncCnt = 0
return
}
// Reaching here indicates a parser bug, likely an
// incorrect token list in this function, but it only
// leads to skipping of possibly correct code if a
// previous error is present, and thus is preferred
// over a non-terminating parse.
}
}
}
var stmtStart = map[token.Token]bool{
token.BREAK: true,
token.CONST: true,
token.CONTINUE: true,
token.DEFER: true,
token.FALLTHROUGH: true,
token.FOR: true,
token.GO: true,
token.GOTO: true,
token.IF: true,
token.RETURN: true,
token.SELECT: true,
token.SWITCH: true,
token.TYPE: true,
token.VAR: true,
}
var declStart = map[token.Token]bool{
token.IMPORT: true,
token.CONST: true,
token.TYPE: true,
token.VAR: true,
}
var exprEnd = map[token.Token]bool{
token.COMMA: true,
token.COLON: true,
token.SEMICOLON: true,
token.RPAREN: true,
token.RBRACK: true,
token.RBRACE: true,
}
// safePos returns a valid file position for a given position: If pos
// is valid to begin with, safePos returns pos. If pos is out-of-range,
// safePos returns the EOF position.
//
// This is hack to work around "artificial" end positions in the AST which
// are computed by adding 1 to (presumably valid) token positions. If the
// token positions are invalid due to parse errors, the resulting end position
// may be past the file's EOF position, which would lead to panics if used
// later on.
func (p *parser) safePos(pos token.Pos) (res token.Pos) {
defer func() {
if recover() != nil {
res = token.Pos(p.file.Base() + p.file.Size()) // EOF position
}
}()
_ = p.file.Offset(pos) // trigger a panic if position is out-of-range
return pos
}
// ----------------------------------------------------------------------------
// Identifiers
func (p *parser) parseIdent() *ast.Ident {
pos := p.pos
name := "_"
if p.tok == token.IDENT {
name = p.lit
p.next()
} else {
p.expect(token.IDENT) // use expect() error handling
}
return &ast.Ident{NamePos: pos, Name: name}
}
func (p *parser) parseIdentList() (list []*ast.Ident) {
if p.trace {
defer un(trace(p, "IdentList"))
}
list = append(list, p.parseIdent())
for p.tok == token.COMMA {
p.next()
list = append(list, p.parseIdent())
}
return
}
// ----------------------------------------------------------------------------
// Common productions
// If lhs is set, result list elements which are identifiers are not resolved.
func (p *parser) parseExprList() (list []ast.Expr) {
if p.trace {
defer un(trace(p, "ExpressionList"))
}
list = append(list, p.parseExpr())
for p.tok == token.COMMA {
p.next()
list = append(list, p.parseExpr())
}
return
}
func (p *parser) parseList(inRhs bool) []ast.Expr {
old := p.inRhs
p.inRhs = inRhs
list := p.parseExprList()
p.inRhs = old
return list
}
// math: allow full array list expressions
func (p *parser) parseArrayList(lbrack token.Pos) *ast.IndexListExpr {
if p.trace {
defer un(trace(p, "ArrayList"))
}
p.exprLev++
// x := p.parseRhs()
x := p.parseExprList()
p.exprLev--
rbrack := p.expect(token.RBRACK)
return &ast.IndexListExpr{Lbrack: lbrack, Indices: x, Rbrack: rbrack}
}
// ----------------------------------------------------------------------------
// Types
func (p *parser) parseType() ast.Expr {
if p.trace {
defer un(trace(p, "Type"))
}
typ := p.tryIdentOrType()
if typ == nil {
pos := p.pos
p.errorExpected(pos, "type")
p.advance(exprEnd)
return &ast.BadExpr{From: pos, To: p.pos}
}
return typ
}
func (p *parser) parseQualifiedIdent(ident *ast.Ident) ast.Expr {
if p.trace {
defer un(trace(p, "QualifiedIdent"))
}
typ := p.parseTypeName(ident)
if p.tok == token.LBRACK {
typ = p.parseTypeInstance(typ)
}
return typ
}
// If the result is an identifier, it is not resolved.
func (p *parser) parseTypeName(ident *ast.Ident) ast.Expr {
if p.trace {
defer un(trace(p, "TypeName"))
}
if ident == nil {
ident = p.parseIdent()
}
if p.tok == token.PERIOD {
// ident is a package name
p.next()
sel := p.parseIdent()
return &ast.SelectorExpr{X: ident, Sel: sel}
}
return ident
}
// "[" has already been consumed, and lbrack is its position.
// If len != nil it is the already consumed array length.
func (p *parser) parseArrayType(lbrack token.Pos, len ast.Expr) *ast.ArrayType {
if p.trace {
defer un(trace(p, "ArrayType"))
}
if len == nil {
p.exprLev++
// always permit ellipsis for more fault-tolerant parsing
if p.tok == token.ELLIPSIS {
len = &ast.Ellipsis{Ellipsis: p.pos}
p.next()
} else if p.tok != token.RBRACK {
len = p.parseRhs()
}
p.exprLev--
}
if p.tok == token.COMMA {
// Trailing commas are accepted in type parameter
// lists but not in array type declarations.
// Accept for better error handling but complain.
p.error(p.pos, "unexpected comma; expecting ]")
p.next()
}
p.expect(token.RBRACK)
elt := p.parseType()
return &ast.ArrayType{Lbrack: lbrack, Len: len, Elt: elt}
}
func (p *parser) parseArrayFieldOrTypeInstance(x *ast.Ident) (*ast.Ident, ast.Expr) {
if p.trace {
defer un(trace(p, "ArrayFieldOrTypeInstance"))
}
lbrack := p.expect(token.LBRACK)
trailingComma := token.NoPos // if valid, the position of a trailing comma preceding the ']'
var args []ast.Expr
if p.tok != token.RBRACK {
p.exprLev++
args = append(args, p.parseRhs())
for p.tok == token.COMMA {
comma := p.pos
p.next()
if p.tok == token.RBRACK {
trailingComma = comma
break
}
args = append(args, p.parseRhs())
}
p.exprLev--
}
rbrack := p.expect(token.RBRACK)
_ = rbrack
if len(args) == 0 {
// x []E
elt := p.parseType()
return x, &ast.ArrayType{Lbrack: lbrack, Elt: elt}
}
// x [P]E or x[P]
if len(args) == 1 {
elt := p.tryIdentOrType()
if elt != nil {
// x [P]E
if trailingComma.IsValid() {
// Trailing commas are invalid in array type fields.
p.error(trailingComma, "unexpected comma; expecting ]")
}
return x, &ast.ArrayType{Lbrack: lbrack, Len: args[0], Elt: elt}
}
}
// x[P], x[P1, P2], ...
return nil, nil // typeparams.PackIndexExpr(x, lbrack, args, rbrack)
}
func (p *parser) parseFieldDecl() *ast.Field {
if p.trace {
defer un(trace(p, "FieldDecl"))
}
doc := p.leadComment
var names []*ast.Ident
var typ ast.Expr
switch p.tok {
case token.IDENT:
name := p.parseIdent()
if p.tok == token.PERIOD || p.tok == token.STRING || p.tok == token.SEMICOLON || p.tok == token.RBRACE {
// embedded type
typ = name
if p.tok == token.PERIOD {
typ = p.parseQualifiedIdent(name)
}
} else {
// name1, name2, ... T
names = []*ast.Ident{name}
for p.tok == token.COMMA {
p.next()
names = append(names, p.parseIdent())
}
// Careful dance: We don't know if we have an embedded instantiated
// type T[P1, P2, ...] or a field T of array type []E or [P]E.
if len(names) == 1 && p.tok == token.LBRACK {
name, typ = p.parseArrayFieldOrTypeInstance(name)
if name == nil {
names = nil
}
} else {
// T P
typ = p.parseType()
}
}
case token.MUL:
star := p.pos
p.next()
if p.tok == token.LPAREN {
// *(T)
p.error(p.pos, "cannot parenthesize embedded type")
p.next()
typ = p.parseQualifiedIdent(nil)
// expect closing ')' but no need to complain if missing
if p.tok == token.RPAREN {
p.next()
}
} else {
// *T
typ = p.parseQualifiedIdent(nil)
}
typ = &ast.StarExpr{Star: star, X: typ}
case token.LPAREN:
p.error(p.pos, "cannot parenthesize embedded type")
p.next()
if p.tok == token.MUL {
// (*T)
star := p.pos
p.next()
typ = &ast.StarExpr{Star: star, X: p.parseQualifiedIdent(nil)}
} else {
// (T)
typ = p.parseQualifiedIdent(nil)
}
// expect closing ')' but no need to complain if missing
if p.tok == token.RPAREN {
p.next()
}
default:
pos := p.pos
p.errorExpected(pos, "field name or embedded type")
p.advance(exprEnd)
typ = &ast.BadExpr{From: pos, To: p.pos}
}
var tag *ast.BasicLit
if p.tok == token.STRING {
tag = &ast.BasicLit{ValuePos: p.pos, Kind: p.tok, Value: p.lit}
p.next()
}
comment := p.expectSemi()
field := &ast.Field{Doc: doc, Names: names, Type: typ, Tag: tag, Comment: comment}
return field
}
func (p *parser) parseStructType() *ast.StructType {
if p.trace {
defer un(trace(p, "StructType"))
}
pos := p.expect(token.STRUCT)
lbrace := p.expect(token.LBRACE)
var list []*ast.Field
for p.tok == token.IDENT || p.tok == token.MUL || p.tok == token.LPAREN {
// a field declaration cannot start with a '(' but we accept
// it here for more robust parsing and better error messages
// (parseFieldDecl will check and complain if necessary)
list = append(list, p.parseFieldDecl())
}
rbrace := p.expect(token.RBRACE)
return &ast.StructType{
Struct: pos,
Fields: &ast.FieldList{
Opening: lbrace,
List: list,
Closing: rbrace,
},
}
}
func (p *parser) parsePointerType() *ast.StarExpr {
if p.trace {
defer un(trace(p, "PointerType"))
}
star := p.expect(token.MUL)
base := p.parseType()
return &ast.StarExpr{Star: star, X: base}
}
func (p *parser) parseDotsType() *ast.Ellipsis {
if p.trace {
defer un(trace(p, "DotsType"))
}
pos := p.expect(token.ELLIPSIS)
elt := p.parseType()
return &ast.Ellipsis{Ellipsis: pos, Elt: elt}
}
type field struct {
name *ast.Ident
typ ast.Expr
}
func (p *parser) parseParamDecl(name *ast.Ident, typeSetsOK bool) (f field) {
// TODO(rFindley) refactor to be more similar to paramDeclOrNil in the syntax
// package
if p.trace {
defer un(trace(p, "ParamDeclOrNil"))
}
ptok := p.tok
if name != nil {
p.tok = token.IDENT // force token.IDENT case in switch below
} else if typeSetsOK && p.tok == token.TILDE {
// "~" ...
return field{nil, p.embeddedElem(nil)}
}
switch p.tok {
case token.IDENT:
// name
if name != nil {
f.name = name
p.tok = ptok
} else {
f.name = p.parseIdent()
}
switch p.tok {
case token.IDENT, token.MUL, token.ARROW, token.FUNC, token.CHAN, token.MAP, token.STRUCT, token.INTERFACE, token.LPAREN:
// name type
f.typ = p.parseType()
case token.LBRACK:
// name "[" type1, ..., typeN "]" or name "[" n "]" type
f.name, f.typ = p.parseArrayFieldOrTypeInstance(f.name)
case token.ELLIPSIS:
// name "..." type
f.typ = p.parseDotsType()
return // don't allow ...type "|" ...
case token.PERIOD:
// name "." ...
f.typ = p.parseQualifiedIdent(f.name)
f.name = nil
case token.TILDE:
if typeSetsOK {
f.typ = p.embeddedElem(nil)
return
}
case token.OR:
if typeSetsOK {
// name "|" typeset
f.typ = p.embeddedElem(f.name)
f.name = nil
return
}
}
case token.MUL, token.ARROW, token.FUNC, token.LBRACK, token.CHAN, token.MAP, token.STRUCT, token.INTERFACE, token.LPAREN:
// type
f.typ = p.parseType()
case token.ELLIPSIS:
// "..." type
// (always accepted)
f.typ = p.parseDotsType()
return // don't allow ...type "|" ...
default:
// TODO(rfindley): this is incorrect in the case of type parameter lists
// (should be "']'" in that case)
p.errorExpected(p.pos, "')'")
p.advance(exprEnd)
}
// [name] type "|"
if typeSetsOK && p.tok == token.OR && f.typ != nil {
f.typ = p.embeddedElem(f.typ)
}
return
}
func (p *parser) parseParameterList(name0 *ast.Ident, typ0 ast.Expr, closing token.Token) (params []*ast.Field) {
if p.trace {
defer un(trace(p, "ParameterList"))
}
// Type parameters are the only parameter list closed by ']'.
tparams := closing == token.RBRACK
pos0 := p.pos
if name0 != nil {
pos0 = name0.Pos()
} else if typ0 != nil {
pos0 = typ0.Pos()
}
// Note: The code below matches the corresponding code in the syntax
// parser closely. Changes must be reflected in either parser.
// For the code to match, we use the local []field list that
// corresponds to []syntax.Field. At the end, the list must be
// converted into an []*ast.Field.
var list []field
var named int // number of parameters that have an explicit name and type
var typed int // number of parameters that have an explicit type
for name0 != nil || p.tok != closing && p.tok != token.EOF {
var par field
if typ0 != nil {
if tparams {
typ0 = p.embeddedElem(typ0)
}
par = field{name0, typ0}
} else {
par = p.parseParamDecl(name0, tparams)
}
name0 = nil // 1st name was consumed if present
typ0 = nil // 1st typ was consumed if present
if par.name != nil || par.typ != nil {
list = append(list, par)
if par.name != nil && par.typ != nil {
named++
}
if par.typ != nil {
typed++
}
}
if !p.atComma("parameter list", closing) {
break
}
p.next()
}
if len(list) == 0 {
return // not uncommon
}
// distribute parameter types (len(list) > 0)
if named == 0 {
// all unnamed => found names are type names
for i := 0; i < len(list); i++ {
par := &list[i]
if typ := par.name; typ != nil {
par.typ = typ
par.name = nil
}
}
if tparams {
// This is the same error handling as below, adjusted for type parameters only.
// See comment below for details. (go.dev/issue/64534)
var errPos token.Pos
var msg string
if named == typed /* same as typed == 0 */ {
errPos = p.pos // position error at closing ]
msg = "missing type constraint"
} else {
errPos = pos0 // position at opening [ or first name
msg = "missing type parameter name"
if len(list) == 1 {
msg += " or invalid array length"
}
}
p.error(errPos, msg)
}
} else if named != len(list) {
// some named or we're in a type parameter list => all must be named
var errPos token.Pos // left-most error position (or invalid)
var typ ast.Expr // current type (from right to left)
for i := len(list) - 1; i >= 0; i-- {
if par := &list[i]; par.typ != nil {
typ = par.typ
if par.name == nil {
errPos = typ.Pos()
n := ast.NewIdent("_")
n.NamePos = errPos // correct position
par.name = n
}
} else if typ != nil {
par.typ = typ
} else {
// par.typ == nil && typ == nil => we only have a par.name
errPos = par.name.Pos()
par.typ = &ast.BadExpr{From: errPos, To: p.pos}
}
}
if errPos.IsValid() {
var msg string
if tparams {
// Not all parameters are named because named != len(list).
// If named == typed we must have parameters that have no types,
// and they must be at the end of the parameter list, otherwise
// the types would have been filled in by the right-to-left sweep
// above and we wouldn't have an error. Since we are in a type
// parameter list, the missing types are constraints.
if named == typed {
errPos = p.pos // position error at closing ]
msg = "missing type constraint"
} else {
msg = "missing type parameter name"
// go.dev/issue/60812
if len(list) == 1 {
msg += " or invalid array length"
}
}
} else {
msg = "mixed named and unnamed parameters"
}
p.error(errPos, msg)
}
}
// Convert list to []*ast.Field.
// If list contains types only, each type gets its own ast.Field.
if named == 0 {
// parameter list consists of types only
for _, par := range list {
passert(par.typ != nil, "nil type in unnamed parameter list")
params = append(params, &ast.Field{Type: par.typ})
}
return
}
// If the parameter list consists of named parameters with types,
// collect all names with the same types into a single ast.Field.
var names []*ast.Ident
var typ ast.Expr
addParams := func() {
passert(typ != nil, "nil type in named parameter list")
field := &ast.Field{Names: names, Type: typ}
params = append(params, field)
names = nil
}
for _, par := range list {
if par.typ != typ {
if len(names) > 0 {
addParams()
}
typ = par.typ
}
names = append(names, par.name)
}
if len(names) > 0 {
addParams()
}
return
}
func (p *parser) parseParameters(acceptTParams bool) (tparams, params *ast.FieldList) {
if p.trace {
defer un(trace(p, "Parameters"))
}
if acceptTParams && p.tok == token.LBRACK {
opening := p.pos
p.next()
// [T any](params) syntax
list := p.parseParameterList(nil, nil, token.RBRACK)
rbrack := p.expect(token.RBRACK)
tparams = &ast.FieldList{Opening: opening, List: list, Closing: rbrack}
// Type parameter lists must not be empty.
if tparams.NumFields() == 0 {
p.error(tparams.Closing, "empty type parameter list")
tparams = nil // avoid follow-on errors
}
}
opening := p.expect(token.LPAREN)
var fields []*ast.Field
if p.tok != token.RPAREN {
fields = p.parseParameterList(nil, nil, token.RPAREN)
}
rparen := p.expect(token.RPAREN)
params = &ast.FieldList{Opening: opening, List: fields, Closing: rparen}
return
}
func (p *parser) parseResult() *ast.FieldList {
if p.trace {
defer un(trace(p, "Result"))
}
if p.tok == token.LPAREN {
_, results := p.parseParameters(false)
return results
}
typ := p.tryIdentOrType()
if typ != nil {
list := make([]*ast.Field, 1)
list[0] = &ast.Field{Type: typ}
return &ast.FieldList{List: list}
}
return nil
}
func (p *parser) parseFuncType() *ast.FuncType {
if p.trace {
defer un(trace(p, "FuncType"))
}
pos := p.expect(token.FUNC)
tparams, params := p.parseParameters(true)
if tparams != nil {
p.error(tparams.Pos(), "function type must have no type parameters")
}
results := p.parseResult()
return &ast.FuncType{Func: pos, Params: params, Results: results}
}
func (p *parser) parseMethodSpec() *ast.Field {
if p.trace {
defer un(trace(p, "MethodSpec"))
}
doc := p.leadComment
var idents []*ast.Ident
var typ ast.Expr
x := p.parseTypeName(nil)
if ident, _ := x.(*ast.Ident); ident != nil {
switch {
case p.tok == token.LBRACK:
// generic method or embedded instantiated type
lbrack := p.pos
p.next()
p.exprLev++
x := p.parseExpr()
p.exprLev--
if name0, _ := x.(*ast.Ident); name0 != nil && p.tok != token.COMMA && p.tok != token.RBRACK {
// generic method m[T any]
//
// Interface methods do not have type parameters. We parse them for a
// better error message and improved error recovery.
_ = p.parseParameterList(name0, nil, token.RBRACK)
_ = p.expect(token.RBRACK)
p.error(lbrack, "interface method must have no type parameters")
// TODO(rfindley) refactor to share code with parseFuncType.
_, params := p.parseParameters(false)
results := p.parseResult()
idents = []*ast.Ident{ident}
typ = &ast.FuncType{
Func: token.NoPos,
Params: params,
Results: results,
}
} else {
// embedded instantiated type
// TODO(rfindley) should resolve all identifiers in x.
list := []ast.Expr{x}
if p.atComma("type argument list", token.RBRACK) {
p.exprLev++
p.next()
for p.tok != token.RBRACK && p.tok != token.EOF {
list = append(list, p.parseType())
if !p.atComma("type argument list", token.RBRACK) {
break
}
p.next()
}
p.exprLev--
}
// rbrack := p.expectClosing(token.RBRACK, "type argument list")
// typ = typeparams.PackIndexExpr(ident, lbrack, list, rbrack)
}
case p.tok == token.LPAREN:
// ordinary method
// TODO(rfindley) refactor to share code with parseFuncType.
_, params := p.parseParameters(false)
results := p.parseResult()
idents = []*ast.Ident{ident}
typ = &ast.FuncType{Func: token.NoPos, Params: params, Results: results}
default:
// embedded type
typ = x
}
} else {
// embedded, possibly instantiated type
typ = x
if p.tok == token.LBRACK {
// embedded instantiated interface
typ = p.parseTypeInstance(typ)
}
}
// Comment is added at the callsite: the field below may joined with
// additional type specs using '|'.
// TODO(rfindley) this should be refactored.
// TODO(rfindley) add more tests for comment handling.
return &ast.Field{Doc: doc, Names: idents, Type: typ}
}
func (p *parser) embeddedElem(x ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "EmbeddedElem"))
}
if x == nil {
x = p.embeddedTerm()
}
for p.tok == token.OR {
t := new(ast.BinaryExpr)
t.OpPos = p.pos
t.Op = token.OR
p.next()
t.X = x
t.Y = p.embeddedTerm()
x = t
}
return x
}
func (p *parser) embeddedTerm() ast.Expr {
if p.trace {
defer un(trace(p, "EmbeddedTerm"))
}
if p.tok == token.TILDE {
t := new(ast.UnaryExpr)
t.OpPos = p.pos
t.Op = token.TILDE
p.next()
t.X = p.parseType()
return t
}
t := p.tryIdentOrType()
if t == nil {
pos := p.pos
p.errorExpected(pos, "~ term or type")
p.advance(exprEnd)
return &ast.BadExpr{From: pos, To: p.pos}
}
return t
}
func (p *parser) parseInterfaceType() *ast.InterfaceType {
if p.trace {
defer un(trace(p, "InterfaceType"))
}
pos := p.expect(token.INTERFACE)
lbrace := p.expect(token.LBRACE)
var list []*ast.Field
parseElements:
for {
switch {
case p.tok == token.IDENT:
f := p.parseMethodSpec()
if f.Names == nil {
f.Type = p.embeddedElem(f.Type)
}
f.Comment = p.expectSemi()
list = append(list, f)
case p.tok == token.TILDE:
typ := p.embeddedElem(nil)
comment := p.expectSemi()
list = append(list, &ast.Field{Type: typ, Comment: comment})
default:
if t := p.tryIdentOrType(); t != nil {
typ := p.embeddedElem(t)
comment := p.expectSemi()
list = append(list, &ast.Field{Type: typ, Comment: comment})
} else {
break parseElements
}
}
}
// TODO(rfindley): the error produced here could be improved, since we could
// accept an identifier, 'type', or a '}' at this point.
rbrace := p.expect(token.RBRACE)
return &ast.InterfaceType{
Interface: pos,
Methods: &ast.FieldList{
Opening: lbrace,
List: list,
Closing: rbrace,
},
}
}
func (p *parser) parseMapType() *ast.MapType {
if p.trace {
defer un(trace(p, "MapType"))
}
pos := p.expect(token.MAP)
p.expect(token.LBRACK)
key := p.parseType()
p.expect(token.RBRACK)
value := p.parseType()
return &ast.MapType{Map: pos, Key: key, Value: value}
}
func (p *parser) parseChanType() *ast.ChanType {
if p.trace {
defer un(trace(p, "ChanType"))
}
pos := p.pos
dir := ast.SEND | ast.RECV
var arrow token.Pos
if p.tok == token.CHAN {
p.next()
if p.tok == token.ARROW {
arrow = p.pos
p.next()
dir = ast.SEND
}
} else {
arrow = p.expect(token.ARROW)
p.expect(token.CHAN)
dir = ast.RECV
}
value := p.parseType()
return &ast.ChanType{Begin: pos, Arrow: arrow, Dir: dir, Value: value}
}
func (p *parser) parseTypeInstance(typ ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "TypeInstance"))
}
opening := p.expect(token.LBRACK)
p.exprLev++
var list []ast.Expr
for p.tok != token.RBRACK && p.tok != token.EOF {
list = append(list, p.parseType())
if !p.atComma("type argument list", token.RBRACK) {
break
}
p.next()
}
p.exprLev--
closing := p.expectClosing(token.RBRACK, "type argument list")
if len(list) == 0 {
p.errorExpected(closing, "type argument list")
return &ast.IndexExpr{
X: typ,
Lbrack: opening,
Index: &ast.BadExpr{From: opening + 1, To: closing},
Rbrack: closing,
}
}
return nil // typeparams.PackIndexExpr(typ, opening, list, closing)
}
func (p *parser) tryIdentOrType() ast.Expr {
defer decNestLev(incNestLev(p))
switch p.tok {
case token.IDENT:
typ := p.parseTypeName(nil)
if p.tok == token.LBRACK {
typ = p.parseTypeInstance(typ)
}
return typ
case token.LBRACK:
lbrack := p.expect(token.LBRACK)
return p.parseArrayList(lbrack) // math: full array exprs
// return p.parseArrayType(lbrack, nil)
case token.STRUCT:
return p.parseStructType()
case token.MUL:
return p.parsePointerType()
case token.FUNC:
return p.parseFuncType()
case token.INTERFACE:
return p.parseInterfaceType()
case token.MAP:
return p.parseMapType()
case token.CHAN, token.ARROW:
return p.parseChanType()
case token.LPAREN:
lparen := p.pos
p.next()
typ := p.parseType()
rparen := p.expect(token.RPAREN)
return &ast.ParenExpr{Lparen: lparen, X: typ, Rparen: rparen}
}
// no type found
return nil
}
// ----------------------------------------------------------------------------
// Blocks
func (p *parser) parseStmtList() (list []ast.Stmt) {
if p.trace {
defer un(trace(p, "StatementList"))
}
for p.tok != token.CASE && p.tok != token.DEFAULT && p.tok != token.RBRACE && p.tok != token.EOF {
list = append(list, p.parseStmt())
}
return
}
func (p *parser) parseBody() *ast.BlockStmt {
if p.trace {
defer un(trace(p, "Body"))
}
lbrace := p.expect(token.LBRACE)
list := p.parseStmtList()
rbrace := p.expect2(token.RBRACE)
return &ast.BlockStmt{Lbrace: lbrace, List: list, Rbrace: rbrace}
}
func (p *parser) parseBlockStmt() *ast.BlockStmt {
if p.trace {
defer un(trace(p, "BlockStmt"))
}
lbrace := p.expect(token.LBRACE)
if p.tok == token.EOF { // math: allow start only
return &ast.BlockStmt{Lbrace: lbrace}
}
list := p.parseStmtList()
rbrace := p.expect2(token.RBRACE)
return &ast.BlockStmt{Lbrace: lbrace, List: list, Rbrace: rbrace}
}
// ----------------------------------------------------------------------------
// Expressions
func (p *parser) parseFuncTypeOrLit() ast.Expr {
if p.trace {
defer un(trace(p, "FuncTypeOrLit"))
}
typ := p.parseFuncType()
if p.tok != token.LBRACE {
// function type only
return typ
}
p.exprLev++
body := p.parseBody()
p.exprLev--
return &ast.FuncLit{Type: typ, Body: body}
}
// parseOperand may return an expression or a raw type (incl. array
// types of the form [...]T). Callers must verify the result.
func (p *parser) parseOperand() ast.Expr {
if p.trace {
defer un(trace(p, "Operand"))
}
switch p.tok {
case token.IDENT:
x := p.parseIdent()
return x
case token.INT, token.FLOAT, token.IMAG, token.CHAR, token.STRING:
x := &ast.BasicLit{ValuePos: p.pos, Kind: p.tok, Value: p.lit}
// fmt.Println("operand lit:", p.lit)
p.next()
if p.tok == token.COLON {
return p.parseSliceExpr(x)
} else {
return x
}
case token.LPAREN:
lparen := p.pos
p.next()
p.exprLev++
x := p.parseRhs() // types may be parenthesized: (some type)
p.exprLev--
rparen := p.expect(token.RPAREN)
return &ast.ParenExpr{Lparen: lparen, X: x, Rparen: rparen}
case token.FUNC:
return p.parseFuncTypeOrLit()
case token.COLON:
p.expect(token.COLON)
return p.parseSliceExpr(nil)
}
if typ := p.tryIdentOrType(); typ != nil { // do not consume trailing type parameters
// could be type for composite literal or conversion
_, isIdent := typ.(*ast.Ident)
passert(!isIdent, "type cannot be identifier")
return typ
}
// we have an error
pos := p.pos
p.errorExpected(pos, "operand")
p.advance(stmtStart)
return &ast.BadExpr{From: pos, To: p.pos}
}
func (p *parser) parseSelector(x ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "Selector"))
}
sel := p.parseIdent()
return &ast.SelectorExpr{X: x, Sel: sel}
}
func (p *parser) parseTypeAssertion(x ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "TypeAssertion"))
}
lparen := p.expect(token.LPAREN)
var typ ast.Expr
if p.tok == token.TYPE {
// type switch: typ == nil
p.next()
} else {
typ = p.parseType()
}
rparen := p.expect(token.RPAREN)
return &ast.TypeAssertExpr{X: x, Type: typ, Lparen: lparen, Rparen: rparen}
}
func (p *parser) parseIndexOrSliceOrInstance(x ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "parseIndexOrSliceOrInstance"))
}
lbrack := p.expect(token.LBRACK)
if p.tok == token.RBRACK {
// empty index, slice or index expressions are not permitted;
// accept them for parsing tolerance, but complain
p.errorExpected(p.pos, "operand")
rbrack := p.pos
p.next()
return &ast.IndexExpr{
X: x,
Lbrack: lbrack,
Index: &ast.BadExpr{From: rbrack, To: rbrack},
Rbrack: rbrack,
}
}
ix := p.parseArrayList(lbrack)
return &ast.IndexExpr{
X: x,
Lbrack: lbrack,
Index: ix,
Rbrack: ix.Rbrack,
}
}
func (p *parser) parseSliceExpr(ex ast.Expr) *ast.SliceExpr {
if p.trace {
defer un(trace(p, "parseSliceExpr"))
}
lbrack := p.pos
p.exprLev++
const N = 3 // change the 3 to 2 to disable 3-index slices
var index [N]ast.Expr
index[0] = ex
var colons [N - 1]token.Pos
ncolons := 0
if ex == nil {
ncolons++
}
var rpos token.Pos
// fmt.Println(ncolons, p.tok)
switch p.tok {
case token.COLON:
// slice expression
for p.tok == token.COLON && ncolons < len(colons) {
colons[ncolons] = p.pos
ncolons++
p.next()
if p.tok != token.COMMA && p.tok != token.COLON && p.tok != token.RBRACK && p.tok != token.EOF {
ix := p.parseRhs()
if se, ok := ix.(*ast.SliceExpr); ok {
index[ncolons] = se.Low
if ncolons == 1 && se.High != nil {
ncolons++
index[ncolons] = se.High
}
// fmt.Printf("nc: %d low: %#v hi: %#v max: %#v\n", ncolons, se.Low, se.High, se.Max)
} else {
// fmt.Printf("nc: %d low: %#v\n", ncolons, ix)
if _, ok := ix.(*ast.BadExpr); !ok {
index[ncolons] = ix
}
}
// } else {
// fmt.Println(ncolons, "else")
}
}
case token.COMMA:
rpos = p.pos // expect(token.COMMA)
case token.RBRACK:
rpos = p.pos // expect(token.RBRACK)
// instance expression
// args = append(args, index[0])
// for p.tok == token.COMMA {
// p.next()
// if p.tok != token.RBRACK && p.tok != token.EOF {
// args = append(args, p.parseType())
// }
// }
default:
ix := p.parseRhs()
// fmt.Printf("nc: %d ix: %#v\n", ncolons, ix)
index[ncolons] = ix
}
p.exprLev--
// rbrack := p.expect(token.RBRACK)
// slice expression
slice3 := false
if ncolons == 2 {
slice3 = true
// Check presence of middle and final index here rather than during type-checking
// to prevent erroneous programs from passing through gofmt (was go.dev/issue/7305).
// if index[1] == nil {
// p.error(colons[0], "middle index required in 3-index slice")
// index[1] = &ast.BadExpr{From: colons[0] + 1, To: colons[1]}
// }
// if index[2] == nil {
// p.error(colons[1], "final index required in 3-index slice")
// index[2] = &ast.BadExpr{From: colons[1] + 1} // , To: rbrack
// }
}
se := &ast.SliceExpr{Lbrack: lbrack, Low: index[0], High: index[1], Max: index[2], Slice3: slice3, Rbrack: rpos}
// fmt.Printf("final: %#v\n", se)
return se
//
// if len(args) == 0 {
// // index expression
// return &ast.IndexExpr{X: x, Lbrack: lbrack, Index: index[0], Rbrack: rbrack}
// }
// instance expression
return nil // typeparams.PackIndexExpr(x, lbrack, args, rbrack)
}
func (p *parser) parseCallOrConversion(fun ast.Expr) *ast.CallExpr {
if p.trace {
defer un(trace(p, "CallOrConversion"))
}
lparen := p.expect(token.LPAREN)
p.exprLev++
var list []ast.Expr
var ellipsis token.Pos
for p.tok != token.RPAREN && p.tok != token.EOF && !ellipsis.IsValid() {
list = append(list, p.parseRhs()) // builtins may expect a type: make(some type, ...)
if p.tok == token.ELLIPSIS {
ellipsis = p.pos
p.next()
}
if !p.atComma("argument list", token.RPAREN) {
break
}
p.next()
}
p.exprLev--
rparen := p.expectClosing(token.RPAREN, "argument list")
return &ast.CallExpr{Fun: fun, Lparen: lparen, Args: list, Ellipsis: ellipsis, Rparen: rparen}
}
func (p *parser) parseValue() ast.Expr {
if p.trace {
defer un(trace(p, "Element"))
}
if p.tok == token.LBRACE {
return p.parseLiteralValue(nil)
}
x := p.parseExpr()
return x
}
func (p *parser) parseElement() ast.Expr {
if p.trace {
defer un(trace(p, "Element"))
}
x := p.parseValue()
if p.tok == token.COLON {
colon := p.pos
p.next()
x = &ast.KeyValueExpr{Key: x, Colon: colon, Value: p.parseValue()}
}
return x
}
func (p *parser) parseElementList() (list []ast.Expr) {
if p.trace {
defer un(trace(p, "ElementList"))
}
for p.tok != token.RBRACE && p.tok != token.EOF {
list = append(list, p.parseElement())
if !p.atComma("composite literal", token.RBRACE) {
break
}
p.next()
}
return
}
func (p *parser) parseLiteralValue(typ ast.Expr) ast.Expr {
defer decNestLev(incNestLev(p))
if p.trace {
defer un(trace(p, "LiteralValue"))
}
lbrace := p.expect(token.LBRACE)
var elts []ast.Expr
p.exprLev++
if p.tok != token.RBRACE {
elts = p.parseElementList()
}
p.exprLev--
rbrace := p.expectClosing(token.RBRACE, "composite literal")
return &ast.CompositeLit{Type: typ, Lbrace: lbrace, Elts: elts, Rbrace: rbrace}
}
func (p *parser) parsePrimaryExpr(x ast.Expr) ast.Expr {
if p.trace {
defer un(trace(p, "PrimaryExpr"))
}
// math: ellipses can show up in index expression.
if p.tok == token.ELLIPSIS {
p.next()
return &ast.Ellipsis{Ellipsis: p.pos}
}
if x == nil {
x = p.parseOperand()
}
// We track the nesting here rather than at the entry for the function,
// since it can iteratively produce a nested output, and we want to
// limit how deep a structure we generate.
var n int
defer func() { p.nestLev -= n }()
for n = 1; ; n++ {
incNestLev(p)
switch p.tok {
case token.PERIOD:
p.next()
switch p.tok {
case token.IDENT:
x = p.parseSelector(x)
case token.LPAREN:
x = p.parseTypeAssertion(x)
default:
pos := p.pos
p.errorExpected(pos, "selector or type assertion")
// TODO(rFindley) The check for token.RBRACE below is a targeted fix
// to error recovery sufficient to make the x/tools tests to
// pass with the new parsing logic introduced for type
// parameters. Remove this once error recovery has been
// more generally reconsidered.
if p.tok != token.RBRACE {
p.next() // make progress
}
sel := &ast.Ident{NamePos: pos, Name: "_"}
x = &ast.SelectorExpr{X: x, Sel: sel}
}
case token.LBRACK:
x = p.parseIndexOrSliceOrInstance(x)
case token.LPAREN:
x = p.parseCallOrConversion(x)
case token.LBRACE:
// operand may have returned a parenthesized complit
// type; accept it but complain if we have a complit
t := ast.Unparen(x)
// determine if '{' belongs to a composite literal or a block statement
switch t.(type) {
case *ast.BadExpr, *ast.Ident, *ast.SelectorExpr:
if p.exprLev < 0 {
return x
}
// x is possibly a composite literal type
case *ast.IndexExpr, *ast.IndexListExpr:
if p.exprLev < 0 {
return x
}
// x is possibly a composite literal type
case *ast.ArrayType, *ast.StructType, *ast.MapType:
// x is a composite literal type
default:
return x
}
if t != x {
p.error(t.Pos(), "cannot parenthesize type in composite literal")
// already progressed, no need to advance
}
x = p.parseLiteralValue(x)
default:
return x
}
}
}
func (p *parser) parseUnaryExpr() ast.Expr {
defer decNestLev(incNestLev(p))
if p.trace {
defer un(trace(p, "UnaryExpr"))
}
switch p.tok {
case token.ADD, token.SUB, token.NOT, token.XOR, token.AND, token.TILDE:
pos, op := p.pos, p.tok
p.next()
x := p.parseUnaryExpr()
return &ast.UnaryExpr{OpPos: pos, Op: op, X: x}
case token.ARROW:
// channel type or receive expression
arrow := p.pos
p.next()
// If the next token is token.CHAN we still don't know if it
// is a channel type or a receive operation - we only know
// once we have found the end of the unary expression. There
// are two cases:
//
// <- type => (<-type) must be channel type
// <- expr => <-(expr) is a receive from an expression
//
// In the first case, the arrow must be re-associated with
// the channel type parsed already:
//
// <- (chan type) => (<-chan type)
// <- (chan<- type) => (<-chan (<-type))
x := p.parseUnaryExpr()
// determine which case we have
if typ, ok := x.(*ast.ChanType); ok {
// (<-type)
// re-associate position info and <-
dir := ast.SEND
for ok && dir == ast.SEND {
if typ.Dir == ast.RECV {
// error: (<-type) is (<-(<-chan T))
p.errorExpected(typ.Arrow, "'chan'")
}
arrow, typ.Begin, typ.Arrow = typ.Arrow, arrow, arrow
dir, typ.Dir = typ.Dir, ast.RECV
typ, ok = typ.Value.(*ast.ChanType)
}
if dir == ast.SEND {
p.errorExpected(arrow, "channel type")
}
return x
}
// <-(expr)
return &ast.UnaryExpr{OpPos: arrow, Op: token.ARROW, X: x}
case token.MUL:
// pointer type or unary "*" expression
pos := p.pos
p.next()
x := p.parseUnaryExpr()
return &ast.StarExpr{Star: pos, X: x}
}
return p.parsePrimaryExpr(nil)
}
func (p *parser) tokPrec() (token.Token, int) {
tok := p.tok
if p.inRhs && tok == token.ASSIGN {
tok = token.EQL
}
if p.tok == token.ILLEGAL && p.lit == "@" {
// fmt.Println("@ token")
return token.ILLEGAL, 5
}
return tok, tok.Precedence()
}
// parseBinaryExpr parses a (possibly) binary expression.
// If x is non-nil, it is used as the left operand.
//
// TODO(rfindley): parseBinaryExpr has become overloaded. Consider refactoring.
func (p *parser) parseBinaryExpr(x ast.Expr, prec1 int) ast.Expr {
if p.trace {
defer un(trace(p, "BinaryExpr"))
}
if x == nil {
x = p.parseUnaryExpr()
}
// We track the nesting here rather than at the entry for the function,
// since it can iteratively produce a nested output, and we want to
// limit how deep a structure we generate.
var n int
defer func() { p.nestLev -= n }()
for n = 1; ; n++ {
incNestLev(p)
op, oprec := p.tokPrec()
if oprec < prec1 {
return x
}
pos := p.pos
if op == token.ILLEGAL {
p.next()
} else {
pos = p.expect(op)
}
y := p.parseBinaryExpr(nil, oprec+1)
x = &ast.BinaryExpr{X: x, OpPos: pos, Op: op, Y: y}
}
}
// The result may be a type or even a raw type ([...]int).
func (p *parser) parseExpr() ast.Expr {
if p.trace {
defer un(trace(p, "Expression"))
}
return p.parseBinaryExpr(nil, token.LowestPrec+1)
}
func (p *parser) parseRhs() ast.Expr {
old := p.inRhs
p.inRhs = true
x := p.parseExpr()
p.inRhs = old
return x
}
// ----------------------------------------------------------------------------
// Statements
// Parsing modes for parseSimpleStmt.
const (
basic = iota
labelOk
rangeOk
)
// parseSimpleStmt returns true as 2nd result if it parsed the assignment
// of a range clause (with mode == rangeOk). The returned statement is an
// assignment with a right-hand side that is a single unary expression of
// the form "range x". No guarantees are given for the left-hand side.
func (p *parser) parseSimpleStmt(mode int) (ast.Stmt, bool) {
if p.trace {
defer un(trace(p, "SimpleStmt"))
}
x := p.parseList(false)
switch p.tok {
case
token.DEFINE, token.ASSIGN, token.ADD_ASSIGN,
token.SUB_ASSIGN, token.MUL_ASSIGN, token.QUO_ASSIGN,
token.REM_ASSIGN, token.AND_ASSIGN, token.OR_ASSIGN,
token.XOR_ASSIGN, token.SHL_ASSIGN, token.SHR_ASSIGN, token.AND_NOT_ASSIGN:
// assignment statement, possibly part of a range clause
pos, tok := p.pos, p.tok
p.next()
var y []ast.Expr
isRange := false
if mode == rangeOk && p.tok == token.RANGE && (tok == token.DEFINE || tok == token.ASSIGN) {
pos := p.pos
p.next()
y = []ast.Expr{&ast.UnaryExpr{OpPos: pos, Op: token.RANGE, X: p.parseRhs()}}
isRange = true
} else {
y = p.parseList(true)
}
return &ast.AssignStmt{Lhs: x, TokPos: pos, Tok: tok, Rhs: y}, isRange
}
if len(x) > 1 {
p.errorExpected(x[0].Pos(), "1 expression")
// continue with first expression
}
switch p.tok {
case token.COLON:
// labeled statement
colon := p.pos
p.next()
if label, isIdent := x[0].(*ast.Ident); mode == labelOk && isIdent {
// Go spec: The scope of a label is the body of the function
// in which it is declared and excludes the body of any nested
// function.
stmt := &ast.LabeledStmt{Label: label, Colon: colon, Stmt: p.parseStmt()}
return stmt, false
}
// The label declaration typically starts at x[0].Pos(), but the label
// declaration may be erroneous due to a token after that position (and
// before the ':'). If SpuriousErrors is not set, the (only) error
// reported for the line is the illegal label error instead of the token
// before the ':' that caused the problem. Thus, use the (latest) colon
// position for error reporting.
p.error(colon, "illegal label declaration")
return &ast.BadStmt{From: x[0].Pos(), To: colon + 1}, false
case token.ARROW:
// send statement
arrow := p.pos
p.next()
y := p.parseRhs()
return &ast.SendStmt{Chan: x[0], Arrow: arrow, Value: y}, false
case token.INC, token.DEC:
// increment or decrement
s := &ast.IncDecStmt{X: x[0], TokPos: p.pos, Tok: p.tok}
p.next()
return s, false
}
// expression
return &ast.ExprStmt{X: x[0]}, false
}
func (p *parser) parseCallExpr(callType string) *ast.CallExpr {
x := p.parseRhs() // could be a conversion: (some type)(x)
if t := ast.Unparen(x); t != x {
p.error(x.Pos(), fmt.Sprintf("expression in %s must not be parenthesized", callType))
x = t
}
if call, isCall := x.(*ast.CallExpr); isCall {
return call
}
if _, isBad := x.(*ast.BadExpr); !isBad {
// only report error if it's a new one
p.error(p.safePos(x.End()), fmt.Sprintf("expression in %s must be function call", callType))
}
return nil
}
func (p *parser) parseGoStmt() ast.Stmt {
if p.trace {
defer un(trace(p, "GoStmt"))
}
pos := p.expect(token.GO)
call := p.parseCallExpr("go")
p.expectSemi()
if call == nil {
return &ast.BadStmt{From: pos, To: pos + 2} // len("go")
}
return &ast.GoStmt{Go: pos, Call: call}
}
func (p *parser) parseDeferStmt() ast.Stmt {
if p.trace {
defer un(trace(p, "DeferStmt"))
}
pos := p.expect(token.DEFER)
call := p.parseCallExpr("defer")
p.expectSemi()
if call == nil {
return &ast.BadStmt{From: pos, To: pos + 5} // len("defer")
}
return &ast.DeferStmt{Defer: pos, Call: call}
}
func (p *parser) parseReturnStmt() *ast.ReturnStmt {
if p.trace {
defer un(trace(p, "ReturnStmt"))
}
pos := p.pos
p.expect(token.RETURN)
var x []ast.Expr
if p.tok != token.SEMICOLON && p.tok != token.RBRACE {
x = p.parseList(true)
}
p.expectSemi()
return &ast.ReturnStmt{Return: pos, Results: x}
}
func (p *parser) parseBranchStmt(tok token.Token) *ast.BranchStmt {
if p.trace {
defer un(trace(p, "BranchStmt"))
}
pos := p.expect(tok)
var label *ast.Ident
if tok != token.FALLTHROUGH && p.tok == token.IDENT {
label = p.parseIdent()
}
p.expectSemi()
return &ast.BranchStmt{TokPos: pos, Tok: tok, Label: label}
}
func (p *parser) makeExpr(s ast.Stmt, want string) ast.Expr {
if s == nil {
return nil
}
if es, isExpr := s.(*ast.ExprStmt); isExpr {
return es.X
}
found := "simple statement"
if _, isAss := s.(*ast.AssignStmt); isAss {
found = "assignment"
}
p.error(s.Pos(), fmt.Sprintf("expected %s, found %s (missing parentheses around composite literal?)", want, found))
return &ast.BadExpr{From: s.Pos(), To: p.safePos(s.End())}
}
// parseIfHeader is an adjusted version of parser.header
// in cmd/compile/internal/syntax/parser.go, which has
// been tuned for better error handling.
func (p *parser) parseIfHeader() (init ast.Stmt, cond ast.Expr) {
if p.tok == token.LBRACE {
p.error(p.pos, "missing condition in if statement")
cond = &ast.BadExpr{From: p.pos, To: p.pos}
return
}
// p.tok != token.LBRACE
prevLev := p.exprLev
p.exprLev = -1
if p.tok != token.SEMICOLON {
// accept potential variable declaration but complain
if p.tok == token.VAR {
p.next()
p.error(p.pos, "var declaration not allowed in if initializer")
}
init, _ = p.parseSimpleStmt(basic)
}
var condStmt ast.Stmt
var semi struct {
pos token.Pos
lit string // ";" or "\n"; valid if pos.IsValid()
}
if p.tok != token.LBRACE {
if p.tok == token.SEMICOLON {
semi.pos = p.pos
semi.lit = p.lit
p.next()
} else {
p.expect(token.SEMICOLON)
}
if p.tok != token.LBRACE {
condStmt, _ = p.parseSimpleStmt(basic)
}
} else {
condStmt = init
init = nil
}
if condStmt != nil {
cond = p.makeExpr(condStmt, "boolean expression")
} else if semi.pos.IsValid() {
if semi.lit == "\n" {
p.error(semi.pos, "unexpected newline, expecting { after if clause")
} else {
p.error(semi.pos, "missing condition in if statement")
}
}
// make sure we have a valid AST
if cond == nil {
cond = &ast.BadExpr{From: p.pos, To: p.pos}
}
p.exprLev = prevLev
return
}
func (p *parser) parseIfStmt() *ast.IfStmt {
defer decNestLev(incNestLev(p))
if p.trace {
defer un(trace(p, "IfStmt"))
}
pos := p.expect(token.IF)
init, cond := p.parseIfHeader()
body := p.parseBlockStmt()
var else_ ast.Stmt
if p.tok == token.ELSE {
p.next()
switch p.tok {
case token.IF:
else_ = p.parseIfStmt()
case token.LBRACE:
else_ = p.parseBlockStmt()
p.expectSemi()
default:
p.errorExpected(p.pos, "if statement or block")
else_ = &ast.BadStmt{From: p.pos, To: p.pos}
}
} else {
p.expectSemi()
}
return &ast.IfStmt{If: pos, Init: init, Cond: cond, Body: body, Else: else_}
}
func (p *parser) parseCaseClause() *ast.CaseClause {
if p.trace {
defer un(trace(p, "CaseClause"))
}
pos := p.pos
var list []ast.Expr
if p.tok == token.CASE {
p.next()
list = p.parseList(true)
} else {
p.expect(token.DEFAULT)
}
colon := p.expect(token.COLON)
body := p.parseStmtList()
return &ast.CaseClause{Case: pos, List: list, Colon: colon, Body: body}
}
func isTypeSwitchAssert(x ast.Expr) bool {
a, ok := x.(*ast.TypeAssertExpr)
return ok && a.Type == nil
}
func (p *parser) isTypeSwitchGuard(s ast.Stmt) bool {
switch t := s.(type) {
case *ast.ExprStmt:
// x.(type)
return isTypeSwitchAssert(t.X)
case *ast.AssignStmt:
// v := x.(type)
if len(t.Lhs) == 1 && len(t.Rhs) == 1 && isTypeSwitchAssert(t.Rhs[0]) {
switch t.Tok {
case token.ASSIGN:
// permit v = x.(type) but complain
p.error(t.TokPos, "expected ':=', found '='")
fallthrough
case token.DEFINE:
return true
}
}
}
return false
}
func (p *parser) parseSwitchStmt() ast.Stmt {
if p.trace {
defer un(trace(p, "SwitchStmt"))
}
pos := p.expect(token.SWITCH)
var s1, s2 ast.Stmt
if p.tok != token.LBRACE {
prevLev := p.exprLev
p.exprLev = -1
if p.tok != token.SEMICOLON {
s2, _ = p.parseSimpleStmt(basic)
}
if p.tok == token.SEMICOLON {
p.next()
s1 = s2
s2 = nil
if p.tok != token.LBRACE {
// A TypeSwitchGuard may declare a variable in addition
// to the variable declared in the initial SimpleStmt.
// Introduce extra scope to avoid redeclaration errors:
//
// switch t := 0; t := x.(T) { ... }
//
// (this code is not valid Go because the first t
// cannot be accessed and thus is never used, the extra
// scope is needed for the correct error message).
//
// If we don't have a type switch, s2 must be an expression.
// Having the extra nested but empty scope won't affect it.
s2, _ = p.parseSimpleStmt(basic)
}
}
p.exprLev = prevLev
}
typeSwitch := p.isTypeSwitchGuard(s2)
lbrace := p.expect(token.LBRACE)
var list []ast.Stmt
for p.tok == token.CASE || p.tok == token.DEFAULT {
list = append(list, p.parseCaseClause())
}
rbrace := p.expect(token.RBRACE)
p.expectSemi()
body := &ast.BlockStmt{Lbrace: lbrace, List: list, Rbrace: rbrace}
if typeSwitch {
return &ast.TypeSwitchStmt{Switch: pos, Init: s1, Assign: s2, Body: body}
}
return &ast.SwitchStmt{Switch: pos, Init: s1, Tag: p.makeExpr(s2, "switch expression"), Body: body}
}
func (p *parser) parseCommClause() *ast.CommClause {
if p.trace {
defer un(trace(p, "CommClause"))
}
pos := p.pos
var comm ast.Stmt
if p.tok == token.CASE {
p.next()
lhs := p.parseList(false)
if p.tok == token.ARROW {
// SendStmt
if len(lhs) > 1 {
p.errorExpected(lhs[0].Pos(), "1 expression")
// continue with first expression
}
arrow := p.pos
p.next()
rhs := p.parseRhs()
comm = &ast.SendStmt{Chan: lhs[0], Arrow: arrow, Value: rhs}
} else {
// RecvStmt
if tok := p.tok; tok == token.ASSIGN || tok == token.DEFINE {
// RecvStmt with assignment
if len(lhs) > 2 {
p.errorExpected(lhs[0].Pos(), "1 or 2 expressions")
// continue with first two expressions
lhs = lhs[0:2]
}
pos := p.pos
p.next()
rhs := p.parseRhs()
comm = &ast.AssignStmt{Lhs: lhs, TokPos: pos, Tok: tok, Rhs: []ast.Expr{rhs}}
} else {
// lhs must be single receive operation
if len(lhs) > 1 {
p.errorExpected(lhs[0].Pos(), "1 expression")
// continue with first expression
}
comm = &ast.ExprStmt{X: lhs[0]}
}
}
} else {
p.expect(token.DEFAULT)
}
colon := p.expect(token.COLON)
body := p.parseStmtList()
return &ast.CommClause{Case: pos, Comm: comm, Colon: colon, Body: body}
}
func (p *parser) parseSelectStmt() *ast.SelectStmt {
if p.trace {
defer un(trace(p, "SelectStmt"))
}
pos := p.expect(token.SELECT)
lbrace := p.expect(token.LBRACE)
var list []ast.Stmt
for p.tok == token.CASE || p.tok == token.DEFAULT {
list = append(list, p.parseCommClause())
}
rbrace := p.expect(token.RBRACE)
p.expectSemi()
body := &ast.BlockStmt{Lbrace: lbrace, List: list, Rbrace: rbrace}
return &ast.SelectStmt{Select: pos, Body: body}
}
func (p *parser) parseForStmt() ast.Stmt {
if p.trace {
defer un(trace(p, "ForStmt"))
}
pos := p.expect(token.FOR)
var s1, s2, s3 ast.Stmt
var isRange bool
if p.tok != token.LBRACE {
prevLev := p.exprLev
p.exprLev = -1
if p.tok != token.SEMICOLON {
if p.tok == token.RANGE {
// "for range x" (nil lhs in assignment)
pos := p.pos
p.next()
y := []ast.Expr{&ast.UnaryExpr{OpPos: pos, Op: token.RANGE, X: p.parseRhs()}}
s2 = &ast.AssignStmt{Rhs: y}
isRange = true
} else {
s2, isRange = p.parseSimpleStmt(rangeOk)
}
}
if !isRange && p.tok == token.SEMICOLON {
p.next()
s1 = s2
s2 = nil
if p.tok != token.SEMICOLON {
s2, _ = p.parseSimpleStmt(basic)
}
p.expectSemi()
if p.tok != token.LBRACE {
s3, _ = p.parseSimpleStmt(basic)
}
}
p.exprLev = prevLev
}
body := p.parseBlockStmt()
p.expectSemi()
if isRange {
as := s2.(*ast.AssignStmt)
// check lhs
var key, value ast.Expr
switch len(as.Lhs) {
case 0:
// nothing to do
case 1:
key = as.Lhs[0]
case 2:
key, value = as.Lhs[0], as.Lhs[1]
default:
p.errorExpected(as.Lhs[len(as.Lhs)-1].Pos(), "at most 2 expressions")
return &ast.BadStmt{From: pos, To: p.safePos(body.End())}
}
// parseSimpleStmt returned a right-hand side that
// is a single unary expression of the form "range x"
x := as.Rhs[0].(*ast.UnaryExpr).X
return &ast.RangeStmt{
For: pos,
Key: key,
Value: value,
TokPos: as.TokPos,
Tok: as.Tok,
Range: as.Rhs[0].Pos(),
X: x,
Body: body,
}
}
// regular for statement
return &ast.ForStmt{
For: pos,
Init: s1,
Cond: p.makeExpr(s2, "boolean or range expression"),
Post: s3,
Body: body,
}
}
func (p *parser) parseStmt() (s ast.Stmt) {
defer decNestLev(incNestLev(p))
if p.trace {
defer un(trace(p, "Statement"))
}
switch p.tok {
case token.CONST, token.TYPE, token.VAR:
s = &ast.DeclStmt{Decl: p.parseDecl(stmtStart)}
case
// tokens that may start an expression
token.IDENT, token.INT, token.FLOAT, token.IMAG, token.CHAR, token.STRING, token.FUNC, token.LPAREN, // operands
token.LBRACK, token.STRUCT, token.MAP, token.CHAN, token.INTERFACE, // composite types
token.ADD, token.SUB, token.MUL, token.AND, token.XOR, token.ARROW, token.NOT: // unary operators
s, _ = p.parseSimpleStmt(labelOk)
// because of the required look-ahead, labeled statements are
// parsed by parseSimpleStmt - don't expect a semicolon after
// them
if _, isLabeledStmt := s.(*ast.LabeledStmt); !isLabeledStmt {
p.expectSemi()
}
case token.GO:
s = p.parseGoStmt()
case token.DEFER:
s = p.parseDeferStmt()
case token.RETURN:
s = p.parseReturnStmt()
case token.BREAK, token.CONTINUE, token.GOTO, token.FALLTHROUGH:
s = p.parseBranchStmt(p.tok)
case token.LBRACE:
s = p.parseBlockStmt()
p.expectSemi()
case token.IF:
s = p.parseIfStmt()
case token.SWITCH:
s = p.parseSwitchStmt()
case token.SELECT:
s = p.parseSelectStmt()
case token.FOR:
s = p.parseForStmt()
case token.SEMICOLON:
// Is it ever possible to have an implicit semicolon
// producing an empty statement in a valid program?
// (handle correctly anyway)
s = &ast.EmptyStmt{Semicolon: p.pos, Implicit: p.lit == "\n"}
p.next()
case token.RBRACE:
// a semicolon may be omitted before a closing "}"
s = &ast.EmptyStmt{Semicolon: p.pos, Implicit: true}
default:
// no statement found
pos := p.pos
p.errorExpected(pos, "statement")
p.advance(stmtStart)
s = &ast.BadStmt{From: pos, To: p.pos}
}
return
}
// ----------------------------------------------------------------------------
// Declarations
type parseSpecFunction func(doc *ast.CommentGroup, keyword token.Token, iota int) ast.Spec
func (p *parser) parseImportSpec(doc *ast.CommentGroup, _ token.Token, _ int) ast.Spec {
if p.trace {
defer un(trace(p, "ImportSpec"))
}
var ident *ast.Ident
switch p.tok {
case token.IDENT:
ident = p.parseIdent()
case token.PERIOD:
ident = &ast.Ident{NamePos: p.pos, Name: "."}
p.next()
}
pos := p.pos
var path string
if p.tok == token.STRING {
path = p.lit
p.next()
} else if p.tok.IsLiteral() {
p.error(pos, "import path must be a string")
p.next()
} else {
p.error(pos, "missing import path")
p.advance(exprEnd)
}
comment := p.expectSemi()
// collect imports
spec := &ast.ImportSpec{
Doc: doc,
Name: ident,
Path: &ast.BasicLit{ValuePos: pos, Kind: token.STRING, Value: path},
Comment: comment,
}
p.imports = append(p.imports, spec)
return spec
}
func (p *parser) parseValueSpec(doc *ast.CommentGroup, keyword token.Token, iota int) ast.Spec {
if p.trace {
defer un(trace(p, keyword.String()+"Spec"))
}
idents := p.parseIdentList()
var typ ast.Expr
var values []ast.Expr
switch keyword {
case token.CONST:
// always permit optional type and initialization for more tolerant parsing
if p.tok != token.EOF && p.tok != token.SEMICOLON && p.tok != token.RPAREN {
typ = p.tryIdentOrType()
if p.tok == token.ASSIGN {
p.next()
values = p.parseList(true)
}
}
case token.VAR:
if p.tok != token.ASSIGN {
typ = p.parseType()
}
if p.tok == token.ASSIGN {
p.next()
values = p.parseList(true)
}
default:
panic("unreachable")
}
comment := p.expectSemi()
spec := &ast.ValueSpec{
Doc: doc,
Names: idents,
Type: typ,
Values: values,
Comment: comment,
}
return spec
}
func (p *parser) parseGenericType(spec *ast.TypeSpec, openPos token.Pos, name0 *ast.Ident, typ0 ast.Expr) {
if p.trace {
defer un(trace(p, "parseGenericType"))
}
list := p.parseParameterList(name0, typ0, token.RBRACK)
closePos := p.expect(token.RBRACK)
spec.TypeParams = &ast.FieldList{Opening: openPos, List: list, Closing: closePos}
// Let the type checker decide whether to accept type parameters on aliases:
// see go.dev/issue/46477.
if p.tok == token.ASSIGN {
// type alias
spec.Assign = p.pos
p.next()
}
spec.Type = p.parseType()
}
func (p *parser) parseTypeSpec(doc *ast.CommentGroup, _ token.Token, _ int) ast.Spec {
if p.trace {
defer un(trace(p, "TypeSpec"))
}
name := p.parseIdent()
spec := &ast.TypeSpec{Doc: doc, Name: name}
if p.tok == token.LBRACK {
// spec.Name "[" ...
// array/slice type or type parameter list
lbrack := p.pos
p.next()
if p.tok == token.IDENT {
// We may have an array type or a type parameter list.
// In either case we expect an expression x (which may
// just be a name, or a more complex expression) which
// we can analyze further.
//
// A type parameter list may have a type bound starting
// with a "[" as in: P []E. In that case, simply parsing
// an expression would lead to an error: P[] is invalid.
// But since index or slice expressions are never constant
// and thus invalid array length expressions, if the name
// is followed by "[" it must be the start of an array or
// slice constraint. Only if we don't see a "[" do we
// need to parse a full expression. Notably, name <- x
// is not a concern because name <- x is a statement and
// not an expression.
var x ast.Expr = p.parseIdent()
if p.tok != token.LBRACK {
// To parse the expression starting with name, expand
// the call sequence we would get by passing in name
// to parser.expr, and pass in name to parsePrimaryExpr.
p.exprLev++
lhs := p.parsePrimaryExpr(x)
x = p.parseBinaryExpr(lhs, token.LowestPrec+1)
p.exprLev--
}
// Analyze expression x. If we can split x into a type parameter
// name, possibly followed by a type parameter type, we consider
// this the start of a type parameter list, with some caveats:
// a single name followed by "]" tilts the decision towards an
// array declaration; a type parameter type that could also be
// an ordinary expression but which is followed by a comma tilts
// the decision towards a type parameter list.
if pname, ptype := extractName(x, p.tok == token.COMMA); pname != nil && (ptype != nil || p.tok != token.RBRACK) {
// spec.Name "[" pname ...
// spec.Name "[" pname ptype ...
// spec.Name "[" pname ptype "," ...
p.parseGenericType(spec, lbrack, pname, ptype) // ptype may be nil
} else {
// spec.Name "[" pname "]" ...
// spec.Name "[" x ...
spec.Type = p.parseArrayType(lbrack, x)
}
} else {
// array type
spec.Type = p.parseArrayType(lbrack, nil)
}
} else {
// no type parameters
if p.tok == token.ASSIGN {
// type alias
spec.Assign = p.pos
p.next()
}
spec.Type = p.parseType()
}
spec.Comment = p.expectSemi()
return spec
}
// extractName splits the expression x into (name, expr) if syntactically
// x can be written as name expr. The split only happens if expr is a type
// element (per the isTypeElem predicate) or if force is set.
// If x is just a name, the result is (name, nil). If the split succeeds,
// the result is (name, expr). Otherwise the result is (nil, x).
// Examples:
//
// x force name expr
// ------------------------------------
// P*[]int T/F P *[]int
// P*E T P *E
// P*E F nil P*E
// P([]int) T/F P []int
// P(E) T P E
// P(E) F nil P(E)
// P*E|F|~G T/F P *E|F|~G
// P*E|F|G T P *E|F|G
// P*E|F|G F nil P*E|F|G
func extractName(x ast.Expr, force bool) (*ast.Ident, ast.Expr) {
switch x := x.(type) {
case *ast.Ident:
return x, nil
case *ast.BinaryExpr:
switch x.Op {
case token.MUL:
if name, _ := x.X.(*ast.Ident); name != nil && (force || isTypeElem(x.Y)) {
// x = name *x.Y
return name, &ast.StarExpr{Star: x.OpPos, X: x.Y}
}
case token.OR:
if name, lhs := extractName(x.X, force || isTypeElem(x.Y)); name != nil && lhs != nil {
// x = name lhs|x.Y
op := *x
op.X = lhs
return name, &op
}
}
case *ast.CallExpr:
if name, _ := x.Fun.(*ast.Ident); name != nil {
if len(x.Args) == 1 && x.Ellipsis == token.NoPos && (force || isTypeElem(x.Args[0])) {
// x = name "(" x.ArgList[0] ")"
return name, x.Args[0]
}
}
}
return nil, x
}
// isTypeElem reports whether x is a (possibly parenthesized) type element expression.
// The result is false if x could be a type element OR an ordinary (value) expression.
func isTypeElem(x ast.Expr) bool {
switch x := x.(type) {
case *ast.ArrayType, *ast.StructType, *ast.FuncType, *ast.InterfaceType, *ast.MapType, *ast.ChanType:
return true
case *ast.BinaryExpr:
return isTypeElem(x.X) || isTypeElem(x.Y)
case *ast.UnaryExpr:
return x.Op == token.TILDE
case *ast.ParenExpr:
return isTypeElem(x.X)
}
return false
}
func (p *parser) parseGenDecl(keyword token.Token, f parseSpecFunction) *ast.GenDecl {
if p.trace {
defer un(trace(p, "GenDecl("+keyword.String()+")"))
}
doc := p.leadComment
pos := p.expect(keyword)
var lparen, rparen token.Pos
var list []ast.Spec
if p.tok == token.LPAREN {
lparen = p.pos
p.next()
for iota := 0; p.tok != token.RPAREN && p.tok != token.EOF; iota++ {
list = append(list, f(p.leadComment, keyword, iota))
}
rparen = p.expect(token.RPAREN)
p.expectSemi()
} else {
list = append(list, f(nil, keyword, 0))
}
return &ast.GenDecl{
Doc: doc,
TokPos: pos,
Tok: keyword,
Lparen: lparen,
Specs: list,
Rparen: rparen,
}
}
func (p *parser) parseFuncDecl() *ast.FuncDecl {
if p.trace {
defer un(trace(p, "FunctionDecl"))
}
doc := p.leadComment
pos := p.expect(token.FUNC)
var recv *ast.FieldList
if p.tok == token.LPAREN {
_, recv = p.parseParameters(false)
}
ident := p.parseIdent()
tparams, params := p.parseParameters(true)
if recv != nil && tparams != nil {
// Method declarations do not have type parameters. We parse them for a
// better error message and improved error recovery.
p.error(tparams.Opening, "method must have no type parameters")
tparams = nil
}
results := p.parseResult()
var body *ast.BlockStmt
switch p.tok {
case token.LBRACE:
body = p.parseBody()
p.expectSemi()
case token.SEMICOLON:
p.next()
if p.tok == token.LBRACE {
// opening { of function declaration on next line
p.error(p.pos, "unexpected semicolon or newline before {")
body = p.parseBody()
p.expectSemi()
}
default:
p.expectSemi()
}
decl := &ast.FuncDecl{
Doc: doc,
Recv: recv,
Name: ident,
Type: &ast.FuncType{
Func: pos,
TypeParams: tparams,
Params: params,
Results: results,
},
Body: body,
}
return decl
}
func (p *parser) parseDecl(sync map[token.Token]bool) ast.Decl {
if p.trace {
defer un(trace(p, "Declaration"))
}
var f parseSpecFunction
switch p.tok {
case token.IMPORT:
f = p.parseImportSpec
case token.CONST, token.VAR:
f = p.parseValueSpec
case token.TYPE:
f = p.parseTypeSpec
case token.FUNC:
return p.parseFuncDecl()
default:
pos := p.pos
p.errorExpected(pos, "declaration")
p.advance(sync)
return &ast.BadDecl{From: pos, To: p.pos}
}
return p.parseGenDecl(p.tok, f)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package transpile
import (
"go/token"
)
// ReplaceIdentAt replaces an identifier spanning n tokens
// starting at given index, with a single identifier with given string.
// This is used in Exec mode for dealing with identifiers and paths that are
// separately-parsed by Go.
func (tk Tokens) ReplaceIdentAt(at int, str string, n int) Tokens {
ntk := append(tk[:at], &Token{Tok: token.IDENT, Str: str})
ntk = append(ntk, tk[at+n:]...)
return ntk
}
// Path extracts a standard path or URL expression from the current
// list of tokens (starting at index 0), returning the path string
// and the number of tokens included in the path.
// Restricts processing to contiguous elements with no spaces!
// If it is not a path, returns nil string, 0
func (tk Tokens) Path(idx0 bool) (string, int) {
n := len(tk)
if n == 0 {
return "", 0
}
t0 := tk[0]
ispath := (t0.IsPathDelim() || t0.Tok == token.TILDE)
if n == 1 {
if ispath {
return t0.String(), 1
}
return "", 0
}
str := tk[0].String()
lastEnd := int(tk[0].Pos) + len(str)
ci := 1
if !ispath {
lastEnd = int(tk[0].Pos)
ci = 0
if t0.Tok != token.IDENT {
return "", 0
}
tin := 1
tid := t0.Str
tindelim := tk[tin].IsPathDelim()
if idx0 {
tindelim = tk[tin].Tok == token.QUO
}
if (int(tk[tin].Pos) > lastEnd+len(tid)) || !(tk[tin].Tok == token.COLON || tindelim) {
return "", 0
}
ci += tin + 1
str = tid + tk[tin].String()
lastEnd += len(str)
if ci >= n || int(tk[ci].Pos) > lastEnd { // just 2 or 2 and a space
if tk[tin].Tok == token.COLON { // go Ident: static initializer
return "", 0
}
}
}
prevWasDelim := true
for {
if ci >= n || int(tk[ci].Pos) > lastEnd {
return str, ci
}
ct := tk[ci]
if ct.IsPathDelim() || ct.IsPathExtraDelim() {
prevWasDelim = true
str += ct.String()
lastEnd += len(ct.String())
ci++
continue
}
if ct.Tok == token.STRING {
prevWasDelim = true
str += EscapeQuotes(ct.String())
lastEnd += len(ct.String())
ci++
continue
}
if !prevWasDelim {
if ct.Tok == token.ILLEGAL && ct.Str == `\` && ci+1 < n && int(tk[ci+1].Pos) == lastEnd+2 {
prevWasDelim = true
str += " "
ci++
lastEnd += 2
continue
}
return str, ci
}
if ct.IsWord() {
prevWasDelim = false
str += ct.String()
lastEnd += len(ct.String())
ci++
continue
}
return str, ci
}
}
func (tk *Token) IsPathDelim() bool {
return tk.Tok == token.PERIOD || tk.Tok == token.QUO
}
func (tk *Token) IsPathExtraDelim() bool {
return tk.Tok == token.SUB || tk.Tok == token.ASSIGN || tk.Tok == token.REM || (tk.Tok == token.ILLEGAL && (tk.Str == "?" || tk.Str == "#"))
}
// IsWord returns true if the token is some kind of word-like entity,
// including IDENT, STRING, CHAR, or one of the Go keywords.
// This is for exec filtering.
func (tk *Token) IsWord() bool {
return tk.Tok == token.IDENT || tk.IsGo() || tk.Tok == token.STRING || tk.Tok == token.CHAR
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package transpile
import (
"errors"
"fmt"
"go/format"
"log/slog"
"os"
"strings"
"cogentcore.org/core/base/logx"
"cogentcore.org/core/base/num"
"cogentcore.org/core/base/stringsx"
"golang.org/x/tools/imports"
)
// State holds the transpiling state
type State struct {
// FuncToVar translates function definitions into variable definitions,
// which is the default for interactive use of random code fragments
// without the complete go formatting.
// For pure transpiling of a complete codebase with full proper Go formatting
// this should be turned off.
FuncToVar bool
// MathMode is on when math mode is turned on.
MathMode bool
// MathRecord is state of auto-recording of data into current directory
// in math mode.
MathRecord bool
// depth of delim at the end of the current line. if 0, was complete.
ParenDepth, BraceDepth, BrackDepth, TypeDepth, DeclDepth int
// Chunks of code lines that are accumulated during Transpile,
// each of which should be evaluated separately, to avoid
// issues with contextual effects from import, package etc.
Chunks []string
// current stack of transpiled lines, that are accumulated into
// code Chunks.
Lines []string
// stack of runtime errors.
Errors []error
// if this is non-empty, it is the name of the last command defined.
// triggers insertion of the AddCommand call to add to list of defined commands.
lastCommand string
}
// NewState returns a new transpiling state; mostly for testing
func NewState() *State {
st := &State{FuncToVar: true}
return st
}
// TranspileCode processes each line of given code,
// adding the results to the LineStack
func (st *State) TranspileCode(code string) {
lns := strings.Split(code, "\n")
n := len(lns)
if n == 0 {
return
}
for _, ln := range lns {
hasDecl := st.DeclDepth > 0
tl := st.TranspileLine(ln)
st.AddLine(tl)
if st.BraceDepth == 0 && st.BrackDepth == 0 && st.ParenDepth == 1 && st.lastCommand != "" {
st.lastCommand = ""
nl := len(st.Lines)
st.Lines[nl-1] = st.Lines[nl-1] + ")"
st.ParenDepth--
}
if hasDecl && st.DeclDepth == 0 { // break at decl
st.AddChunk()
}
}
}
// TranspileFile transpiles the given input goal file to the
// given output Go file. If no existing package declaration
// is found, then package main and func main declarations are
// added. This also affects how functions are interpreted.
func (st *State) TranspileFile(in string, out string) error {
b, err := os.ReadFile(in)
if err != nil {
return err
}
code := string(b)
lns := stringsx.SplitLines(code)
hasPackage := false
for _, ln := range lns {
if strings.HasPrefix(ln, "package ") {
hasPackage = true
break
}
}
if hasPackage {
st.FuncToVar = false // use raw functions
}
st.TranspileCode(code)
st.FuncToVar = true
if err != nil {
return err
}
hdr := `package main
import (
"cogentcore.org/lab/goal"
"cogentcore.org/lab/goal/goalib"
"cogentcore.org/lab/tensor"
_ "cogentcore.org/lab/tensor/tmath"
_ "cogentcore.org/lab/stats/stats"
_ "cogentcore.org/lab/stats/metric"
)
func main() {
goalrun := goal.NewGoal()
_ = goalrun
`
src := st.Code()
res := []byte(src)
bsrc := res
gen := fmt.Sprintf("// Code generated by \"goal build\"; DO NOT EDIT.\n//line %s:1\n", in)
if hasPackage {
bsrc = []byte(gen + src)
res, err = format.Source(bsrc)
} else {
bsrc = []byte(gen + hdr + src + "\n}")
res, err = imports.Process(out, bsrc, nil)
}
if err != nil {
res = bsrc
fmt.Println(err.Error())
} else {
err = st.DepthError()
}
werr := os.WriteFile(out, res, 0666)
return errors.Join(err, werr)
}
// TotalDepth returns the sum of any unresolved paren, brace, or bracket depths.
func (st *State) TotalDepth() int {
return num.Abs(st.ParenDepth) + num.Abs(st.BraceDepth) + num.Abs(st.BrackDepth)
}
// ResetCode resets the stack of transpiled code
func (st *State) ResetCode() {
st.Chunks = nil
st.Lines = nil
}
// ResetDepth resets the current depths to 0
func (st *State) ResetDepth() {
st.ParenDepth, st.BraceDepth, st.BrackDepth, st.TypeDepth, st.DeclDepth = 0, 0, 0, 0, 0
}
// DepthError reports an error if any of the parsing depths are not zero,
// to be called at the end of transpiling a complete block of code.
func (st *State) DepthError() error {
if st.TotalDepth() == 0 {
return nil
}
str := ""
if st.ParenDepth != 0 {
str += fmt.Sprintf("Incomplete parentheses (), remaining depth: %d\n", st.ParenDepth)
}
if st.BraceDepth != 0 {
str += fmt.Sprintf("Incomplete braces [], remaining depth: %d\n", st.BraceDepth)
}
if st.BrackDepth != 0 {
str += fmt.Sprintf("Incomplete brackets {}, remaining depth: %d\n", st.BrackDepth)
}
if str != "" {
slog.Error(str)
return errors.New(str)
}
return nil
}
// AddLine adds line on the stack
func (st *State) AddLine(ln string) {
st.Lines = append(st.Lines, ln)
}
// Code returns the current transpiled lines,
// split into chunks that should be compiled separately.
func (st *State) Code() string {
st.AddChunk()
if len(st.Chunks) == 0 {
return ""
}
return strings.Join(st.Chunks, "\n")
}
// AddChunk adds current lines into a chunk of code
// that should be compiled separately.
func (st *State) AddChunk() {
if len(st.Lines) == 0 {
return
}
st.Chunks = append(st.Chunks, strings.Join(st.Lines, "\n"))
st.Lines = nil
}
// AddError adds the given error to the error stack if it is non-nil,
// and calls the Cancel function if set, to stop execution.
// This is the main way that goal errors are handled.
// It also prints the error.
func (st *State) AddError(err error) error {
if err == nil {
return nil
}
st.Errors = append(st.Errors, err)
logx.PrintlnError(err)
return err
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package transpile
import (
"errors"
"go/token"
)
var tensorfsCommands = map[string]func(mp *mathParse) error{
"cd": cd,
"mkdir": mkdir,
"ls": ls,
}
func cd(mp *mathParse) error {
var dir string
if len(mp.ewords) > 1 {
dir = mp.ewords[1]
}
mp.out.Add(token.IDENT, "tensorfs.Chdir")
mp.out.Add(token.LPAREN)
mp.out.Add(token.STRING, `"`+dir+`"`)
mp.out.Add(token.RPAREN)
return nil
}
func mkdir(mp *mathParse) error {
if len(mp.ewords) == 1 {
return errors.New("tensorfs mkdir requires a directory name")
}
dir := mp.ewords[1]
mp.out.Add(token.IDENT, "tensorfs.Mkdir")
mp.out.Add(token.LPAREN)
mp.out.Add(token.STRING, `"`+dir+`"`)
mp.out.Add(token.RPAREN)
return nil
}
func ls(mp *mathParse) error {
mp.out.Add(token.IDENT, "tensorfs.List")
mp.out.Add(token.LPAREN)
for i := 1; i < len(mp.ewords); i++ {
mp.out.Add(token.STRING, `"`+mp.ewords[i]+`"`)
}
mp.out.Add(token.RPAREN)
return nil
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package transpile
import (
"go/scanner"
"go/token"
"log/slog"
"slices"
"strings"
"cogentcore.org/core/base/logx"
)
// Token provides full data for one token
type Token struct {
// Go token classification
Tok token.Token
// Literal string
Str string
// position in the original string.
// this is only set for the original parse,
// not for transpiled additions.
Pos token.Pos
}
// Tokens converts the string into tokens
func TokensFromString(ln string) Tokens {
fset := token.NewFileSet()
f := fset.AddFile("", fset.Base(), len(ln))
var sc scanner.Scanner
sc.Init(f, []byte(ln), errHandler, scanner.ScanComments|2) // 2 is non-exported dontInsertSemis
// note to Go team: just export this stuff. seriously.
var toks Tokens
for {
pos, tok, lit := sc.Scan()
if tok == token.EOF {
break
}
// logx.PrintfDebug(" token: %s\t%s\t%q\n", fset.Position(pos), tok, lit)
toks = append(toks, &Token{Tok: tok, Pos: pos, Str: lit})
}
return toks
}
func errHandler(pos token.Position, msg string) {
logx.PrintlnDebug("Scan Error:", pos, msg)
}
// Tokens is a slice of Token
type Tokens []*Token
// NewToken returns a new token, for generated tokens without Pos
func NewToken(tok token.Token, str ...string) *Token {
tk := &Token{Tok: tok}
if len(str) > 0 {
tk.Str = str[0]
}
return tk
}
// Add adds a new token, for generated tokens without Pos
func (tk *Tokens) Add(tok token.Token, str ...string) *Token {
nt := NewToken(tok, str...)
*tk = append(*tk, nt)
return nt
}
// AddMulti adds new basic tokens (not IDENT).
func (tk *Tokens) AddMulti(tok ...token.Token) {
for _, t := range tok {
tk.Add(t)
}
}
// AddTokens adds given tokens to our list
func (tk *Tokens) AddTokens(toks ...*Token) *Tokens {
*tk = append(*tk, toks...)
return tk
}
// Insert inserts a new token at given position
func (tk *Tokens) Insert(i int, tok token.Token, str ...string) *Token {
nt := NewToken(tok, str...)
*tk = slices.Insert(*tk, i, nt)
return nt
}
// Last returns the final token in the list
func (tk Tokens) Last() *Token {
n := len(tk)
if n == 0 {
return nil
}
return tk[n-1]
}
// DeleteLastComma removes any final Comma.
// easier to generate and delete at the end
func (tk *Tokens) DeleteLastComma() {
lt := tk.Last()
if lt == nil {
return
}
if lt.Tok == token.COMMA {
*tk = (*tk)[:len(*tk)-1]
}
}
// String returns the string for the token
func (tk *Token) String() string {
if tk.Str != "" {
return tk.Str
}
return tk.Tok.String()
}
// IsBacktickString returns true if the given STRING uses backticks
func (tk *Token) IsBacktickString() bool {
if tk.Tok != token.STRING {
return false
}
return (tk.Str[0] == '`')
}
// IsGo returns true if the given token is a Go Keyword or Comment
func (tk *Token) IsGo() bool {
if tk.Tok >= token.BREAK && tk.Tok <= token.VAR {
return true
}
if tk.Tok == token.COMMENT {
return true
}
return false
}
// IsValidExecIdent returns true if the given token is a valid component
// of an Exec mode identifier
func (tk *Token) IsValidExecIdent() bool {
return (tk.IsGo() || tk.Tok == token.IDENT || tk.Tok == token.SUB || tk.Tok == token.DEC || tk.Tok == token.INT || tk.Tok == token.FLOAT || tk.Tok == token.ASSIGN)
}
// String is the stringer version which includes the token ID
// in addition to the string literal
func (tk Tokens) String() string {
str := ""
for _, tok := range tk {
str += "[" + tok.Tok.String() + "] "
if tok.Str != "" {
str += tok.Str + " "
}
}
if len(str) == 0 {
return str
}
return str[:len(str)-1] // remove trailing space
}
// Code returns concatenated Str values of the tokens,
// to generate a surface-valid code string.
func (tk Tokens) Code() string {
n := len(tk)
if n == 0 {
return ""
}
str := ""
prvIdent := false
for _, tok := range tk {
switch {
case tok.IsOp():
switch tok.Tok {
case token.INC, token.DEC:
str += tok.String() + " "
case token.MUL:
str += " " + tok.String()
case token.SUB:
str += tok.String()
default:
str += " " + tok.String() + " "
}
prvIdent = false
case tok.Tok == token.ELLIPSIS:
str += " " + tok.String()
prvIdent = false
case tok.IsBracket() || tok.Tok == token.PERIOD:
if tok.Tok == token.RBRACE || tok.Tok == token.LBRACE {
if len(str) > 0 && str[len(str)-1] != ' ' {
str += " "
}
str += tok.String() + " "
} else {
str += tok.String()
}
prvIdent = false
case tok.Tok == token.COMMA || tok.Tok == token.COLON || tok.Tok == token.SEMICOLON:
str += tok.String() + " "
prvIdent = false
case tok.Tok == token.STRUCT:
str += " " + tok.String() + " "
case tok.Tok == token.FUNC:
if prvIdent {
str += " "
}
str += tok.String()
prvIdent = true
case tok.Tok == token.COMMENT:
if str != "" {
str += " "
}
str += tok.String()
case tok.IsGo():
if prvIdent {
str += " "
}
str += tok.String()
if tok.Tok != token.MAP {
str += " "
}
prvIdent = false
case tok.Tok == token.IDENT || tok.Tok == token.STRING:
if prvIdent {
str += " "
}
str += tok.String()
prvIdent = true
default:
str += tok.String()
prvIdent = false
}
}
if len(str) == 0 {
return str
}
if str[len(str)-1] == ' ' {
return str[:len(str)-1]
}
return str
}
// IsOp returns true if the given token is an operator
func (tk *Token) IsOp() bool {
if tk.Tok >= token.ADD && tk.Tok <= token.DEFINE {
return true
}
return false
}
// Contains returns true if the token string contains any of the given token(s)
func (tk Tokens) Contains(toks ...token.Token) bool {
if len(toks) == 0 {
slog.Error("programmer error: tokens.Contains with no args")
return false
}
for _, t := range tk {
for _, st := range toks {
if t.Tok == st {
return true
}
}
}
return false
}
// EscapeQuotes replaces any " with \"
func EscapeQuotes(str string) string {
return strings.ReplaceAll(str, `"`, `\"`)
}
// AddQuotes surrounds given string with quotes,
// also escaping any contained quotes
func AddQuotes(str string) string {
return `"` + EscapeQuotes(str) + `"`
}
// IsBracket returns true if the given token is a bracket delimiter:
// paren, brace, bracket
func (tk *Token) IsBracket() bool {
if (tk.Tok >= token.LPAREN && tk.Tok <= token.LBRACE) || (tk.Tok >= token.RPAREN && tk.Tok <= token.RBRACE) {
return true
}
return false
}
// RightMatching returns the position (or -1 if not found) for the
// right matching [paren, bracket, brace] given the left one that
// is at the 0 position of the current set of tokens.
func (tk Tokens) RightMatching() int {
n := len(tk)
if n == 0 {
return -1
}
rb := token.RPAREN
lb := tk[0].Tok
switch lb {
case token.LPAREN:
rb = token.RPAREN
case token.LBRACK:
rb = token.RBRACK
case token.LBRACE:
rb = token.RBRACE
}
depth := 0
for i := 1; i < n; i++ {
tok := tk[i].Tok
switch tok {
case rb:
if depth <= 0 {
return i
}
depth--
case lb:
depth++
}
}
return -1
}
// BracketDepths returns the depths for the three bracket delimiters
// [paren, bracket, brace], based on unmatched right versions.
func (tk Tokens) BracketDepths() (paren, brace, brack int) {
n := len(tk)
if n == 0 {
return
}
for i := 0; i < n; i++ {
tok := tk[i].Tok
switch tok {
case token.LPAREN:
paren++
case token.LBRACE:
brace++
case token.LBRACK:
brack++
case token.RPAREN:
paren--
case token.RBRACE:
brace--
case token.RBRACK:
brack--
}
}
return
}
// ModeEnd returns the position (or -1 if not found) for the
// next ILLEGAL mode token ($ or #) given the starting one that
// is at the 0 position of the current set of tokens.
func (tk Tokens) ModeEnd() int {
n := len(tk)
if n == 0 {
return -1
}
st := tk[0].Str
for i := 1; i < n; i++ {
if tk[i].Tok != token.ILLEGAL {
continue
}
if tk[i].Str == st {
return i
}
}
return -1
}
// IsAssignExpr checks if there are any Go assignment or define tokens
// outside of { } Go code.
func (tk Tokens) IsAssignExpr() bool {
n := len(tk)
if n == 0 {
return false
}
for i := 1; i < n; i++ {
tok := tk[i].Tok
if tok == token.ASSIGN || tok == token.DEFINE || (tok >= token.ADD_ASSIGN && tok <= token.AND_NOT_ASSIGN) {
return true
}
if tok == token.LBRACE { // skip Go mode
rp := tk[i:n].RightMatching()
if rp > 0 {
i += rp
}
}
}
return false
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package transpile
import (
"fmt"
"go/token"
"slices"
"strings"
"cogentcore.org/core/base/logx"
)
// TranspileLine is the main function for parsing a single line of goal input,
// returning a new transpiled line of code that converts Exec code into corresponding
// Go function calls.
func (st *State) TranspileLine(code string) string {
if len(code) == 0 {
return code
}
if strings.HasPrefix(code, "#!") {
return ""
}
toks := st.TranspileLineTokens(code)
paren, brace, brack := toks.BracketDepths()
st.ParenDepth += paren
st.BraceDepth += brace
st.BrackDepth += brack
// logx.PrintlnDebug("depths: ", st.ParenDepth, st.BraceDepth, st.BrackDepth, st.TypeDepth, st.DeclDepth)
if st.TypeDepth > 0 && st.BraceDepth == 0 {
st.TypeDepth = 0
}
if st.DeclDepth > 0 && (st.ParenDepth == 0 && st.BraceDepth == 0) {
st.DeclDepth = 0
}
// logx.PrintlnDebug("depths: ", st.ParenDepth, st.BraceDepth, st.BrackDepth, st.TypeDepth, st.DeclDepth)
return toks.Code()
}
// TranspileLineTokens returns the tokens for the full line
func (st *State) TranspileLineTokens(code string) Tokens {
if code == "" {
return nil
}
toks := TokensFromString(code)
n := len(toks)
if n == 0 {
return toks
}
if st.MathMode {
if len(toks) >= 2 {
if toks[0].Tok == token.ILLEGAL && toks[0].Str == "#" && toks[1].Tok == token.ILLEGAL && toks[1].Str == "#" {
st.MathMode = false
return nil
}
}
return st.TranspileMath(toks, code, true)
}
ewords, err := ExecWords(code)
if err != nil {
st.AddError(err)
return nil
}
logx.PrintlnDebug("\n########## line:\n", code, "\nTokens:", len(toks), "\n", toks.String(), "\nWords:", len(ewords), "\n", ewords)
if toks[0].Tok == token.TYPE {
st.TypeDepth++
}
if toks[0].Tok == token.IMPORT || toks[0].Tok == token.VAR || toks[0].Tok == token.CONST {
st.DeclDepth++
}
// logx.PrintlnDebug("depths: ", st.ParenDepth, st.BraceDepth, st.BrackDepth, st.TypeDepth, st.DeclDepth)
if st.TypeDepth > 0 || st.DeclDepth > 0 {
logx.PrintlnDebug("go: type / decl defn")
return st.TranspileGo(toks, code)
}
t0 := toks[0]
_, t0pn := toks.Path(true) // true = first position
en := len(ewords)
f0exec := (t0.Tok == token.IDENT && ExecWordIsCommand(ewords[0]))
if f0exec && n > 1 && toks[1].Tok == token.COLON { // go Ident: static initializer
f0exec = false
}
switch {
case t0.Tok == token.ILLEGAL:
if t0.Str == "#" {
logx.PrintlnDebug("math #")
if toks[1].Tok == token.ILLEGAL && toks[1].Str == "#" {
st.MathMode = true
return st.TranspileMath(toks[2:], code, true)
}
return st.TranspileMath(toks[1:], code, true)
}
return st.TranspileExec(ewords, false)
case t0.Tok == token.LBRACE:
logx.PrintlnDebug("go: { } line")
return st.TranspileGoRange(toks, code, 1, n-1)
case t0.Tok == token.LBRACK:
logx.PrintlnDebug("exec: [ ] line")
return st.TranspileExec(ewords, false) // it processes the [ ]
case t0.Tok == token.IDENT && t0.Str == "command":
st.lastCommand = toks[1].Str // 1 is the name -- triggers AddCommand
toks = toks[2:] // get rid of first
toks.Insert(0, token.IDENT, "goalrun.AddCommand")
toks.Insert(1, token.LPAREN)
toks.Insert(2, token.STRING, `"`+st.lastCommand+`"`)
toks.Insert(3, token.COMMA)
toks.Insert(4, token.FUNC)
toks.Insert(5, token.LPAREN)
toks.Insert(6, token.IDENT, "args")
toks.Insert(7, token.ELLIPSIS)
toks.Insert(8, token.IDENT, "string")
toks.Insert(9, token.RPAREN)
toks.AddTokens(st.TranspileGo(toks[11:], code)...)
case t0.IsGo():
if t0.Tok == token.GO {
if !toks.Contains(token.LPAREN) {
logx.PrintlnDebug("exec: go command")
return st.TranspileExec(ewords, false)
}
}
logx.PrintlnDebug("go keyword")
return st.TranspileGo(toks, code)
case toks[n-1].Tok == token.INC || toks[n-1].Tok == token.DEC:
logx.PrintlnDebug("go ++ / --")
return st.TranspileGo(toks, code)
case t0pn > 0: // path expr
logx.PrintlnDebug("exec: path...")
return st.TranspileExec(ewords, false)
case t0.Tok == token.STRING:
logx.PrintlnDebug("exec: string...")
return st.TranspileExec(ewords, false)
case f0exec && en == 1:
logx.PrintlnDebug("exec: 1 word")
return st.TranspileExec(ewords, false)
case !f0exec: // exec must be IDENT
logx.PrintlnDebug("go: not ident")
return st.TranspileGo(toks, code)
case f0exec && en > 1 && ewords[0] != "set" && toks.IsAssignExpr():
logx.PrintlnDebug("go: assignment or defn")
return st.TranspileGo(toks, code)
case f0exec && en > 1 && ewords[0] != "set" && toks.IsAssignExpr():
logx.PrintlnDebug("go: assignment or defn")
return st.TranspileGo(toks, code)
case f0exec: // now any ident
logx.PrintlnDebug("exec: ident..")
return st.TranspileExec(ewords, false)
default:
logx.PrintlnDebug("go: default")
return st.TranspileGo(toks, code)
}
return toks
}
// TranspileGoRange returns transpiled tokens assuming Go code,
// for given start, end (exclusive) range of given tokens and code.
// In general the positions in the tokens applies to the _original_ code
// so you should just keep the original code string. However, this is
// needed for a specific case.
func (st *State) TranspileGoRange(toks Tokens, code string, start, end int) Tokens {
codeSt := toks[start].Pos - 1
codeEd := token.Pos(len(code))
if end <= len(toks)-1 {
codeEd = toks[end].Pos - 1
}
return st.TranspileGo(toks[start:end], code[codeSt:codeEd])
}
// TranspileGo returns transpiled tokens assuming Go code.
// Unpacks any encapsulated shell or math expressions.
func (st *State) TranspileGo(toks Tokens, code string) Tokens {
n := len(toks)
if n == 0 {
return toks
}
if st.FuncToVar && toks[0].Tok == token.FUNC { // reorder as an assignment
if len(toks) > 1 && toks[1].Tok == token.IDENT {
toks[0] = toks[1]
toks.Insert(1, token.DEFINE)
toks[2] = &Token{Tok: token.FUNC}
n = len(toks)
}
}
gtoks := make(Tokens, 0, len(toks)) // return tokens
for i := 0; i < n; i++ {
tok := toks[i]
switch {
case tok.Tok == token.ILLEGAL:
et := toks[i:].ModeEnd()
if et > 0 {
if tok.Str == "#" {
gtoks.AddTokens(st.TranspileMath(toks[i+1:i+et], code, false)...)
} else {
gtoks.AddTokens(st.TranspileExecTokens(toks[i+1:i+et+1], code, true)...)
}
i += et
continue
} else {
gtoks = append(gtoks, tok)
}
case tok.Tok == token.LBRACK && i > 0 && toks[i-1].Tok == token.IDENT: // index expr
ixtoks := toks[i:]
rm := ixtoks.RightMatching()
if rm < 3 {
gtoks = append(gtoks, tok)
continue
}
idx := st.TranspileGoNDimIndex(toks, code, >oks, i-1, rm+i)
if idx > 0 {
i = idx
} else {
gtoks = append(gtoks, tok)
}
default:
gtoks = append(gtoks, tok)
}
}
return gtoks
}
// TranspileExecString returns transpiled tokens assuming Exec code,
// from a string, with the given bool indicating whether [Output] is needed.
func (st *State) TranspileExecString(str string, output bool) Tokens {
if len(str) <= 1 {
return nil
}
ewords, err := ExecWords(str)
if err != nil {
st.AddError(err)
}
return st.TranspileExec(ewords, output)
}
// TranspileExecTokens returns transpiled tokens assuming Exec code,
// from given tokens, with the given bool indicating
// whether [Output] is needed.
func (st *State) TranspileExecTokens(toks Tokens, code string, output bool) Tokens {
nt := len(toks)
if nt == 0 {
return nil
}
str := code[toks[0].Pos-1 : toks[nt-1].Pos-1]
return st.TranspileExecString(str, output)
}
// TranspileExec returns transpiled tokens assuming Exec code,
// with the given bools indicating the type of run to execute.
func (st *State) TranspileExec(ewords []string, output bool) Tokens {
n := len(ewords)
if n == 0 {
return nil
}
etoks := make(Tokens, 0, n+5) // return tokens
var execTok *Token
bgJob := false
noStop := false
if ewords[0] == "[" {
ewords = ewords[1:]
n--
noStop = true
}
startExec := func() {
bgJob = false
etoks.Add(token.IDENT, "goalrun")
etoks.Add(token.PERIOD)
switch {
case output && noStop:
execTok = etoks.Add(token.IDENT, "OutputErrOK")
case output && !noStop:
execTok = etoks.Add(token.IDENT, "Output")
case !output && noStop:
execTok = etoks.Add(token.IDENT, "RunErrOK")
case !output && !noStop:
execTok = etoks.Add(token.IDENT, "Run")
}
etoks.Add(token.LPAREN)
}
endExec := func() {
if bgJob {
execTok.Str = "Start"
}
etoks.DeleteLastComma()
etoks.Add(token.RPAREN)
}
startExec()
for i := 0; i < n; i++ {
f := ewords[i]
switch {
case f == "{": // embedded go
if n < i+3 {
st.AddError(fmt.Errorf("goal: no matching right brace } found in exec command line"))
} else {
gstr := ewords[i+1]
etoks.AddTokens(st.TranspileGo(TokensFromString(gstr), gstr)...)
etoks.Add(token.COMMA)
i += 2
}
case f == "[":
noStop = true
case f == "]": // solo is def end
// just skip
noStop = false
case f == "&":
bgJob = true
case f[0] == '|':
execTok.Str = "Start"
etoks.Add(token.IDENT, AddQuotes(f))
etoks.Add(token.COMMA)
endExec()
etoks.Add(token.SEMICOLON)
etoks.AddTokens(st.TranspileExec(ewords[i+1:], output)...)
return etoks
case f == ";":
endExec()
etoks.Add(token.SEMICOLON)
etoks.AddTokens(st.TranspileExec(ewords[i+1:], output)...)
return etoks
default:
if f[0] == '"' || f[0] == '`' {
etoks.Add(token.STRING, f)
} else {
etoks.Add(token.IDENT, AddQuotes(f)) // mark as an IDENT but add quotes!
}
etoks.Add(token.COMMA)
}
}
endExec()
return etoks
}
// TranspileGoNDimIndex processes an ident[*] sequence of tokens,
// translating it into a corresponding tensor Value or Set expression,
// if it is a multi-dimensional indexing expression which is not valid in Go,
// to support simple n-dimensional tensor indexing in Go (not math) mode.
// Gets the current sequence of toks tokens, where the ident starts at idIdx
// and the ] is at rbIdx. It puts the results in gtoks generated tokens.
// Returns a positive index to resume processing at, if it is actually an
// n-dimensional expr, and -1 if not, in which case the normal process resumes.
func (st *State) TranspileGoNDimIndex(toks Tokens, code string, gtoks *Tokens, idIdx, rbIdx int) int {
var commas []int
for i := idIdx + 2; i < rbIdx; i++ {
tk := toks[i]
if tk.Tok == token.COMMA {
commas = append(commas, i)
}
if tk.Tok == token.LPAREN || tk.Tok == token.LBRACE || tk.Tok == token.LBRACK {
rp := toks[i:rbIdx].RightMatching()
if rp > 0 {
i += rp
}
}
}
if len(commas) == 0 { // not multidim
return -1
}
isPtr := false
if idIdx > 0 && toks[idIdx-1].Tok == token.AND {
isPtr = true
lgt := len(*gtoks)
*gtoks = slices.Delete(*gtoks, lgt-2, lgt-1) // get rid of &
}
// now we need to determine if it is a Set based on what happens after rb
isSet := false
stok := token.ILLEGAL
n := len(toks)
hasComment := false
if toks[n-1].Tok == token.COMMENT {
hasComment = true
n--
}
if n-rbIdx > 1 {
ntk := toks[rbIdx+1].Tok
if ntk == token.ASSIGN || (ntk >= token.ADD_ASSIGN && ntk <= token.QUO_ASSIGN) {
isSet = true
stok = ntk
}
}
fun := "Value"
if isPtr {
fun = "ValuePtr"
isSet = false
} else if isSet {
fun = "Set"
switch stok {
case token.ADD_ASSIGN:
fun += "Add"
case token.SUB_ASSIGN:
fun += "Sub"
case token.MUL_ASSIGN:
fun += "Mul"
case token.QUO_ASSIGN:
fun += "Div"
}
}
gtoks.Add(token.PERIOD)
gtoks.Add(token.IDENT, fun)
gtoks.Add(token.LPAREN)
if isSet {
gtoks.AddTokens(st.TranspileGo(toks[rbIdx+2:n], code)...)
gtoks.Add(token.COMMA)
}
sti := idIdx + 2
for _, cp := range commas {
gtoks.Add(token.IDENT, "int")
gtoks.Add(token.LPAREN)
gtoks.AddTokens(st.TranspileGo(toks[sti:cp], code)...)
gtoks.Add(token.RPAREN)
gtoks.Add(token.COMMA)
sti = cp + 1
}
gtoks.Add(token.IDENT, "int")
gtoks.Add(token.LPAREN)
gtoks.AddTokens(st.TranspileGo(toks[sti:rbIdx], code)...)
gtoks.Add(token.RPAREN)
gtoks.Add(token.RPAREN)
if isSet {
if hasComment {
gtoks.AddTokens(toks[len(toks)-1])
}
return len(toks)
} else {
return rbIdx
}
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package transpile
import (
"go/ast"
"go/token"
)
// inferKindExpr infers the basic Kind level type from given expression
func inferKindExpr(ex ast.Expr) token.Token {
if ex == nil {
return token.ILLEGAL
}
switch x := ex.(type) {
case *ast.BadExpr:
return token.ILLEGAL
case *ast.Ident:
// todo: get type of object is not possible!
case *ast.BinaryExpr:
ta := inferKindExpr(x.X)
tb := inferKindExpr(x.Y)
if ta == tb {
return ta
}
if ta != token.ILLEGAL {
return ta
} else {
return tb
}
case *ast.BasicLit:
return x.Kind // key grounding
case *ast.FuncLit:
case *ast.ParenExpr:
return inferKindExpr(x.X)
case *ast.SelectorExpr:
case *ast.TypeAssertExpr:
case *ast.IndexListExpr:
if x.X == nil { // array literal
return inferKindExprList(x.Indices)
} else {
return inferKindExpr(x.X)
}
case *ast.SliceExpr:
case *ast.CallExpr:
}
return token.ILLEGAL
}
func inferKindExprList(ex []ast.Expr) token.Token {
n := len(ex)
for i := range n {
t := inferKindExpr(ex[i])
if t != token.ILLEGAL {
return t
}
}
return token.ILLEGAL
}
// Copyright (c) 2022, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
package alignsl performs 16-byte alignment checking of struct fields
and total size modulus checking of struct types to ensure WGSL
(and GSL) compatibility.
Checks that struct sizes are an even multiple of 16 bytes
(4 float32's), fields are 32 bit types: [U]Int32, Float32,
and that fields that are other struct types are aligned
at even 16 byte multiples.
*/
package alignsl
import (
"errors"
"fmt"
"go/types"
"strings"
"golang.org/x/tools/go/packages"
)
// Context for given package run
type Context struct {
Sizes types.Sizes // from package
Structs map[*types.Struct]string // structs that have been processed already -- value is name
Stack map[*types.Struct]string // structs to process in a second pass -- structs encountered during processing of other structs
StructTypes map[string]bool // top level list of struct types to examine -- skip anything at a top-level that is not in this list.
Errs []string // accumulating list of error strings -- empty if all good
}
func NewContext(sz types.Sizes, structTypes map[string]bool) *Context {
cx := &Context{Sizes: sz}
cx.Structs = make(map[*types.Struct]string)
cx.Stack = make(map[*types.Struct]string)
cx.StructTypes = structTypes
return cx
}
func (cx *Context) IsNewStruct(st *types.Struct) bool {
if _, has := cx.Structs[st]; has {
return false
}
cx.Structs[st] = st.String()
return true
}
func (cx *Context) AddError(ers string, hasErr bool, stName string) bool {
if !hasErr {
cx.Errs = append(cx.Errs, stName)
}
cx.Errs = append(cx.Errs, ers)
return true
}
func TypeName(tp types.Type) string {
switch x := tp.(type) {
case *types.Named:
return x.Obj().Name()
}
return tp.String()
}
// CheckStruct is the top-level checker -- returns hasErr = true if there
// are any mis-aligned fields or total size of struct is not an
// even multiple of 16 bytes -- adds details to Errs.
// If struct is not on the cx.StructTypes list, it is skipped.
func CheckStruct(cx *Context, st *types.Struct, stName string) bool {
if _, ok := cx.StructTypes[stName]; !ok {
return false
}
return CheckStructImpl(cx, st, stName)
}
// CheckStructImpl can be used for CheckStack -- doesn't check for
// top-level StructTypes membership.
func CheckStructImpl(cx *Context, st *types.Struct, stName string) bool {
if !cx.IsNewStruct(st) {
return false
}
var flds []*types.Var
nf := st.NumFields()
if nf == 0 {
return false
}
hasErr := false
for i := 0; i < nf; i++ {
fl := st.Field(i)
flds = append(flds, fl)
ft := fl.Type()
ut := ft.Underlying()
if bt, isBasic := ut.(*types.Basic); isBasic {
kind := bt.Kind()
if kind == types.Invalid {
hasErr = cx.AddError(fmt.Sprintf(` %s: %s: add //gosl:import "package"`, fl.Name(), bt.String()), hasErr, stName)
} else if !(kind == types.Uint32 || kind == types.Int32 || kind == types.Float32 || kind == types.Uint64) {
hasErr = cx.AddError(fmt.Sprintf(" %s: basic type != [U]Int32 or Float32: %s", fl.Name(), bt.String()), hasErr, stName)
fmt.Println("kind:", kind, "ft:", ft.String())
}
} else {
if sst, is := ut.(*types.Struct); is {
cx.Stack[sst] = TypeName(ft)
} else {
hasErr = cx.AddError(fmt.Sprintf(" %s: unsupported type: %s", fl.Name(), ft.String()), hasErr, stName)
}
}
}
offs := cx.Sizes.Offsetsof(flds)
last := cx.Sizes.Sizeof(flds[nf-1].Type())
totsz := int(offs[nf-1] + last)
mod := totsz % 16
vectyp := strings.Contains(strings.ToLower(stName), "vec") // vector types are ok
if !vectyp && mod != 0 {
needs := 4 - (mod / 4)
hasErr = cx.AddError(fmt.Sprintf(" total size: %d not even multiple of 16 -- needs %d extra 32bit padding fields", totsz, needs), hasErr, stName)
}
// check that struct starts at mod 16 byte offset
for i, fl := range flds {
ft := fl.Type()
ut := ft.Underlying()
if _, is := ut.(*types.Struct); is {
off := offs[i]
if off%16 != 0 {
hasErr = cx.AddError(fmt.Sprintf(" %s: struct type: %s is not at mod-16 byte offset: %d", fl.Name(), TypeName(ft), off), hasErr, stName)
}
}
}
return hasErr
}
// CheckPackage is main entry point for checking a package
// returns error string if any errors found.
// structTypes is a map of struct type names to check for alignment.
// any other struct types are purely internal not used for variables, so
// they don't need to be checked.
func CheckPackage(pkg *packages.Package, structTypes map[string]bool) error {
cx := NewContext(pkg.TypesSizes, structTypes)
sc := pkg.Types.Scope()
hasErr := CheckScope(cx, sc, 0)
er := CheckStack(cx)
if hasErr || er {
str := `
WARNING: in struct type alignment checking:
Checks that struct sizes are an even multiple of 16 bytes (4 float32's),
and fields are 32 bit types: [U]Int32, Float32 or other struct,
and that fields that are other struct types are aligned at even 16 byte multiples.
List of errors found follow below, by struct type name:
` + strings.Join(cx.Errs, "\n")
return errors.New(str)
}
return nil
}
func CheckStack(cx *Context) bool {
hasErr := false
for {
if len(cx.Stack) == 0 {
break
}
st := cx.Stack
cx.Stack = make(map[*types.Struct]string) // new stack
for st, nm := range st {
er := CheckStructImpl(cx, st, nm)
if er {
hasErr = true
}
}
}
return hasErr
}
func CheckScope(cx *Context, sc *types.Scope, level int) bool {
nms := sc.Names()
ntyp := 0
hasErr := false
for _, nm := range nms {
ob := sc.Lookup(nm)
tp := ob.Type()
if tp == nil {
continue
}
if nt, is := tp.(*types.Named); is {
ut := nt.Underlying()
if ut == nil {
continue
}
if st, is := ut.(*types.Struct); is {
er := CheckStruct(cx, st, nt.Obj().Name())
if er {
hasErr = true
}
ntyp++
}
}
}
if ntyp == 0 {
for i := 0; i < sc.NumChildren(); i++ {
cs := sc.Child(i)
er := CheckScope(cx, cs, level+1)
if er {
hasErr = true
}
}
}
return hasErr
}
// Code generated by "goal build"; DO NOT EDIT.
//line atomic.goal:1
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import "sync/atomic"
//gosl:start
// Atomic does an atomic computation on the data.
func Atomic(i uint32) { //gosl:kernel
atomic.AddInt32(IntData.ValuePtr(int(i), int(Integ)), 1)
}
//gosl:end
// Code generated by "goal build"; DO NOT EDIT.
//line compute.goal:1
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"cogentcore.org/core/math32"
"cogentcore.org/lab/tensor"
)
//gosl:start
//gosl:import "cogentcore.org/core/math32"
//gosl:vars
var (
// Params are the parameters for the computation.
//
//gosl:group Params
//gosl:read-only
Params []ParamStruct
// Data is the data on which the computation operates.
// 2D: outer index is data, inner index is: Raw, Integ, Exp vars.
//
//gosl:group Data
//gosl:dims 2
//gosl:nbuffs 8
Data *tensor.Float32
// IntData is the int data on which the computation operates.
// 2D: outer index is data, inner index is: Raw, Integ, Exp vars.
//
//gosl:dims 2
IntData *tensor.Int32
)
const (
Raw int = iota
Integ
Exp
NVars
)
// SubStruct is a sub struct within the overall param struct.
// There are tricky rules about how to access such things.
type SubStruct struct {
// rate constant in msec
Tau float32
// 1/Tau
Dt float32
pad float32
pad1 float32
}
// ParamStruct has the test params
type ParamStruct struct {
// rate constant in msec
Tau float32
// 1/Tau
Dt float32
// number of data items -- must have avail for GPU to exclude extra.
DataLen uint32
pad1 float32
Sub SubStruct
}
// IntegFromRaw computes integrated value from current raw value
func (ps *SubStruct) IntegFromRaw(idx int) {
integ := Data.Value(int(idx), int(Integ))
integ += ps.Dt * (Data.Value(int(idx), int(Raw)) - integ)
Data.Set(integ, int(idx), int(Integ))
Data.Set(math32.FastExp(-integ), int(idx), int(Exp))
}
// IntegFromRaw computes integrated value from current raw value
func (ps *ParamStruct) IntegFromRaw(idx int) {
integ := Data.Value(int(idx), int(Integ))
integ += ps.Dt * (Data.Value(int(idx), int(Raw)) - integ)
Data.Set(integ, int(idx), int(Integ))
Data.Set(math32.FastExp(-integ), int(idx), int(Exp))
ps.Sub.IntegFromRaw(idx)
}
// Compute does the main computation.
func Compute(i uint32) { //gosl:kernel
if i >= Params[0].DataLen { // note: essential to bounds check b/c i in 64 blocks
return
}
Params[0].IntegFromRaw(int(i))
}
//gosl:end
// note: only core compute code needs to be in shader -- all init is done CPU-side
func (ps *ParamStruct) Defaults() {
ps.Tau = 5
ps.Update()
ps.Sub.Defaults()
}
func (ps *ParamStruct) Update() {
ps.Dt = 1.0 / ps.Tau
}
func (ps *SubStruct) Defaults() {
ps.Tau = 5
ps.Update()
}
func (ps *SubStruct) Update() {
ps.Dt = 1.0 / ps.Tau
}
// Code generated by "gosl"; DO NOT EDIT
package main
import (
"embed"
"fmt"
"math"
"unsafe"
"cogentcore.org/core/gpu"
"cogentcore.org/lab/tensor"
)
//go:embed shaders/*.wgsl
var shaders embed.FS
var (
// GPUInitialized is true once the GPU system has been initialized.
// Prevents multiple initializations.
GPUInitialized bool
// ComputeGPU is the compute gpu device.
// Set this prior to calling GPUInit() to use an existing device.
ComputeGPU *gpu.GPU
// BorrowedGPU is true if our ComputeGPU is set externally,
// versus created specifically for this system. If external,
// we don't release it.
BorrowedGPU bool
// UseGPU indicates whether to use GPU vs. CPU.
UseGPU bool
)
// GPUSystem is a GPU compute System with kernels operating on the
// same set of data variables.
var GPUSystem *gpu.ComputeSystem
// GPUVars is an enum for GPU variables, for specifying what to sync.
type GPUVars int32 //enums:enum
const (
ParamsVar GPUVars = 0
DataVar GPUVars = 1
IntDataVar GPUVars = 2
)
// Tensor stride variables
var TensorStrides tensor.Uint32
// GPUInit initializes the GPU compute system,
// configuring system(s), variables and kernels.
// It is safe to call multiple times: detects if already run.
func GPUInit() {
if GPUInitialized {
return
}
GPUInitialized = true
if ComputeGPU == nil { // set prior to this call to use an external
ComputeGPU = gpu.NewComputeGPU()
} else {
BorrowedGPU = true
}
gp := ComputeGPU
_ = fmt.Sprintf("%g",math.NaN()) // keep imports happy
{
sy := gpu.NewComputeSystem(gp, "Default")
GPUSystem = sy
vars := sy.Vars()
{
sgp := vars.AddGroup(gpu.Storage, "Params")
var vr *gpu.Var
_ = vr
vr = sgp.Add("TensorStrides", gpu.Uint32, 1, gpu.ComputeShader)
vr.ReadOnly = true
vr = sgp.AddStruct("Params", int(unsafe.Sizeof(ParamStruct{})), 1, gpu.ComputeShader)
vr.ReadOnly = true
sgp.SetNValues(1)
}
{
sgp := vars.AddGroup(gpu.Storage, "Data")
var vr *gpu.Var
_ = vr
vr = sgp.Add("Data0", gpu.Float32, 1, gpu.ComputeShader)
vr = sgp.Add("Data1", gpu.Float32, 1, gpu.ComputeShader)
vr = sgp.Add("Data2", gpu.Float32, 1, gpu.ComputeShader)
vr = sgp.Add("Data3", gpu.Float32, 1, gpu.ComputeShader)
vr = sgp.Add("Data4", gpu.Float32, 1, gpu.ComputeShader)
vr = sgp.Add("Data5", gpu.Float32, 1, gpu.ComputeShader)
vr = sgp.Add("Data6", gpu.Float32, 1, gpu.ComputeShader)
vr = sgp.Add("Data7", gpu.Float32, 1, gpu.ComputeShader)
vr = sgp.Add("IntData", gpu.Int32, 1, gpu.ComputeShader)
sgp.SetNValues(1)
}
var pl *gpu.ComputePipeline
pl = gpu.NewComputePipelineShaderFS(shaders, "shaders/Atomic.wgsl", sy)
pl.AddVarUsed(0, "TensorStrides")
pl.AddVarUsed(1, "IntData")
pl = gpu.NewComputePipelineShaderFS(shaders, "shaders/Compute.wgsl", sy)
pl.AddVarUsed(0, "TensorStrides")
pl.AddVarUsed(1, "Data0")
pl.AddVarUsed(1, "Data1")
pl.AddVarUsed(1, "Data2")
pl.AddVarUsed(1, "Data3")
pl.AddVarUsed(1, "Data4")
pl.AddVarUsed(1, "Data5")
pl.AddVarUsed(1, "Data6")
pl.AddVarUsed(1, "Data7")
pl.AddVarUsed(0, "Params")
sy.Config()
}
}
// GPURelease releases the GPU compute system resources.
// Call this at program exit.
func GPURelease() {
if GPUSystem != nil {
GPUSystem.Release()
GPUSystem = nil
}
if !BorrowedGPU && ComputeGPU != nil {
ComputeGPU.Release()
}
ComputeGPU = nil
}
// RunAtomic runs the Atomic kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// Can call multiple Run* kernels in a row, which are then all launched
// in the same command submission on the GPU, which is by far the most efficient.
// MUST call RunDone (with optional vars to sync) after all Run calls.
// Alternatively, a single-shot RunOneAtomic call does Run and Done for a
// single run-and-sync case.
func RunAtomic(n int) {
if UseGPU {
RunAtomicGPU(n)
} else {
RunAtomicCPU(n)
}
}
// RunAtomicGPU runs the Atomic kernel on the GPU. See [RunAtomic] for more info.
func RunAtomicGPU(n int) {
sy := GPUSystem
pl := sy.ComputePipelines["Atomic"]
ce, _ := sy.BeginComputePass()
pl.Dispatch1D(ce, n, 64)
}
// RunAtomicCPU runs the Atomic kernel on the CPU.
func RunAtomicCPU(n int) {
gpu.VectorizeFunc(0, n, Atomic)
}
// RunOneAtomic runs the Atomic kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// This version then calls RunDone with the given variables to sync
// after the Run, for a single-shot Run-and-Done call. If multiple kernels
// can be run in sequence, it is much more efficient to do multiple Run*
// calls followed by a RunDone call.
func RunOneAtomic(n int, syncVars ...GPUVars) {
if UseGPU {
RunAtomicGPU(n)
RunDone(syncVars...)
} else {
RunAtomicCPU(n)
}
}
// RunCompute runs the Compute kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// Can call multiple Run* kernels in a row, which are then all launched
// in the same command submission on the GPU, which is by far the most efficient.
// MUST call RunDone (with optional vars to sync) after all Run calls.
// Alternatively, a single-shot RunOneCompute call does Run and Done for a
// single run-and-sync case.
func RunCompute(n int) {
if UseGPU {
RunComputeGPU(n)
} else {
RunComputeCPU(n)
}
}
// RunComputeGPU runs the Compute kernel on the GPU. See [RunCompute] for more info.
func RunComputeGPU(n int) {
sy := GPUSystem
pl := sy.ComputePipelines["Compute"]
ce, _ := sy.BeginComputePass()
pl.Dispatch1D(ce, n, 64)
}
// RunComputeCPU runs the Compute kernel on the CPU.
func RunComputeCPU(n int) {
gpu.VectorizeFunc(0, n, Compute)
}
// RunOneCompute runs the Compute kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// This version then calls RunDone with the given variables to sync
// after the Run, for a single-shot Run-and-Done call. If multiple kernels
// can be run in sequence, it is much more efficient to do multiple Run*
// calls followed by a RunDone call.
func RunOneCompute(n int, syncVars ...GPUVars) {
if UseGPU {
RunComputeGPU(n)
RunDone(syncVars...)
} else {
RunComputeCPU(n)
}
}
// RunDone must be called after Run* calls to start compute kernels.
// This actually submits the kernel jobs to the GPU, and adds commands
// to synchronize the given variables back from the GPU to the CPU.
// After this function completes, the GPU results will be available in
// the specified variables.
func RunDone(syncVars ...GPUVars) {
if !UseGPU {
return
}
sy := GPUSystem
sy.ComputeEncoder.End()
ReadFromGPU(syncVars...)
sy.EndComputePass()
SyncFromGPU(syncVars...)
}
// ToGPU copies given variables to the GPU for the system.
func ToGPU(vars ...GPUVars) {
if !UseGPU {
return
}
sy := GPUSystem
syVars := sy.Vars()
for _, vr := range vars {
switch vr {
case ParamsVar:
v, _ := syVars.ValueByIndex(0, "Params", 0)
gpu.SetValueFrom(v, Params)
case DataVar:
bsz := 536870904
n := Data.Len()
nb := int(math.Ceil(float64(n) / float64(bsz)))
for bi := range nb {
v, _ := syVars.ValueByIndex(1, fmt.Sprintf("Data%d", bi), 0)
st := bsz * bi
ed := min(bsz * (bi+1), n)
gpu.SetValueFrom(v, Data.Values[st:ed])
}
case IntDataVar:
v, _ := syVars.ValueByIndex(1, "IntData", 0)
gpu.SetValueFrom(v, IntData.Values)
}
}
}
// RunGPUSync can be called to synchronize data between CPU and GPU.
// Any prior ToGPU* calls will execute to send data to the GPU,
// and any subsequent RunDone* calls will copy data back from the GPU.
func RunGPUSync() {
if !UseGPU {
return
}
sy := GPUSystem
sy.BeginComputePass()
}
// ToGPUTensorStrides gets tensor strides and starts copying to the GPU.
func ToGPUTensorStrides() {
if !UseGPU {
return
}
sy := GPUSystem
syVars := sy.Vars()
TensorStrides.SetShapeSizes(20)
TensorStrides.SetInt1D(Data.Shape().Strides[0], 0)
TensorStrides.SetInt1D(Data.Shape().Strides[1], 1)
TensorStrides.SetInt1D(IntData.Shape().Strides[0], 10)
TensorStrides.SetInt1D(IntData.Shape().Strides[1], 11)
v, _ := syVars.ValueByIndex(0, "TensorStrides", 0)
gpu.SetValueFrom(v, TensorStrides.Values)
}
// ReadFromGPU starts the process of copying vars to the GPU.
func ReadFromGPU(vars ...GPUVars) {
sy := GPUSystem
syVars := sy.Vars()
for _, vr := range vars {
switch vr {
case ParamsVar:
v, _ := syVars.ValueByIndex(0, "Params", 0)
v.GPUToRead(sy.CommandEncoder)
case DataVar:
bsz := 536870904
n := Data.Len()
nb := int(math.Ceil(float64(n) / float64(bsz)))
for bi := range nb {
v, _ := syVars.ValueByIndex(1, fmt.Sprintf("Data%d", bi), 0)
v.GPUToRead(sy.CommandEncoder)
}
case IntDataVar:
v, _ := syVars.ValueByIndex(1, "IntData", 0)
v.GPUToRead(sy.CommandEncoder)
}
}
}
// SyncFromGPU synchronizes vars from the GPU to the actual variable.
func SyncFromGPU(vars ...GPUVars) {
sy := GPUSystem
syVars := sy.Vars()
for _, vr := range vars {
switch vr {
case ParamsVar:
v, _ := syVars.ValueByIndex(0, "Params", 0)
v.ReadSync()
gpu.ReadToBytes(v, Params)
case DataVar:
bsz := 536870904
n := Data.Len()
nb := int(math.Ceil(float64(n) / float64(bsz)))
for bi := range nb {
v, _ := syVars.ValueByIndex(1, fmt.Sprintf("Data%d", bi), 0)
v.ReadSync()
st := bsz * bi
ed := min(bsz * (bi+1), n)
gpu.ReadToBytes(v, Data.Values[st:ed])
}
case IntDataVar:
v, _ := syVars.ValueByIndex(1, "IntData", 0)
v.ReadSync()
gpu.ReadToBytes(v, IntData.Values)
}
}
}
// GetParams returns a pointer to the given global variable:
// [Params] []ParamStruct at given index. This directly processed in the GPU code,
// so this function call is an equivalent for the CPU.
func GetParams(idx uint32) *ParamStruct {
return &Params[idx]
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This example just does some basic calculations on data structures and
// reports the time difference between the CPU and GPU.
package main
import (
"fmt"
"math/rand"
"runtime"
"cogentcore.org/core/base/timer"
"cogentcore.org/core/gpu"
"cogentcore.org/lab/tensor"
)
//go:generate gosl
func init() {
// must lock main thread for gpu!
runtime.LockOSThread()
}
func main() {
gpu.Debug = true
GPUInit()
rand.Seed(0)
// gpu.NumThreads = 1 // to restrict to sequential for loop
n := 160_000_000 // nbuffs = 8
// n := 16_000_000 // fits in 1 buffer
// n := 2_000_000
Params = make([]ParamStruct, 1)
Params[0].Defaults()
Data = tensor.NewFloat32()
Data.SetShapeSizes(n, 3)
nt := Data.Len()
Params[0].DataLen = uint32(nt)
IntData = tensor.NewInt32()
IntData.SetShapeSizes(n, 3)
for i := range nt {
Data.Set1D(rand.Float32(), i)
}
sid := tensor.NewInt32()
sid.SetShapeSizes(n, 3)
sd := tensor.NewFloat32()
sd.SetShapeSizes(n, 3)
for i := range nt {
sd.Set1D(Data.Value1D(i), i)
}
cpuTmr := timer.Time{}
cpuTmr.Start()
RunOneAtomic(n)
RunOneCompute(n)
cpuTmr.Stop()
cd := Data
cid := IntData
Data = sd
IntData = sid
gpuFullTmr := timer.Time{}
gpuFullTmr.Start()
UseGPU = true
ToGPUTensorStrides()
ToGPU(ParamsVar, DataVar, IntDataVar)
gpuTmr := timer.Time{}
gpuTmr.Start()
RunAtomic(n)
RunCompute(n)
gpuTmr.Stop()
RunDone(DataVar, IntDataVar)
gpuFullTmr.Stop()
mx := min(n, 5)
for i := 0; i < mx; i++ {
fmt.Printf("%d\t CPU IntData: %d\t GPU: %d\n", i, cid.Value(1, Integ), sid.Value(i, Integ))
}
fmt.Println()
for i := 0; i < mx; i++ {
d := cd.Value(i, Exp) - sd.Value(i, Exp)
fmt.Printf("CPU:\t%d\t Raw: %6.4g\t Integ: %6.4g\t Exp: %6.4g\tGPU: %6.4g\tDiff: %g\n", i, cd.Value(i, Raw), cd.Value(i, Integ), cd.Value(i, Exp), sd.Value(i, Exp), d)
fmt.Printf("GPU:\t%d\t Raw: %6.4g\t Integ: %6.4g\t Exp: %6.4g\tCPU: %6.4g\tDiff: %g\n\n", i, sd.Value(i, Raw), sd.Value(i, Integ), sd.Value(i, Exp), cd.Value(i, Exp), d)
}
fmt.Printf("\n")
cpu := cpuTmr.Total
gpu := gpuTmr.Total
gpuFull := gpuFullTmr.Total
fmt.Printf("N: %d\t CPU: %v\t GPU: %v\t Full: %v\t CPU/GPU: %6.4g\n", n, cpu, gpu, gpuFull, float64(cpu)/float64(gpu))
GPURelease()
}
// Code generated by "gosl"; DO NOT EDIT
package main
import (
"embed"
"fmt"
"math"
"unsafe"
"cogentcore.org/core/gpu"
"cogentcore.org/lab/tensor"
)
//go:embed shaders/*.wgsl
var shaders embed.FS
var (
// GPUInitialized is true once the GPU system has been initialized.
// Prevents multiple initializations.
GPUInitialized bool
// ComputeGPU is the compute gpu device.
// Set this prior to calling GPUInit() to use an existing device.
ComputeGPU *gpu.GPU
// BorrowedGPU is true if our ComputeGPU is set externally,
// versus created specifically for this system. If external,
// we don't release it.
BorrowedGPU bool
// UseGPU indicates whether to use GPU vs. CPU.
UseGPU bool
)
// GPUSystem is a GPU compute System with kernels operating on the
// same set of data variables.
var GPUSystem *gpu.ComputeSystem
// GPUVars is an enum for GPU variables, for specifying what to sync.
type GPUVars int32 //enums:enum
const (
SeedVar GPUVars = 0
FloatsVar GPUVars = 1
UintsVar GPUVars = 2
)
// Tensor stride variables
var TensorStrides tensor.Uint32
// GPUInit initializes the GPU compute system,
// configuring system(s), variables and kernels.
// It is safe to call multiple times: detects if already run.
func GPUInit() {
if GPUInitialized {
return
}
GPUInitialized = true
if ComputeGPU == nil { // set prior to this call to use an external
ComputeGPU = gpu.NewComputeGPU()
} else {
BorrowedGPU = true
}
gp := ComputeGPU
_ = fmt.Sprintf("%g",math.NaN()) // keep imports happy
{
sy := gpu.NewComputeSystem(gp, "Default")
GPUSystem = sy
vars := sy.Vars()
{
sgp := vars.AddGroup(gpu.Storage, "Group_0")
var vr *gpu.Var
_ = vr
vr = sgp.Add("TensorStrides", gpu.Uint32, 1, gpu.ComputeShader)
vr.ReadOnly = true
vr = sgp.AddStruct("Seed", int(unsafe.Sizeof(Seeds{})), 1, gpu.ComputeShader)
vr.ReadOnly = true
vr = sgp.Add("Floats", gpu.Float32, 1, gpu.ComputeShader)
vr = sgp.Add("Uints", gpu.Uint32, 1, gpu.ComputeShader)
sgp.SetNValues(1)
}
var pl *gpu.ComputePipeline
pl = gpu.NewComputePipelineShaderFS(shaders, "shaders/Compute.wgsl", sy)
pl.AddVarUsed(0, "TensorStrides")
pl.AddVarUsed(0, "Floats")
pl.AddVarUsed(0, "Seed")
pl.AddVarUsed(0, "Uints")
sy.Config()
}
}
// GPURelease releases the GPU compute system resources.
// Call this at program exit.
func GPURelease() {
if GPUSystem != nil {
GPUSystem.Release()
GPUSystem = nil
}
if !BorrowedGPU && ComputeGPU != nil {
ComputeGPU.Release()
}
ComputeGPU = nil
}
// RunCompute runs the Compute kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// Can call multiple Run* kernels in a row, which are then all launched
// in the same command submission on the GPU, which is by far the most efficient.
// MUST call RunDone (with optional vars to sync) after all Run calls.
// Alternatively, a single-shot RunOneCompute call does Run and Done for a
// single run-and-sync case.
func RunCompute(n int) {
if UseGPU {
RunComputeGPU(n)
} else {
RunComputeCPU(n)
}
}
// RunComputeGPU runs the Compute kernel on the GPU. See [RunCompute] for more info.
func RunComputeGPU(n int) {
sy := GPUSystem
pl := sy.ComputePipelines["Compute"]
ce, _ := sy.BeginComputePass()
pl.Dispatch1D(ce, n, 64)
}
// RunComputeCPU runs the Compute kernel on the CPU.
func RunComputeCPU(n int) {
gpu.VectorizeFunc(0, n, Compute)
}
// RunOneCompute runs the Compute kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// This version then calls RunDone with the given variables to sync
// after the Run, for a single-shot Run-and-Done call. If multiple kernels
// can be run in sequence, it is much more efficient to do multiple Run*
// calls followed by a RunDone call.
func RunOneCompute(n int, syncVars ...GPUVars) {
if UseGPU {
RunComputeGPU(n)
RunDone(syncVars...)
} else {
RunComputeCPU(n)
}
}
// RunDone must be called after Run* calls to start compute kernels.
// This actually submits the kernel jobs to the GPU, and adds commands
// to synchronize the given variables back from the GPU to the CPU.
// After this function completes, the GPU results will be available in
// the specified variables.
func RunDone(syncVars ...GPUVars) {
if !UseGPU {
return
}
sy := GPUSystem
sy.ComputeEncoder.End()
ReadFromGPU(syncVars...)
sy.EndComputePass()
SyncFromGPU(syncVars...)
}
// ToGPU copies given variables to the GPU for the system.
func ToGPU(vars ...GPUVars) {
if !UseGPU {
return
}
sy := GPUSystem
syVars := sy.Vars()
for _, vr := range vars {
switch vr {
case SeedVar:
v, _ := syVars.ValueByIndex(0, "Seed", 0)
gpu.SetValueFrom(v, Seed)
case FloatsVar:
v, _ := syVars.ValueByIndex(0, "Floats", 0)
gpu.SetValueFrom(v, Floats.Values)
case UintsVar:
v, _ := syVars.ValueByIndex(0, "Uints", 0)
gpu.SetValueFrom(v, Uints.Values)
}
}
}
// RunGPUSync can be called to synchronize data between CPU and GPU.
// Any prior ToGPU* calls will execute to send data to the GPU,
// and any subsequent RunDone* calls will copy data back from the GPU.
func RunGPUSync() {
if !UseGPU {
return
}
sy := GPUSystem
sy.BeginComputePass()
}
// ToGPUTensorStrides gets tensor strides and starts copying to the GPU.
func ToGPUTensorStrides() {
if !UseGPU {
return
}
sy := GPUSystem
syVars := sy.Vars()
TensorStrides.SetShapeSizes(20)
TensorStrides.SetInt1D(Floats.Shape().Strides[0], 0)
TensorStrides.SetInt1D(Floats.Shape().Strides[1], 1)
TensorStrides.SetInt1D(Uints.Shape().Strides[0], 10)
TensorStrides.SetInt1D(Uints.Shape().Strides[1], 11)
v, _ := syVars.ValueByIndex(0, "TensorStrides", 0)
gpu.SetValueFrom(v, TensorStrides.Values)
}
// ReadFromGPU starts the process of copying vars to the GPU.
func ReadFromGPU(vars ...GPUVars) {
sy := GPUSystem
syVars := sy.Vars()
for _, vr := range vars {
switch vr {
case SeedVar:
v, _ := syVars.ValueByIndex(0, "Seed", 0)
v.GPUToRead(sy.CommandEncoder)
case FloatsVar:
v, _ := syVars.ValueByIndex(0, "Floats", 0)
v.GPUToRead(sy.CommandEncoder)
case UintsVar:
v, _ := syVars.ValueByIndex(0, "Uints", 0)
v.GPUToRead(sy.CommandEncoder)
}
}
}
// SyncFromGPU synchronizes vars from the GPU to the actual variable.
func SyncFromGPU(vars ...GPUVars) {
sy := GPUSystem
syVars := sy.Vars()
for _, vr := range vars {
switch vr {
case SeedVar:
v, _ := syVars.ValueByIndex(0, "Seed", 0)
v.ReadSync()
gpu.ReadToBytes(v, Seed)
case FloatsVar:
v, _ := syVars.ValueByIndex(0, "Floats", 0)
v.ReadSync()
gpu.ReadToBytes(v, Floats.Values)
case UintsVar:
v, _ := syVars.ValueByIndex(0, "Uints", 0)
v.ReadSync()
gpu.ReadToBytes(v, Uints.Values)
}
}
}
// GetSeed returns a pointer to the given global variable:
// [Seed] []Seeds at given index. This directly processed in the GPU code,
// so this function call is an equivalent for the CPU.
func GetSeed(idx uint32) *Seeds {
return &Seed[idx]
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"runtime"
"log/slog"
"cogentcore.org/core/base/timer"
"cogentcore.org/lab/tensor"
)
//go:generate gosl
func init() {
// must lock main thread for gpu!
runtime.LockOSThread()
}
func main() {
GPUInit()
// n := 10
n := 16_000_000 // max for macbook M*
// n := 200_000
UseGPU = false
Seed = make([]Seeds, 1)
dataCU := tensor.NewUint32(n, 2)
dataGU := tensor.NewUint32(n, 2)
dataCF := tensor.NewFloat32(n, NVars)
dataGF := tensor.NewFloat32(n, NVars)
Uints = dataCU
Floats = dataCF
cpuTmr := timer.Time{}
cpuTmr.Start()
RunOneCompute(n)
cpuTmr.Stop()
UseGPU = true
Uints = dataGU
Floats = dataGF
gpuFullTmr := timer.Time{}
gpuFullTmr.Start()
ToGPUTensorStrides()
ToGPU(SeedVar, FloatsVar, UintsVar)
gpuTmr := timer.Time{}
gpuTmr.Start()
RunCompute(n)
gpuTmr.Stop()
RunDone(FloatsVar, UintsVar)
gpuFullTmr.Stop()
anyDiffEx := false
anyDiffTol := false
mx := min(n, 5)
fmt.Printf("Index\tDif(Ex,Tol)\t CPU \t then GPU\n")
for i := 0; i < n; i++ {
smEx, smTol := IsSame(dataCU, dataGU, dataCF, dataGF, i)
if !smEx {
anyDiffEx = true
}
if !smTol {
anyDiffTol = true
}
if i > mx {
continue
}
exS := " "
if !smEx {
exS = "*"
}
tolS := " "
if !smTol {
tolS = "*"
}
fmt.Printf("%d\t%s %s\t%s\n\t\t%s\n", i, exS, tolS, String(dataCU, dataCF, i), String(dataGU, dataGF, i))
}
fmt.Printf("\n")
if anyDiffEx {
slog.Error("Differences between CPU and GPU detected at Exact level (excludes Gauss)")
}
if anyDiffTol {
slog.Error("Differences between CPU and GPU detected at Tolerance level", "tolerance", Tol)
}
cpu := cpuTmr.Total
gpu := gpuTmr.Total
fmt.Printf("N: %d\t CPU: %v\t GPU: %v\t Full: %v\t CPU/GPU: %6.4g\n", n, cpu, gpu, gpuFullTmr.Total, float64(cpu)/float64(gpu))
GPURelease()
}
// Code generated by "goal build"; DO NOT EDIT.
//
//line rand.goal:1
package main
import (
"fmt"
"cogentcore.org/core/math32"
"cogentcore.org/lab/gosl/slrand"
"cogentcore.org/lab/tensor"
)
//gosl:start
//gosl:vars
var (
//gosl:read-only
Seed []Seeds
// Floats has random float values: [idx][6]
//
//gosl:dims 2
Floats *tensor.Float32
// Uints has random uint32 values: [idx][2]
//
//gosl:dims 2
Uints *tensor.Uint32
)
type Seeds struct {
Seed uint64
pad, pad1 int32
}
const (
FloatX int = iota
FloatY
Float11X
Float11Y
GaussX
GaussY
NVars
)
// RndGen calls random function calls to test generator.
// Note that the counter to the outer-most computation function
// is passed by *value*, so the same counter goes to each element
// as it is computed, but within this scope, counter is passed by
// reference (as a pointer) so subsequent calls get a new counter value.
// The counter should be incremented by the number of random calls
// outside of the overall update function.
func RndGen(counter uint64, idx uint32) {
uints := slrand.Uint32Vec2(counter, uint32(0), idx)
floats := slrand.Float32Vec2(counter, uint32(1), idx)
floats11 := slrand.Float32Range11Vec2(counter, uint32(2), idx)
gauss := slrand.Float32NormVec2(counter, uint32(3), idx)
Uints.Set(uints.X, int(idx), int(0))
Uints.Set(uints.Y, int(idx), int(1))
Floats.Set(floats.X, int(idx), int(FloatX))
Floats.Set(floats.Y, int(idx), int(FloatY))
Floats.Set(floats11.X, int(idx), int(Float11X))
Floats.Set(floats11.Y, int(idx), int(Float11Y))
Floats.Set(gauss.X, int(idx), int(GaussX))
Floats.Set(gauss.Y, int(idx), int(GaussY))
}
func Compute(i uint32) { //gosl:kernel
// note: this should have a bounds check here on i -- can be larger than Floats
RndGen(Seed[0].Seed, i)
}
//gosl:end
const Tol = 1.0e-4 // fails at lower tol eventually -- -6 works for many
func FloatSame(f1, f2 float32) (exact, tol bool) {
exact = f1 == f2
tol = math32.Abs(f1-f2) < Tol
return
}
func Float32Vec2Same(ax, bx, ay, by float32) (exact, tol bool) {
e1, t1 := FloatSame(ax, bx)
e2, t2 := FloatSame(ay, by)
exact = e1 && e2
tol = t1 && t2
return
}
// IsSame compares values at two levels: exact and with Tol
func IsSame(au, bu *tensor.Uint32, af, bf *tensor.Float32, idx int) (exact, tol bool) {
e1 := au.Value(int(idx), int(0)) == bu.Value(int(idx), int(0)) && au.Value(int(idx), int(1)) == bu.Value(int(idx), int(1))
e2, t2 := Float32Vec2Same(af.Value(int(idx), int(FloatX)), bf.Value(int(idx), int(FloatX)), af.Value(int(idx), int(FloatY)), bf.Value(int(idx), int(FloatY)))
e3, t3 := Float32Vec2Same(af.Value(int(idx), int(Float11X)), bf.Value(int(idx), int(Float11X)), af.Value(int(idx), int(Float11Y)), bf.Value(int(idx), int(Float11Y)))
_, t4 := Float32Vec2Same(af.Value(int(idx), int(GaussX)), bf.Value(int(idx), int(GaussX)), af.Value(int(idx), int(GaussY)), bf.Value(int(idx), int(GaussY)))
exact = e1 && e2 && e3 // skip e4 -- know it isn't
tol = t2 && t3 && t4
return
}
func String(u *tensor.Uint32, f *tensor.Float32, idx int) string {
return fmt.Sprintf("U: %x\t%x\tF: %g\t%g\tF11: %g\t%g\tG: %g\t%g", u.Value(int(idx), int(0)), u.Value(int(idx), int(1)), f.Value(int(idx), int(FloatX)), f.Value(int(idx), int(FloatY)), f.Value(int(idx), int(Float11X)), f.Value(int(idx), int(Float11Y)), f.Value(int(idx), int(GaussX)), f.Value(int(idx), int(GaussY)))
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"cogentcore.org/core/cli"
"cogentcore.org/lab/gosl/gotosl"
)
func main() { //types:skip
opts := cli.DefaultOptions("gosl", "Go as a shader language converts Go code to WGSL WebGPU shader code, which can be run on the GPU through WebGPU.")
cfg := &gotosl.Config{}
cli.Run(opts, cfg, gotosl.Run)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
import (
"fmt"
"sort"
"golang.org/x/exp/maps"
)
// Function represents the call graph of functions
type Function struct {
Name string
Funcs map[string]*Function
Atomics map[string]*Var // variables that have atomic operations in this function
VarsUsed map[string]*Var // all global variables referenced by this function.
}
func NewFunction(name string) *Function {
return &Function{Name: name, Funcs: make(map[string]*Function)}
}
func (fn *Function) AddAtomic(vr *Var) {
if fn.Atomics == nil {
fn.Atomics = make(map[string]*Var)
}
fn.Atomics[vr.Name] = vr
}
func (fn *Function) AddVarUsed(vr *Var) {
if fn.VarsUsed == nil {
fn.VarsUsed = make(map[string]*Var)
}
fn.VarsUsed[vr.Name] = vr
}
// get or add a function of given name
func (st *State) RecycleFunc(name string) *Function {
fn, ok := st.FuncGraph[name]
if !ok {
fn = NewFunction(name)
st.FuncGraph[name] = fn
}
return fn
}
func getAllFuncs(f *Function, all map[string]*Function) {
for fnm, fn := range f.Funcs {
_, ok := all[fnm]
if ok {
continue
}
all[fnm] = fn
getAllFuncs(fn, all)
}
}
// AllFuncs returns aggregated list of all functions called be given function
func (st *State) AllFuncs(name string) map[string]*Function {
fn, ok := st.FuncGraph[name]
if !ok {
fmt.Printf("gosl: ERROR kernel function named: %q not found\n", name)
return nil
}
all := make(map[string]*Function)
all[name] = fn
getAllFuncs(fn, all)
// cfs := maps.Keys(all)
// sort.Strings(cfs)
// for _, cfnm := range cfs {
// fmt.Println("\t" + cfnm)
// }
return all
}
// VarsUsed returns all the atomic and used global variables
// used by the list of functions. Also the total number of used vars
// that includes the NBuffs counts.
func (st *State) VarsUsed(funcs map[string]*Function) (avars, uvars map[string]*Var, nvars int) {
avars = make(map[string]*Var)
uvars = make(map[string]*Var)
for _, fn := range funcs {
for vn, v := range fn.Atomics {
avars[vn] = v
}
for vn, v := range fn.VarsUsed {
uvars[vn] = v
}
}
nvars = 1 // assume TensorStrides always
for _, vr := range uvars {
if vr.NBuffs > 1 {
nvars += vr.NBuffs
} else {
nvars++
}
}
return
}
func (st *State) PrintFuncGraph() {
funs := maps.Keys(st.FuncGraph)
sort.Strings(funs)
for _, fname := range funs {
fmt.Println(fname)
fn := st.FuncGraph[fname]
cfs := maps.Keys(fn.Funcs)
sort.Strings(cfs)
for _, cfnm := range cfs {
fmt.Println("\t" + cfnm)
}
}
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file is largely copied from the Go source,
// src/go/printer/comment.go:
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
import (
"go/ast"
"go/doc/comment"
"strings"
)
// formatDocComment reformats the doc comment list,
// returning the canonical formatting.
func formatDocComment(list []*ast.Comment) []*ast.Comment {
// Extract comment text (removing comment markers).
var kind, text string
var directives []*ast.Comment
if len(list) == 1 && strings.HasPrefix(list[0].Text, "/*") {
kind = "/*"
text = list[0].Text
if !strings.Contains(text, "\n") || allStars(text) {
// Single-line /* .. */ comment in doc comment position,
// or multiline old-style comment like
// /*
// * Comment
// * text here.
// */
// Should not happen, since it will not work well as a
// doc comment, but if it does, just ignore:
// reformatting it will only make the situation worse.
return list
}
text = text[2 : len(text)-2] // cut /* and */
} else if strings.HasPrefix(list[0].Text, "//") {
kind = "//"
var b strings.Builder
for _, c := range list {
after, found := strings.CutPrefix(c.Text, "//")
if !found {
return list
}
// Accumulate //go:build etc lines separately.
if isDirective(after) {
directives = append(directives, c)
continue
}
b.WriteString(strings.TrimPrefix(after, " "))
b.WriteString("\n")
}
text = b.String()
} else {
// Not sure what this is, so leave alone.
return list
}
if text == "" {
return list
}
// Parse comment and reformat as text.
var p comment.Parser
d := p.Parse(text)
var pr comment.Printer
text = string(pr.Comment(d))
// For /* */ comment, return one big comment with text inside.
slash := list[0].Slash
if kind == "/*" {
c := &ast.Comment{
Slash: slash,
Text: "/*\n" + text + "*/",
}
return []*ast.Comment{c}
}
// For // comment, return sequence of // lines.
var out []*ast.Comment
for text != "" {
var line string
line, text, _ = strings.Cut(text, "\n")
if line == "" {
line = "//"
} else if strings.HasPrefix(line, "\t") {
line = "//" + line
} else {
line = "// " + line
}
out = append(out, &ast.Comment{
Slash: slash,
Text: line,
})
}
if len(directives) > 0 {
out = append(out, &ast.Comment{
Slash: slash,
Text: "//",
})
for _, c := range directives {
out = append(out, &ast.Comment{
Slash: slash,
Text: c.Text,
})
}
}
return out
}
// isDirective reports whether c is a comment directive.
// See go.dev/issue/37974.
// This code is also in go/ast.
func isDirective(c string) bool {
// "//line " is a line directive.
// "//extern " is for gccgo.
// "//export " is for cgo.
// (The // has been removed.)
if strings.HasPrefix(c, "line ") || strings.HasPrefix(c, "extern ") || strings.HasPrefix(c, "export ") {
return true
}
// "//[a-z0-9]+:[a-z0-9]"
// (The // has been removed.)
colon := strings.Index(c, ":")
if colon <= 0 || colon+1 >= len(c) {
return false
}
for i := 0; i <= colon+1; i++ {
if i == colon {
continue
}
b := c[i]
if !('a' <= b && b <= 'z' || '0' <= b && b <= '9') {
return false
}
}
return true
}
// allStars reports whether text is the interior of an
// old-style /* */ comment with a star at the start of each line.
func allStars(text string) bool {
for i := 0; i < len(text); i++ {
if text[i] == '\n' {
j := i + 1
for j < len(text) && (text[j] == ' ' || text[j] == '\t') {
j++
}
if j < len(text) && text[j] != '*' {
return false
}
}
}
return true
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
//go:generate core generate -add-types -add-funcs
// Keep these in sync with go/format/format.go.
const (
tabWidth = 8
printerMode = UseSpaces | TabIndent | printerNormalizeNumbers
// printerNormalizeNumbers means to canonicalize number literal prefixes
// and exponents while printing. See https://golang.org/doc/go1.13#gosl.
//
// This value is defined in go/printer specifically for go/format and cmd/gosl.
printerNormalizeNumbers = 1 << 30
)
// Config has the configuration info for the gosl system.
type Config struct {
// Output is the output directory for shader code,
// relative to where gosl is invoked; must not be an empty string.
Output string `flag:"out" default:"shaders"`
// Exclude is a comma-separated list of names of functions to exclude from exporting to WGSL.
Exclude string `default:"Update,Defaults"`
// Keep keeps temporary converted versions of the source files, for debugging.
Keep bool
// Debug enables debugging messages while running.
Debug bool
// MaxBufferSize is the maximum size for any buffer.
// This is often platform-dependent, but is needed for
// accessing variables that have multiple buffers.
// It is compiled into the kernel code as a constant,
// and must fit in a uint32 number.
// The default is 32 byte aligned down version of 2147483647 max for nvidia
MaxBufferSize uint32 `default:"2147483616"`
}
//cli:cmd -root
func Run(cfg *Config) error { //types:add
st := &State{}
st.Init(cfg)
return st.Run()
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
import (
"bytes"
"fmt"
"path/filepath"
"strings"
"slices"
)
// ExtractFiles processes all the package files and saves the corresponding
// .go files with simple go header.
func (st *State) ExtractFiles() {
st.ImportPackages = make(map[string]bool)
for impath := range st.GoImports {
_, pkg := filepath.Split(impath)
if pkg != "math32" {
st.ImportPackages[pkg] = true
}
}
for fn, fl := range st.GoFiles {
hasVars := false
fl.Lines, hasVars = st.ExtractGosl(fl.Lines)
if hasVars {
st.GoVarsFiles[fn] = fl
delete(st.GoFiles, fn)
}
WriteFileLines(filepath.Join(st.ImportsDir, fn), st.AppendGoHeader(fl.Lines))
}
}
// ExtractImports processes all the imported files and saves the corresponding
// .go files with simple go header.
func (st *State) ExtractImports() {
if len(st.GoImports) == 0 {
return
}
for impath, im := range st.GoImports {
_, pkg := filepath.Split(impath)
for fn, fl := range im {
fl.Lines, _ = st.ExtractGosl(fl.Lines)
WriteFileLines(filepath.Join(st.ImportsDir, pkg+"-"+fn), st.AppendGoHeader(fl.Lines))
}
}
}
// ExtractGosl gosl comment-directive tagged regions from given file.
func (st *State) ExtractGosl(lines [][]byte) (outLines [][]byte, hasVars bool) {
key := []byte("//gosl:")
start := []byte("start")
wgsl := []byte("wgsl")
nowgsl := []byte("nowgsl")
end := []byte("end")
vars := []byte("vars")
imp := []byte("import")
kernel := []byte("//gosl:kernel")
fnc := []byte("func")
comment := []byte("// ")
inReg := false
inWgsl := false
inNoWgsl := false
for li, ln := range lines {
tln := bytes.TrimSpace(ln)
isKey := bytes.HasPrefix(tln, key)
var keyStr []byte
if isKey {
keyStr = tln[len(key):]
// fmt.Printf("key: %s\n", string(keyStr))
}
switch {
case inReg && isKey && bytes.HasPrefix(keyStr, end):
if inWgsl || inNoWgsl {
inWgsl = false
inNoWgsl = false
} else {
inReg = false
}
case inReg && isKey && bytes.HasPrefix(keyStr, vars):
hasVars = true
outLines = append(outLines, ln)
case isKey && bytes.HasPrefix(keyStr, nowgsl):
inReg = true
inNoWgsl = true
outLines = append(outLines, ln) // key to include self here
case isKey && bytes.HasPrefix(keyStr, wgsl):
inReg = true
inWgsl = true
case inWgsl:
if bytes.HasPrefix(tln, comment) {
outLines = append(outLines, tln[3:])
} else {
outLines = append(outLines, ln)
}
case inReg:
for pkg := range st.ImportPackages { // remove package prefixes
if !bytes.Contains(ln, imp) {
ln = bytes.ReplaceAll(ln, []byte(pkg+"."), []byte{})
}
}
if bytes.HasPrefix(ln, fnc) && bytes.Contains(ln, kernel) {
opts := strings.TrimSpace(string(ln[bytes.LastIndex(ln, kernel)+len(kernel):]))
rw := "read-write:"
rwvars := make(map[string]bool)
flds := strings.Fields(opts)
nf := len(flds)
if nf > 0 && strings.HasPrefix(flds[nf-1], rw) {
rwf := flds[nf-1]
slices.Delete(flds, nf-1, nf)
varlist := strings.Split(rwf[len(rw):], ",")
for _, v := range varlist {
rwvars[v] = true
}
}
sysnm := ""
if len(flds) > 0 {
sysnm = flds[0]
}
sy := st.System(sysnm)
fcall := string(ln[5:])
lp := strings.Index(fcall, "(")
rp := strings.LastIndex(fcall, ")")
args := fcall[lp+1 : rp]
fnm := fcall[:lp]
funcode := ""
for ki := li + 1; ki < len(lines); ki++ {
kl := lines[ki]
if len(kl) > 0 && kl[0] == '}' {
break
}
funcode += string(kl) + "\n"
}
kn := &Kernel{Name: fnm, Args: args, FuncCode: funcode, ReadWriteVars: rwvars}
sy.Kernels[fnm] = kn
if st.Config.Debug {
fmt.Println("\tAdded kernel:", fnm, "args:", args, "system:", sy.Name)
}
}
outLines = append(outLines, ln)
case isKey && bytes.HasPrefix(keyStr, start):
inReg = true
}
}
return
}
// AppendGoHeader appends Go header
func (st *State) AppendGoHeader(lines [][]byte) [][]byte {
olns := make([][]byte, 0, len(lines)+10)
olns = append(olns, []byte("package imports"))
olns = append(olns, []byte(`import (
"math"
"sync/atomic"
"cogentcore.org/core/math32"
"cogentcore.org/lab/gosl/slbool"
"cogentcore.org/lab/gosl/slrand"
"cogentcore.org/lab/gosl/sltype"
"cogentcore.org/lab/gosl/slvec"
"cogentcore.org/lab/tensor"
`))
for impath := range st.GoImports {
if strings.Contains(impath, "core/goal/gosl") {
continue
}
olns = append(olns, []byte("\t\""+impath+"\""))
}
olns = append(olns, []byte(")"))
olns = append(olns, lines...)
SlBoolReplace(olns)
return olns
}
// ExtractWGSL extracts key stuff for WGSL code, not package
// and import directives.
func (st *State) ExtractWGSL(lines [][]byte) [][]byte {
pack := []byte("package")
imp := []byte("import")
lparen := []byte("(")
rparen := []byte(")")
mx := min(10, len(lines))
stln := 0
gotImp := false
for li := 0; li < mx; li++ {
ln := lines[li]
switch {
case bytes.HasPrefix(ln, pack):
stln = li + 1
case bytes.HasPrefix(ln, imp):
if bytes.HasSuffix(ln, lparen) {
gotImp = true
} else {
stln = li + 1
}
case gotImp && bytes.HasPrefix(ln, rparen):
stln = li + 1
}
}
lines = lines[stln:] // get rid of package, import
return lines
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
import (
"bytes"
"fmt"
"io/fs"
"log"
"os"
"path/filepath"
"strings"
"cogentcore.org/core/base/fsx"
"golang.org/x/tools/go/packages"
)
// wgslFile returns the file with a ".wgsl" extension
func wgslFile(fn string) string {
f, _ := fsx.ExtSplit(fn)
return f + ".wgsl"
}
// bareFile returns the file with no extention
func bareFile(fn string) string {
f, _ := fsx.ExtSplit(fn)
return f
}
func ReadFileLines(fn string) ([][]byte, error) {
nl := []byte("\n")
buf, err := os.ReadFile(fn)
if err != nil {
fmt.Println(err)
return nil, err
}
lines := bytes.Split(buf, nl)
return lines, nil
}
func WriteFileLines(fn string, lines [][]byte) error {
res := bytes.Join(lines, []byte("\n"))
return os.WriteFile(fn, res, 0644)
}
// HasGoslTag returns true if given file has a //gosl: tag
func (st *State) HasGoslTag(lines [][]byte) bool {
key := []byte("//gosl:")
pkg := []byte("package ")
for _, ln := range lines {
tln := bytes.TrimSpace(ln)
if st.Package == "" {
if bytes.HasPrefix(tln, pkg) {
st.Package = string(bytes.TrimPrefix(tln, pkg))
}
}
if bytes.HasPrefix(tln, key) {
return true
}
}
return false
}
func IsGoFile(f fs.DirEntry) bool {
name := f.Name()
return !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go") && !f.IsDir()
}
func IsWGSLFile(f fs.DirEntry) bool {
name := f.Name()
return !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".wgsl") && !f.IsDir()
}
// ProjectFiles gets the files in the current directory.
func (st *State) ProjectFiles() {
fls := fsx.Filenames(".", ".go")
st.GoFiles = make(map[string]*File)
st.GoVarsFiles = make(map[string]*File)
for _, fn := range fls {
fl := &File{Name: fn}
var err error
fl.Lines, err = ReadFileLines(fn)
if err != nil {
continue
}
if !st.HasGoslTag(fl.Lines) {
continue
}
st.GoFiles[fn] = fl
st.ImportFiles(fl.Lines)
}
}
// ImportFiles checks the given content for //gosl:import tags
// and imports the package if so.
func (st *State) ImportFiles(lines [][]byte) {
key := []byte("//gosl:import ")
for _, ln := range lines {
tln := bytes.TrimSpace(ln)
if !bytes.HasPrefix(tln, key) {
continue
}
impath := strings.TrimSpace(string(tln[len(key):]))
if impath[0] == '"' {
impath = impath[1:]
}
if impath[len(impath)-1] == '"' {
impath = impath[:len(impath)-1]
}
_, ok := st.GoImports[impath]
if ok {
continue
}
var pkgs []*packages.Package
var err error
pkgs, err = packages.Load(&packages.Config{Mode: packages.NeedName | packages.NeedFiles}, impath)
if err != nil {
fmt.Println(err)
continue
}
pfls := make(map[string]*File)
st.GoImports[impath] = pfls
pkg := pkgs[0]
gofls := pkg.GoFiles
if len(gofls) == 0 {
fmt.Printf("WARNING: no go files found in path: %s\n", impath)
}
for _, gf := range gofls {
lns, err := ReadFileLines(gf)
if err != nil {
continue
}
if !st.HasGoslTag(lns) {
continue
}
_, fo := filepath.Split(gf)
pfls[fo] = &File{Name: fo, Lines: lns}
st.ImportFiles(lns)
// fmt.Printf("added file: %s from package: %s\n", gf, impath)
}
st.GoImports[impath] = pfls
}
}
// RemoveGenFiles removes .go, .wgsl, .spv files in shader generated dir
func RemoveGenFiles(dir string) {
err := filepath.WalkDir(dir, func(path string, f fs.DirEntry, err error) error {
if err != nil {
return err
}
if IsGoFile(f) || IsWGSLFile(f) {
os.Remove(path)
}
return nil
})
if err != nil {
log.Println(err)
}
}
// CopyPackageFile copies given file name from given package path
// into the current imports directory.
// e.g., "slrand.wgsl", "cogentcore.org/lab/gosl/slrand"
func (st *State) CopyPackageFile(fnm, packagePath string) error {
for _, f := range st.SLImportFiles {
if f.Name == fnm { // don't re-import
return nil
}
}
tofn := filepath.Join(st.ImportsDir, fnm)
pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName | packages.NeedFiles}, packagePath)
if err != nil {
fmt.Println(err)
return err
}
if len(pkgs) != 1 {
err = fmt.Errorf("%s package not found", packagePath)
fmt.Println(err)
return err
}
pkg := pkgs[0]
var fn string
if len(pkg.GoFiles) > 0 {
fn = pkg.GoFiles[0]
} else if len(pkg.OtherFiles) > 0 {
fn = pkg.GoFiles[0]
} else {
err = fmt.Errorf("No files found in package: %s", packagePath)
fmt.Println(err)
return err
}
dir, _ := filepath.Split(fn)
fmfn := filepath.Join(dir, fnm)
lines, err := CopyFile(fmfn, tofn)
if err == nil {
lines = SlRemoveComments(lines)
st.SLImportFiles = append(st.SLImportFiles, &File{Name: fnm, Lines: lines})
}
return nil
}
func CopyFile(src, dst string) ([][]byte, error) {
lines, err := ReadFileLines(src)
if err != nil {
return lines, err
}
err = WriteFileLines(dst, lines)
if err != nil {
return lines, err
}
return lines, err
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
import (
"fmt"
"os"
"path/filepath"
"slices"
"strings"
"golang.org/x/exp/maps"
)
// genSysName is the name to use for system in generating code.
// if only one system, the name is empty
func (st *State) genSysName(sy *System) string {
if len(st.Systems) == 1 {
return ""
}
return sy.Name
}
// genSysVar is the name to use for system in generating code.
// if only one system, the name is empty
func (st *State) genSysVar(sy *System) string {
return fmt.Sprintf("GPU%sSystem", st.genSysName(sy))
}
// GenGPU generates and writes the Go GPU helper code.
// if imports then generates in imports directory.
func (st *State) GenGPU(imports bool) {
var b strings.Builder
header := `// Code generated by "gosl"; DO NOT EDIT
package %s
import (
"embed"
"fmt"
"math"
"unsafe"
"cogentcore.org/core/gpu"
"cogentcore.org/lab/tensor"
)
%s
var (
// GPUInitialized is true once the GPU system has been initialized.
// Prevents multiple initializations.
GPUInitialized bool
// ComputeGPU is the compute gpu device.
// Set this prior to calling GPUInit() to use an existing device.
ComputeGPU *gpu.GPU
// BorrowedGPU is true if our ComputeGPU is set externally,
// versus created specifically for this system. If external,
// we don't release it.
BorrowedGPU bool
// UseGPU indicates whether to use GPU vs. CPU.
UseGPU bool
)
`
pkg := st.Package
shaders := fmt.Sprintf(`//go:embed %s/*.wgsl
var shaders embed.FS`, st.Config.Output)
if imports {
shaders = `var shaders embed.FS`
pkg = "imports"
}
b.WriteString(fmt.Sprintf(header, pkg, shaders))
sys := maps.Keys(st.Systems)
slices.Sort(sys)
for _, synm := range sys {
sy := st.Systems[synm]
b.WriteString(fmt.Sprintf("// %s is a GPU compute System with kernels operating on the\n// same set of data variables.\n", st.genSysVar(sy)))
b.WriteString(fmt.Sprintf("var %s *gpu.ComputeSystem\n", st.genSysVar(sy)))
}
venum := `
// GPUVars is an enum for GPU variables, for specifying what to sync.
type GPUVars int32 //enums:enum
const (
`
b.WriteString(venum)
vidx := 0
hasTensors := false
for _, synm := range sys {
sy := st.Systems[synm]
if sy.NTensors > 0 {
hasTensors = true
}
for _, gp := range sy.Groups {
for _, vr := range gp.Vars {
b.WriteString(fmt.Sprintf("\t%sVar GPUVars = %d\n", vr.Name, vidx))
vidx++
}
}
}
b.WriteString(")\n")
if hasTensors {
b.WriteString("\n// Tensor stride variables\n")
for _, synm := range sys {
sy := st.Systems[synm]
genSynm := st.genSysName(sy)
b.WriteString(fmt.Sprintf("var %sTensorStrides tensor.Uint32\n", genSynm))
}
} else {
b.WriteString("\n// Dummy tensor stride variable to avoid import error\n")
b.WriteString("var __TensorStrides tensor.Uint32\n")
}
initf := `
// GPUInit initializes the GPU compute system,
// configuring system(s), variables and kernels.
// It is safe to call multiple times: detects if already run.
func GPUInit() {
if GPUInitialized {
return
}
GPUInitialized = true
if ComputeGPU == nil { // set prior to this call to use an external
ComputeGPU = gpu.NewComputeGPU()
} else {
BorrowedGPU = true
}
gp := ComputeGPU
_ = fmt.Sprintf("%g",math.NaN()) // keep imports happy
`
b.WriteString(initf)
for _, synm := range sys {
sy := st.Systems[synm]
b.WriteString(st.GenGPUSystemInit(sy))
}
b.WriteString("}\n\n")
release := `// GPURelease releases the GPU compute system resources.
// Call this at program exit.
func GPURelease() {
`
b.WriteString(release)
sysRelease := ` if %[1]s != nil {
%[1]s.Release()
%[1]s = nil
}
`
for _, synm := range sys {
sy := st.Systems[synm]
b.WriteString(fmt.Sprintf(sysRelease, st.genSysVar(sy)))
}
gpuRelease := `
if !BorrowedGPU && ComputeGPU != nil {
ComputeGPU.Release()
}
ComputeGPU = nil
}
`
b.WriteString(gpuRelease)
for _, synm := range sys {
sy := st.Systems[synm]
b.WriteString(st.GenGPUSystemOps(sy))
}
gs := b.String()
fn := "gosl.go"
if imports {
fn = filepath.Join(st.Config.Output, "imports", fn)
}
os.WriteFile(fn, []byte(gs), 0644)
}
// GenGPUSystemInit generates GPU Init code for given system.
func (st *State) GenGPUSystemInit(sy *System) string {
var b strings.Builder
syvar := st.genSysVar(sy)
b.WriteString("\t{\n")
b.WriteString(fmt.Sprintf("\t\tsy := gpu.NewComputeSystem(gp, %q)\n", sy.Name))
b.WriteString(fmt.Sprintf("\t\t%s = sy\n", syvar))
kns := maps.Keys(sy.Kernels)
slices.Sort(kns)
b.WriteString("\t\tvars := sy.Vars()\n")
for gi, gp := range sy.Groups {
b.WriteString("\t\t{\n")
gtyp := "gpu.Storage"
if gp.Uniform {
gtyp = "gpu.Uniform"
}
b.WriteString(fmt.Sprintf("\t\t\tsgp := vars.AddGroup(%s, %q)\n", gtyp, gp.Name))
b.WriteString("\t\t\tvar vr *gpu.Var\n\t\t\t_ = vr\n")
if sy.NTensors > 0 && gi == 0 {
b.WriteString(fmt.Sprintf("\t\t\tvr = sgp.Add(%q, gpu.%s, 1, gpu.ComputeShader)\n", "TensorStrides", "Uint32"))
b.WriteString("\t\t\tvr.ReadOnly = true\n")
}
for _, vr := range gp.Vars {
if vr.Tensor {
typ := strings.TrimPrefix(vr.Type, "tensor.")
if vr.NBuffs > 1 {
for bi := range vr.NBuffs {
b.WriteString(fmt.Sprintf("\t\t\tvr = sgp.Add(\"%s%d\", gpu.%s, 1, gpu.ComputeShader)\n", vr.Name, bi, typ))
}
} else {
b.WriteString(fmt.Sprintf("\t\t\tvr = sgp.Add(%q, gpu.%s, 1, gpu.ComputeShader)\n", vr.Name, typ))
}
} else {
b.WriteString(fmt.Sprintf("\t\t\tvr = sgp.AddStruct(%q, int(unsafe.Sizeof(%s{})), 1, gpu.ComputeShader)\n", vr.Name, vr.SLType()))
}
if vr.ReadOnly && !vr.ReadOrWrite {
b.WriteString("\t\t\tvr.ReadOnly = true\n")
}
}
b.WriteString("\t\t\tsgp.SetNValues(1)\n")
b.WriteString("\t\t}\n")
}
b.WriteString("\t\tvar pl *gpu.ComputePipeline\n")
for _, knm := range kns {
kn := sy.Kernels[knm]
b.WriteString(fmt.Sprintf("\t\tpl = gpu.NewComputePipelineShaderFS(shaders, %q, sy)\n", kn.Filename))
if sy.NTensors > 0 {
b.WriteString(fmt.Sprintf("\t\tpl.AddVarUsed(%d, %q)\n", 0, "TensorStrides"))
}
vnms := maps.Keys(kn.VarsUsed)
slices.Sort(vnms)
for _, vnm := range vnms {
vr := kn.VarsUsed[vnm]
if vr.NBuffs > 1 {
for bi := range vr.NBuffs {
b.WriteString(fmt.Sprintf("\t\tpl.AddVarUsed(%d, \"%s%d\")\n", vr.Group, vr.Name, bi))
}
} else {
b.WriteString(fmt.Sprintf("\t\tpl.AddVarUsed(%d, %q)\n", vr.Group, vr.Name))
}
}
}
b.WriteString("\t\tsy.Config()\n")
b.WriteString("\t}\n")
return b.String()
}
// GenGPUSystemOps generates GPU helper functions for given system.
func (st *State) GenGPUSystemOps(sy *System) string {
var b strings.Builder
syvar := st.genSysVar(sy)
synm := st.genSysName(sy)
// 1 = kernel, 2 = system var, 3 = sysname (blank for 1 default)
run := `// Run%[1]s runs the %[1]s kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// Can call multiple Run* kernels in a row, which are then all launched
// in the same command submission on the GPU, which is by far the most efficient.
// MUST call RunDone (with optional vars to sync) after all Run calls.
// Alternatively, a single-shot RunOne%[1]s call does Run and Done for a
// single run-and-sync case.
func Run%[1]s(n int) {
if UseGPU {
Run%[1]sGPU(n)
} else {
Run%[1]sCPU(n)
}
}
// Run%[1]sGPU runs the %[1]s kernel on the GPU. See [Run%[1]s] for more info.
func Run%[1]sGPU(n int) {
sy := %[2]s
pl := sy.ComputePipelines[%[1]q]
ce, _ := sy.BeginComputePass()
pl.Dispatch1D(ce, n, 64)
}
// Run%[1]sCPU runs the %[1]s kernel on the CPU.
func Run%[1]sCPU(n int) {
gpu.VectorizeFunc(0, n, %[1]s)
}
// RunOne%[1]s runs the %[1]s kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// This version then calls RunDone with the given variables to sync
// after the Run, for a single-shot Run-and-Done call. If multiple kernels
// can be run in sequence, it is much more efficient to do multiple Run*
// calls followed by a RunDone call.
func RunOne%[1]s(n int, syncVars ...GPUVars) {
if UseGPU {
Run%[1]sGPU(n)
RunDone%[3]s(syncVars...)
} else {
Run%[1]sCPU(n)
}
}
`
// 1 = sysname (blank for 1 default), 2 = system var
runDone := `// RunDone%[1]s must be called after Run* calls to start compute kernels.
// This actually submits the kernel jobs to the GPU, and adds commands
// to synchronize the given variables back from the GPU to the CPU.
// After this function completes, the GPU results will be available in
// the specified variables.
func RunDone%[1]s(syncVars ...GPUVars) {
if !UseGPU {
return
}
sy := %[2]s
sy.ComputeEncoder.End()
%[1]sReadFromGPU(syncVars...)
sy.EndComputePass()
%[1]sSyncFromGPU(syncVars...)
}
// %[1]sToGPU copies given variables to the GPU for the system.
func %[1]sToGPU(vars ...GPUVars) {
if !UseGPU {
return
}
sy := %[2]s
syVars := sy.Vars()
for _, vr := range vars {
switch vr {
`
kns := maps.Keys(sy.Kernels)
slices.Sort(kns)
for _, knm := range kns {
kn := sy.Kernels[knm]
b.WriteString(fmt.Sprintf(run, kn.Name, syvar, synm))
}
b.WriteString(fmt.Sprintf(runDone, synm, syvar))
for gi, gp := range sy.Groups {
for _, vr := range gp.Vars {
vv := vr.Name
if vr.Tensor {
vv += ".Values"
}
b.WriteString(fmt.Sprintf("\t\tcase %sVar:\n", vr.Name))
if vr.NBuffs > 1 {
bsz := st.Config.MaxBufferSize / 4
b.WriteString(fmt.Sprintf("\t\t\tbsz := %d\n", bsz))
b.WriteString(fmt.Sprintf("\t\t\tn := %s.Len()\n", vr.Name))
b.WriteString("\t\t\tnb := int(math.Ceil(float64(n) / float64(bsz)))\n")
b.WriteString("\t\t\tfor bi := range nb {\n")
b.WriteString(fmt.Sprintf("\t\t\t\tv, _ := syVars.ValueByIndex(%d, fmt.Sprintf(\"%s%%d\", bi), 0)\n", gi, vr.Name))
b.WriteString("\t\t\t\tst := bsz * bi\n")
b.WriteString("\t\t\t\ted := min(bsz * (bi+1), n)\n")
b.WriteString(fmt.Sprintf("\t\t\t\tgpu.SetValueFrom(v, %s[st:ed])\n", vv))
b.WriteString("\t\t\t}\n")
} else {
b.WriteString(fmt.Sprintf("\t\t\tv, _ := syVars.ValueByIndex(%d, %q, 0)\n", gi, vr.Name))
b.WriteString(fmt.Sprintf("\t\t\tgpu.SetValueFrom(v, %s)\n", vv))
}
}
}
b.WriteString("\t\t}\n\t}\n}\n")
runSync := `// Run%[1]sGPUSync can be called to synchronize data between CPU and GPU.
// Any prior ToGPU* calls will execute to send data to the GPU,
// and any subsequent RunDone* calls will copy data back from the GPU.
func Run%[1]sGPUSync() {
if !UseGPU {
return
}
sy := %[2]s
sy.BeginComputePass()
}
`
b.WriteString(fmt.Sprintf(runSync, synm, syvar))
if sy.NTensors > 0 {
tensorStrides := `
// %[1]sToGPUTensorStrides gets tensor strides and starts copying to the GPU.
func %[1]sToGPUTensorStrides() {
if !UseGPU {
return
}
sy := %[2]s
syVars := sy.Vars()
`
b.WriteString(fmt.Sprintf(tensorStrides, synm, syvar))
strvar := synm + "TensorStrides"
b.WriteString(fmt.Sprintf("\t%s.SetShapeSizes(%d)\n", strvar, sy.NTensors*10))
for _, gp := range sy.Groups {
for _, vr := range gp.Vars {
if !vr.Tensor {
continue
}
for d := range vr.TensorDims {
b.WriteString(fmt.Sprintf("\t%sTensorStrides.SetInt1D(%s.Shape().Strides[%d], %d)\n", synm, vr.Name, d, vr.TensorIndex*10+d))
}
}
}
b.WriteString(fmt.Sprintf("\tv, _ := syVars.ValueByIndex(0, %q, 0)\n", strvar))
b.WriteString(fmt.Sprintf("\tgpu.SetValueFrom(v, %s.Values)\n", strvar))
b.WriteString("}\n")
}
fmGPU := `
// %[1]sReadFromGPU starts the process of copying vars to the GPU.
func %[1]sReadFromGPU(vars ...GPUVars) {
sy := %[2]s
syVars := sy.Vars()
for _, vr := range vars {
switch vr {
`
b.WriteString(fmt.Sprintf(fmGPU, synm, syvar))
for gi, gp := range sy.Groups {
for _, vr := range gp.Vars {
b.WriteString(fmt.Sprintf("\t\tcase %sVar:\n", vr.Name))
if vr.NBuffs > 1 {
bsz := st.Config.MaxBufferSize / 4
b.WriteString(fmt.Sprintf("\t\t\tbsz := %d\n", bsz))
b.WriteString(fmt.Sprintf("\t\t\tn := %s.Len()\n", vr.Name))
b.WriteString("\t\t\tnb := int(math.Ceil(float64(n) / float64(bsz)))\n")
b.WriteString("\t\t\tfor bi := range nb {\n")
b.WriteString(fmt.Sprintf("\t\t\t\tv, _ := syVars.ValueByIndex(%d, fmt.Sprintf(\"%s%%d\", bi), 0)\n", gi, vr.Name))
b.WriteString("\t\t\t\tv.GPUToRead(sy.CommandEncoder)\n")
b.WriteString("\t\t\t}\n")
} else {
b.WriteString(fmt.Sprintf("\t\t\tv, _ := syVars.ValueByIndex(%d, %q, 0)\n", gi, vr.Name))
b.WriteString("\t\t\tv.GPUToRead(sy.CommandEncoder)\n")
}
}
}
b.WriteString("\t\t}\n\t}\n}\n")
syncGPU := `
// %[1]sSyncFromGPU synchronizes vars from the GPU to the actual variable.
func %[1]sSyncFromGPU(vars ...GPUVars) {
sy := %[2]s
syVars := sy.Vars()
for _, vr := range vars {
switch vr {
`
b.WriteString(fmt.Sprintf(syncGPU, synm, syvar))
for gi, gp := range sy.Groups {
for _, vr := range gp.Vars {
vv := vr.Name
if vr.Tensor {
vv += ".Values"
}
b.WriteString(fmt.Sprintf("\t\tcase %sVar:\n", vr.Name))
if vr.NBuffs > 1 {
bsz := st.Config.MaxBufferSize / 4
b.WriteString(fmt.Sprintf("\t\t\tbsz := %d\n", bsz))
b.WriteString(fmt.Sprintf("\t\t\tn := %s.Len()\n", vr.Name))
b.WriteString("\t\t\tnb := int(math.Ceil(float64(n) / float64(bsz)))\n")
b.WriteString("\t\t\tfor bi := range nb {\n")
b.WriteString(fmt.Sprintf("\t\t\t\tv, _ := syVars.ValueByIndex(%d, fmt.Sprintf(\"%s%%d\", bi), 0)\n", gi, vr.Name))
b.WriteString(fmt.Sprintf("\t\t\t\tv.ReadSync()\n"))
b.WriteString("\t\t\t\tst := bsz * bi\n")
b.WriteString("\t\t\t\ted := min(bsz * (bi+1), n)\n")
b.WriteString(fmt.Sprintf("\t\t\t\tgpu.ReadToBytes(v, %s[st:ed])\n", vv))
b.WriteString("\t\t\t}\n")
} else {
b.WriteString(fmt.Sprintf("\t\t\tv, _ := syVars.ValueByIndex(%d, %q, 0)\n", gi, vr.Name))
b.WriteString(fmt.Sprintf("\t\t\tv.ReadSync()\n"))
b.WriteString(fmt.Sprintf("\t\t\tgpu.ReadToBytes(v, %s)\n", vv))
}
}
}
b.WriteString("\t\t}\n\t}\n}\n")
getFun := `
// Get%[1]s returns a pointer to the given global variable:
// [%[1]s] []%[2]s at given index. This directly processed in the GPU code,
// so this function call is an equivalent for the CPU.
func Get%[1]s(idx uint32) *%[2]s {
return &%[1]s[idx]
}
`
for _, gp := range sy.Groups {
for _, vr := range gp.Vars {
if vr.Tensor {
continue
}
b.WriteString(fmt.Sprintf(getFun, vr.Name, vr.SLType()))
}
}
return b.String()
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
import (
"fmt"
"strings"
)
// GenKernelHeader returns the novel generated WGSL kernel code
// for given kernel, which goes at the top of the resulting file.
func (st *State) GenKernelHeader(sy *System, kn *Kernel) string {
var b strings.Builder
b.WriteString("// Code generated by \"gosl\"; DO NOT EDIT\n")
b.WriteString("// kernel: " + kn.Name + "\n\n")
for gi, gp := range sy.Groups {
if gp.Doc != "" {
b.WriteString("// " + gp.Doc + "\n")
}
str := "storage"
if gp.Uniform {
str = "uniform"
}
if gi == 0 && sy.NTensors > 0 {
access := ", read"
if gp.Uniform {
access = ""
}
b.WriteString("@group(0) @binding(0)\n")
b.WriteString(fmt.Sprintf("var<%s%s> TensorStrides: array<u32>;\n", str, access))
}
for _, vr := range gp.Vars {
_, isAtomic := kn.Atomics[vr.Name]
_, isUsed := kn.VarsUsed[vr.Name]
if !isUsed {
continue
}
access := ", read_write"
if vr.ReadOnly && !vr.ReadOrWrite {
access = ", read"
}
if gp.Uniform {
access = ""
}
if vr.Doc != "" {
b.WriteString("// " + vr.Doc + "\n")
}
if vr.NBuffs <= 1 {
b.WriteString(fmt.Sprintf("@group(%d) @binding(%d)\n", vr.Group, vr.Binding))
b.WriteString(fmt.Sprintf("var<%s%s> %s: ", str, access, vr.Name))
if isAtomic {
b.WriteString(fmt.Sprintf("array<atomic<%s>>;\n", vr.SLType()))
} else {
b.WriteString(fmt.Sprintf("array<%s>;\n", vr.SLType()))
}
continue
}
vn := vr.Binding
for bi := range vr.NBuffs {
b.WriteString(fmt.Sprintf("@group(%d) @binding(%d)\n", gi, vn))
b.WriteString(fmt.Sprintf("var<%s%s> %s%d: ", str, access, vr.Name, bi))
b.WriteString(fmt.Sprintf("array<%s>;\n", vr.SLType()))
vn++
}
}
}
b.WriteString("\nalias GPUVars = i32;\n\n") // gets included when iteratively processing enumgen.go
b.WriteString("@compute @workgroup_size(64, 1, 1)\n")
b.WriteString("fn main(@builtin(workgroup_id) wgid: vec3<u32>, @builtin(num_workgroups) nwg: vec3<u32>, @builtin(local_invocation_index) loci: u32) {\n")
b.WriteString("\tlet idx = loci + (wgid.x + wgid.y * nwg.x + wgid.z * nwg.x * nwg.y) * 64;\n")
b.WriteString(fmt.Sprintf("\t%s(idx);\n", kn.Name))
b.WriteString("}\n")
b.WriteString(st.GenTensorFuncs(sy, kn))
return b.String()
}
// GenTensorFuncs returns the generated WGSL code
// for indexing the tensors in given system.
func (st *State) GenTensorFuncs(sy *System, kn *Kernel) string {
var b strings.Builder
done := make(map[string]bool)
for _, gp := range sy.Groups {
for _, vr := range gp.Vars {
_, isUsed := kn.VarsUsed[vr.Name]
if !vr.Tensor || !isUsed {
continue
}
if vr.NBuffs > 1 {
b.WriteString(st.GenNBuffFuncs(sy, vr))
}
fn := vr.IndexFunc()
if _, ok := done[fn]; ok {
continue
}
done[fn] = true
typ := "u32"
b.WriteString("\nfn " + fn + "(")
nd := vr.TensorDims
for d := range nd {
b.WriteString(fmt.Sprintf("s%d: %s, ", d, typ))
}
for d := range nd {
b.WriteString(fmt.Sprintf("i%d: u32", d))
if d < nd-1 {
b.WriteString(", ")
}
}
b.WriteString(") -> u32 {\n\treturn ")
for d := range nd {
b.WriteString(fmt.Sprintf("s%d * i%d", d, d))
if d < nd-1 {
b.WriteString(" + ")
}
}
b.WriteString(";\n}\n")
}
}
return b.String()
}
// GenNBuffFuncs returns the generated WGSL code
// for accessing data in multi-buffer variables.
func (st *State) GenNBuffFuncs(sy *System, vr *Var) string {
var b strings.Builder
b.WriteString("\nfn " + vr.Name + "Get(ix: u32) -> " + vr.SLType() + " {\n")
bsz := st.Config.MaxBufferSize / 4 // assume 4 bytes per
b.WriteString(fmt.Sprintf("\tlet ii = ix / %d;\n", bsz))
b.WriteString("\tswitch ii {\n")
for bi := range vr.NBuffs {
if bi == vr.NBuffs-1 {
b.WriteString("\tdefault: {\n")
} else {
b.WriteString(fmt.Sprintf("\tcase u32(%d): {\n", bi))
}
if bi > 0 {
b.WriteString(fmt.Sprintf("\t\treturn %s%d[ix - %d];\n", vr.Name, bi, bsz*uint32(bi)))
} else {
b.WriteString(fmt.Sprintf("\t\treturn %s%d[ix];\n", vr.Name, bi))
}
b.WriteString("\t}\n")
}
b.WriteString("\t}\n}\n")
methNames := []string{"Set", "SetAdd", "SetSub", "SetMul", "SetDiv"}
methOps := []string{"=", "+=", "-=", "*=", "/="}
for mi, mn := range methNames {
mop := methOps[mi]
b.WriteString("\nfn " + vr.Name + mn + "(vl: " + vr.SLType() + ", ix: u32) {\n")
b.WriteString(fmt.Sprintf("\tlet ii = ix / %d;\n", bsz))
b.WriteString("\tswitch ii {\n")
for bi := range vr.NBuffs {
if bi == vr.NBuffs-1 {
b.WriteString("\tdefault: {\n")
} else {
b.WriteString(fmt.Sprintf("\tcase u32(%d): {\n", bi))
}
if bi > 0 {
b.WriteString(fmt.Sprintf("\t\t%s%d[ix - %d] %s vl;\n", vr.Name, bi, bsz*uint32(bi), mop))
} else {
b.WriteString(fmt.Sprintf("\t\t%s%d[ix] %s vl;\n", vr.Name, bi, mop))
}
b.WriteString("\t}\n")
}
b.WriteString("\t}\n}\n")
}
return b.String()
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file is largely copied from the Go source,
// src/go/printer/gobuild.go:
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
import (
"go/build/constraint"
"sort"
"text/tabwriter"
)
func (p *printer) fixGoBuildLines() {
if len(p.goBuild)+len(p.plusBuild) == 0 {
return
}
// Find latest possible placement of //go:build and // +build comments.
// That's just after the last blank line before we find a non-comment.
// (We'll add another blank line after our comment block.)
// When we start dropping // +build comments, we can skip over /* */ comments too.
// Note that we are processing tabwriter input, so every comment
// begins and ends with a tabwriter.Escape byte.
// And some newlines have turned into \f bytes.
insert := 0
for pos := 0; ; {
// Skip leading space at beginning of line.
blank := true
for pos < len(p.output) && (p.output[pos] == ' ' || p.output[pos] == '\t') {
pos++
}
// Skip over // comment if any.
if pos+3 < len(p.output) && p.output[pos] == tabwriter.Escape && p.output[pos+1] == '/' && p.output[pos+2] == '/' {
blank = false
for pos < len(p.output) && !isNL(p.output[pos]) {
pos++
}
}
// Skip over \n at end of line.
if pos >= len(p.output) || !isNL(p.output[pos]) {
break
}
pos++
if blank {
insert = pos
}
}
// If there is a //go:build comment before the place we identified,
// use that point instead. (Earlier in the file is always fine.)
if len(p.goBuild) > 0 && p.goBuild[0] < insert {
insert = p.goBuild[0]
} else if len(p.plusBuild) > 0 && p.plusBuild[0] < insert {
insert = p.plusBuild[0]
}
var x constraint.Expr
switch len(p.goBuild) {
case 0:
// Synthesize //go:build expression from // +build lines.
for _, pos := range p.plusBuild {
y, err := constraint.Parse(p.commentTextAt(pos))
if err != nil {
x = nil
break
}
if x == nil {
x = y
} else {
x = &constraint.AndExpr{X: x, Y: y}
}
}
case 1:
// Parse //go:build expression.
x, _ = constraint.Parse(p.commentTextAt(p.goBuild[0]))
}
var block []byte
if x == nil {
// Don't have a valid //go:build expression to treat as truth.
// Bring all the lines together but leave them alone.
// Note that these are already tabwriter-escaped.
for _, pos := range p.goBuild {
block = append(block, p.lineAt(pos)...)
}
for _, pos := range p.plusBuild {
block = append(block, p.lineAt(pos)...)
}
} else {
block = append(block, tabwriter.Escape)
block = append(block, "//go:build "...)
block = append(block, x.String()...)
block = append(block, tabwriter.Escape, '\n')
if len(p.plusBuild) > 0 {
lines, err := constraint.PlusBuildLines(x)
if err != nil {
lines = []string{"// +build error: " + err.Error()}
}
for _, line := range lines {
block = append(block, tabwriter.Escape)
block = append(block, line...)
block = append(block, tabwriter.Escape, '\n')
}
}
}
block = append(block, '\n')
// Build sorted list of lines to delete from remainder of output.
toDelete := append(p.goBuild, p.plusBuild...)
sort.Ints(toDelete)
// Collect output after insertion point, with lines deleted, into after.
var after []byte
start := insert
for _, end := range toDelete {
if end < start {
continue
}
after = appendLines(after, p.output[start:end])
start = end + len(p.lineAt(end))
}
after = appendLines(after, p.output[start:])
if n := len(after); n >= 2 && isNL(after[n-1]) && isNL(after[n-2]) {
after = after[:n-1]
}
p.output = p.output[:insert]
p.output = append(p.output, block...)
p.output = append(p.output, after...)
}
// appendLines is like append(x, y...)
// but it avoids creating doubled blank lines,
// which would not be gofmt-standard output.
// It assumes that only whole blocks of lines are being appended,
// not line fragments.
func appendLines(x, y []byte) []byte {
if len(y) > 0 && isNL(y[0]) && // y starts in blank line
(len(x) == 0 || len(x) >= 2 && isNL(x[len(x)-1]) && isNL(x[len(x)-2])) { // x is empty or ends in blank line
y = y[1:] // delete y's leading blank line
}
return append(x, y...)
}
func (p *printer) lineAt(start int) []byte {
pos := start
for pos < len(p.output) && !isNL(p.output[pos]) {
pos++
}
if pos < len(p.output) {
pos++
}
return p.output[start:pos]
}
func (p *printer) commentTextAt(start int) string {
if start < len(p.output) && p.output[start] == tabwriter.Escape {
start++
}
pos := start
for pos < len(p.output) && p.output[pos] != tabwriter.Escape && !isNL(p.output[pos]) {
pos++
}
return string(p.output[start:pos])
}
func isNL(b byte) bool {
return b == '\n' || b == '\f'
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
import (
"fmt"
"go/ast"
"os"
"path/filepath"
"reflect"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/stack"
)
// System represents a ComputeSystem, and its kernels and variables.
type System struct {
Name string
// Kernels are the kernels using this compute system.
Kernels map[string]*Kernel
// Groups are the variables for this compute system.
Groups []*Group
// NTensors is the number of tensor vars.
NTensors int
}
func NewSystem(name string) *System {
sy := &System{Name: name}
sy.Kernels = make(map[string]*Kernel)
return sy
}
// Kernel represents a kernel function, which is the basis for
// each wgsl generated code file.
type Kernel struct {
Name string
Args string
// Filename is the name of the kernel shader file, e.g., shaders/Compute.wgsl
Filename string
// function code
FuncCode string
// Lines is full shader code
Lines [][]byte
// ReadWriteVars are variables marked as read-write for current kernel.
ReadWriteVars map[string]bool
Atomics map[string]*Var
VarsUsed map[string]*Var
}
// Var represents one global system buffer variable.
type Var struct {
Name string
// Group number that we are in.
Group int
// Binding number within group.
Binding int
// comment docs about this var.
Doc string
// Type of variable: either []Type or F32, U32 for tensors
Type string
// ReadOnly indicates that this variable is never read back from GPU,
// specified by the gosl:read-only property in the variable comments.
// It is important to optimize GPU memory usage to indicate this.
ReadOnly bool
// ReadOrWrite indicates that this variable defaults to ReadOnly
// but is also flagged as read-write in some cases. It is registered
// as read_write in the gpu ComputeSystem, but processed as ReadOnly
// by default except for kernels that declare it as read-write.
ReadOrWrite bool
// True if a tensor type
Tensor bool
// Number of dimensions
TensorDims int
// data kind of the tensor
TensorKind reflect.Kind
// index of tensor in list of tensor variables, for indexing.
TensorIndex int
// NBuffs is the number of buffers to allocate to this variable; default is 1,
// which provides direct access. Otherwise, a wrapper function is generated
// that allows > max buffer size total storage.
// The index still has to fit in a uint32 variable, so 4g max value.
// Assuming 4 bytes per element, that means a total of 16g max total storage.
// The Config.MaxBufferSize (set at compile time, defaults to 2g) determines
// how many buffers: if 2g, then 16 / 2 = 8 max buffers.
NBuffs int
}
func (vr *Var) SetTensorKind() {
kindStr := strings.TrimPrefix(vr.Type, "tensor.")
kind := reflect.Float32
switch kindStr {
case "Float32":
kind = reflect.Float32
case "Uint32":
kind = reflect.Uint32
case "Int32":
kind = reflect.Int32
default:
errors.Log(fmt.Errorf("gosl: variable %q type is not supported: %q", vr.Name, kindStr))
}
vr.TensorKind = kind
}
// SLType returns the WGSL type string
func (vr *Var) SLType() string {
if vr.Tensor {
switch vr.TensorKind {
case reflect.Float32:
return "f32"
case reflect.Int32:
return "i32"
case reflect.Uint32:
return "u32"
}
} else {
return vr.Type[2:]
}
return ""
}
// GoType returns the Go type string for tensors
func (vr *Var) GoType() string {
if vr.Tensor {
switch vr.TensorKind {
case reflect.Float32:
return "float32"
case reflect.Int32:
return "int32"
case reflect.Uint32:
return "uint32"
}
}
return ""
}
// IndexFunc returns the tensor index function name
func (vr *Var) IndexFunc() string {
return fmt.Sprintf("Index%dD", vr.TensorDims)
}
// IndexStride returns the tensor stride variable reference
func (vr *Var) IndexStride(dim int) string {
return fmt.Sprintf("TensorStrides[%d]", vr.TensorIndex*10+dim)
}
// Group represents one variable group.
type Group struct {
Name string
// comment docs about this group
Doc string
// Uniform indicates a uniform group; else default is Storage.
Uniform bool
Vars []*Var
}
// File has contents of a file as lines of bytes.
type File struct {
Name string
Lines [][]byte
}
// GetGlobalVar holds GetVar expression, to Set variable back when done.
type GetGlobalVar struct {
// global variable
Var *Var
// name of temporary variable
TmpVar string
// index passed to the Get function
IdxExpr ast.Expr
// rw override
ReadWrite bool
}
// State holds the current Go -> WGSL processing state.
type State struct {
// Config options.
Config *Config
// path to shaders/imports directory.
ImportsDir string
// name of the package
Package string
// GoFiles are all the files with gosl content in current directory.
GoFiles map[string]*File
// GoVarsFiles are all the files with gosl:vars content in current directory.
// These must be processed first! they are moved from GoFiles to here.
GoVarsFiles map[string]*File
// GoImports has all the imported files.
GoImports map[string]map[string]*File
// ImportPackages has short package names, to remove from go code
// so everything lives in same main package.
ImportPackages map[string]bool
// Systems has the kernels and variables for each system.
// There is an initial "Default" system when system is not specified.
Systems map[string]*System
// GetFuncs is a map of GetVar, SetVar function names for global vars.
GetFuncs map[string]*Var
// VarStructTypes is a map of struct type names to vars that use them.
VarStructTypes map[string]*Var
// SLImportFiles are all the extracted and translated WGSL files in shaders/imports,
// which are copied into the generated shader kernel files.
SLImportFiles []*File
// generated Go GPU gosl.go file contents
GPUFile File
// ExcludeMap is the compiled map of functions to exclude in Go -> WGSL translation.
ExcludeMap map[string]bool
// GetVarStack is a stack per function definition of GetVar variables
// that need to be set at the end.
GetVarStack stack.Stack[map[string]*GetGlobalVar]
// GetFuncGraph is true if getting the function graph (first pass).
GetFuncGraph bool
// CurKernel is the current Kernel for second pass processing.
CurKernel *Kernel
// KernelFuncs are the list of functions to include for current kernel.
KernelFuncs map[string]*Function
// FuncGraph is the call graph of functions, for dead code elimination
FuncGraph map[string]*Function
}
func (st *State) Init(cfg *Config) {
st.Config = cfg
st.GoImports = make(map[string]map[string]*File)
st.Systems = make(map[string]*System)
st.ExcludeMap = make(map[string]bool)
ex := strings.Split(cfg.Exclude, ",")
for _, fn := range ex {
st.ExcludeMap[fn] = true
}
st.Systems["Default"] = NewSystem("Default")
}
func (st *State) Run() error {
if gomod := os.Getenv("GO111MODULE"); gomod == "off" {
err := errors.New("gosl only works in go modules mode, but GO111MODULE=off")
return err
}
if st.Config.Output == "" {
st.Config.Output = "shaders"
}
st.ProjectFiles() // get list of all files, recursively gets imports etc.
if len(st.GoFiles) == 0 {
return nil
}
st.ImportsDir = filepath.Join(st.Config.Output, "imports")
os.MkdirAll(st.Config.Output, 0755)
os.MkdirAll(st.ImportsDir, 0755)
RemoveGenFiles(st.Config.Output)
RemoveGenFiles(st.ImportsDir)
st.ExtractFiles() // get .go from project files
st.ExtractImports() // get .go from imports
st.TranslateDir("./" + st.ImportsDir)
st.GenGPU(false)
return nil
}
// System returns the given system by name, making if not made.
// if name is empty, "Default" is used.
func (st *State) System(sysname string) *System {
if sysname == "" {
sysname = "Default"
}
sy, ok := st.Systems[sysname]
if ok {
return sy
}
sy = NewSystem(sysname)
st.Systems[sysname] = sy
return sy
}
// GlobalVar returns global variable of given name, if found.
func (st *State) GlobalVar(vrnm string) *Var {
if st == nil {
return nil
}
if st.Systems == nil {
return nil
}
for _, sy := range st.Systems {
for _, gp := range sy.Groups {
for _, vr := range gp.Vars {
if vr.Name == vrnm {
return vr
}
}
}
}
return nil
}
// VarIsReadWrite returns true if var of name is set as read-write
// for current kernel.
func (st *State) VarIsReadWrite(vrnm string) bool {
if st.CurKernel == nil {
return false
}
if _, rw := st.CurKernel.ReadWriteVars[vrnm]; rw {
return true
}
return false
}
// GetTempVar returns temp var for global variable of given name, if found.
func (st *State) GetTempVar(vrnm string) *GetGlobalVar {
if st == nil || st.GetVarStack == nil {
return nil
}
nv := len(st.GetVarStack)
for i := nv - 1; i >= 0; i-- {
gvars := st.GetVarStack[i]
if gv, ok := gvars[vrnm]; ok {
return gv
}
}
return nil
}
// VarsAdded is called when a set of vars has been added; update relevant maps etc.
func (st *State) VarsAdded() {
st.GetFuncs = make(map[string]*Var)
st.VarStructTypes = make(map[string]*Var)
for _, sy := range st.Systems {
tensorIdx := 0
for gi, gp := range sy.Groups {
vn := 0
if gi == 0 { // leave room for TensorStrides
vn++
}
for _, vr := range gp.Vars {
vr.Group = gi
vr.Binding = vn
if vr.Tensor {
vr.TensorIndex = tensorIdx
tensorIdx++
if vr.NBuffs > 1 {
vn += vr.NBuffs
} else {
vn++
}
continue
}
st.GetFuncs["Get"+vr.Name] = vr
jtyp := strings.TrimPrefix(vr.Type, "[]")
st.VarStructTypes[jtyp] = vr
vn++
}
}
sy.NTensors = tensorIdx
}
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file is largely copied from the Go source,
// src/go/printer/nodes.go:
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements printing of AST nodes; specifically
// expressions, statements, declarations, and files. It uses
// the print functionality implemented in printer.go.
package gotosl
import (
"fmt"
"go/ast"
"go/token"
"go/types"
"math"
"path"
"slices"
"strconv"
"strings"
"unicode"
"unicode/utf8"
)
// Formatting issues:
// - better comment formatting for /*-style comments at the end of a line (e.g. a declaration)
// when the comment spans multiple lines; if such a comment is just two lines, formatting is
// not idempotent
// - formatting of expression lists
// - should use blank instead of tab to separate one-line function bodies from
// the function header unless there is a group of consecutive one-liners
// ----------------------------------------------------------------------------
// Common AST nodes.
// Print as many newlines as necessary (but at least min newlines) to get to
// the current line. ws is printed before the first line break. If newSection
// is set, the first line break is printed as formfeed. Returns 0 if no line
// breaks were printed, returns 1 if there was exactly one newline printed,
// and returns a value > 1 if there was a formfeed or more than one newline
// printed.
//
// TODO(gri): linebreak may add too many lines if the next statement at "line"
// is preceded by comments because the computation of n assumes
// the current position before the comment and the target position
// after the comment. Thus, after interspersing such comments, the
// space taken up by them is not considered to reduce the number of
// linebreaks. At the moment there is no easy way to know about
// future (not yet interspersed) comments in this function.
func (p *printer) linebreak(line, min int, ws whiteSpace, newSection bool) (nbreaks int) {
n := max(nlimit(line-p.pos.Line), min)
if n > 0 {
p.print(ws)
if newSection {
p.print(formfeed)
n--
nbreaks = 2
}
nbreaks += n
for ; n > 0; n-- {
p.print(newline)
}
}
return
}
// gosl: find any gosl directive in given comments, returns directive(s) and remaining docs
func (p *printer) findDirective(g *ast.CommentGroup) (dirs []string, docs string) {
if g == nil {
return
}
for _, c := range g.List {
if strings.HasPrefix(c.Text, "//gosl:") {
dirs = append(dirs, c.Text[7:])
} else {
docs += c.Text + " "
}
}
return
}
// gosl: hasDirective returns whether directive(s) contains string
func hasDirective(dirs []string, dir string) bool {
for _, d := range dirs {
if strings.Contains(d, dir) {
return true
}
}
return false
}
// gosl: directiveAfter returns the directive after given leading text,
// and a bool indicating if the string was found.
func directiveAfter(dirs []string, dir string) (string, bool) {
for _, d := range dirs {
if strings.HasPrefix(d, dir) {
return strings.TrimSpace(strings.TrimPrefix(d, dir)), true
}
}
return "", false
}
// setComment sets g as the next comment if g != nil and if node comments
// are enabled - this mode is used when printing source code fragments such
// as exports only. It assumes that there is no pending comment in p.comments
// and at most one pending comment in the p.comment cache.
func (p *printer) setComment(g *ast.CommentGroup) {
if g == nil || !p.useNodeComments {
return
}
if p.comments == nil {
// initialize p.comments lazily
p.comments = make([]*ast.CommentGroup, 1)
} else if p.cindex < len(p.comments) {
// for some reason there are pending comments; this
// should never happen - handle gracefully and flush
// all comments up to g, ignore anything after that
p.flush(p.posFor(g.List[0].Pos()), token.ILLEGAL)
p.comments = p.comments[0:1]
// in debug mode, report error
p.internalError("setComment found pending comments")
}
p.comments[0] = g
p.cindex = 0
// don't overwrite any pending comment in the p.comment cache
// (there may be a pending comment when a line comment is
// immediately followed by a lead comment with no other
// tokens between)
if p.commentOffset == infinity {
p.nextComment() // get comment ready for use
}
}
type exprListMode uint
const (
commaTerm exprListMode = 1 << iota // list is optionally terminated by a comma
noIndent // no extra indentation in multi-line lists
)
// If indent is set, a multi-line identifier list is indented after the
// first linebreak encountered.
func (p *printer) identList(list []*ast.Ident, indent bool) {
// convert into an expression list so we can re-use exprList formatting
xlist := make([]ast.Expr, len(list))
for i, x := range list {
xlist[i] = x
}
var mode exprListMode
if !indent {
mode = noIndent
}
p.exprList(token.NoPos, xlist, 1, mode, token.NoPos, false)
}
const filteredMsg = "contains filtered or unexported fields"
// Print a list of expressions. If the list spans multiple
// source lines, the original line breaks are respected between
// expressions.
//
// TODO(gri) Consider rewriting this to be independent of []ast.Expr
// so that we can use the algorithm for any kind of list
//
// (e.g., pass list via a channel over which to range).
func (p *printer) exprList(prev0 token.Pos, list []ast.Expr, depth int, mode exprListMode, next0 token.Pos, isIncomplete bool) {
if len(list) == 0 {
if isIncomplete {
prev := p.posFor(prev0)
next := p.posFor(next0)
if prev.IsValid() && prev.Line == next.Line {
p.print("/* " + filteredMsg + " */")
} else {
p.print(newline)
p.print(indent, "// "+filteredMsg, unindent, newline)
}
}
return
}
prev := p.posFor(prev0)
next := p.posFor(next0)
line := p.lineFor(list[0].Pos())
endLine := p.lineFor(list[len(list)-1].End())
if prev.IsValid() && prev.Line == line && line == endLine {
// all list entries on a single line
for i, x := range list {
if i > 0 {
// use position of expression following the comma as
// comma position for correct comment placement
p.setPos(x.Pos())
p.print(token.COMMA, blank)
}
p.expr0(x, depth)
}
if isIncomplete {
p.print(token.COMMA, blank, "/* "+filteredMsg+" */")
}
return
}
// list entries span multiple lines;
// use source code positions to guide line breaks
// Don't add extra indentation if noIndent is set;
// i.e., pretend that the first line is already indented.
ws := ignore
if mode&noIndent == 0 {
ws = indent
}
// The first linebreak is always a formfeed since this section must not
// depend on any previous formatting.
prevBreak := -1 // index of last expression that was followed by a linebreak
if prev.IsValid() && prev.Line < line && p.linebreak(line, 0, ws, true) > 0 {
ws = ignore
prevBreak = 0
}
// initialize expression/key size: a zero value indicates expr/key doesn't fit on a single line
size := 0
// We use the ratio between the geometric mean of the previous key sizes and
// the current size to determine if there should be a break in the alignment.
// To compute the geometric mean we accumulate the ln(size) values (lnsum)
// and the number of sizes included (count).
lnsum := 0.0
count := 0
// print all list elements
prevLine := prev.Line
for i, x := range list {
line = p.lineFor(x.Pos())
// Determine if the next linebreak, if any, needs to use formfeed:
// in general, use the entire node size to make the decision; for
// key:value expressions, use the key size.
// TODO(gri) for a better result, should probably incorporate both
// the key and the node size into the decision process
useFF := true
// Determine element size: All bets are off if we don't have
// position information for the previous and next token (likely
// generated code - simply ignore the size in this case by setting
// it to 0).
prevSize := size
const infinity = 1e6 // larger than any source line
size = p.nodeSize(x, infinity)
pair, isPair := x.(*ast.KeyValueExpr)
if size <= infinity && prev.IsValid() && next.IsValid() {
// x fits on a single line
if isPair {
size = p.nodeSize(pair.Key, infinity) // size <= infinity
}
} else {
// size too large or we don't have good layout information
size = 0
}
// If the previous line and the current line had single-
// line-expressions and the key sizes are small or the
// ratio between the current key and the geometric mean
// if the previous key sizes does not exceed a threshold,
// align columns and do not use formfeed.
if prevSize > 0 && size > 0 {
const smallSize = 40
if count == 0 || prevSize <= smallSize && size <= smallSize {
useFF = false
} else {
const r = 2.5 // threshold
geomean := math.Exp(lnsum / float64(count)) // count > 0
ratio := float64(size) / geomean
useFF = r*ratio <= 1 || r <= ratio
}
}
needsLinebreak := 0 < prevLine && prevLine < line
if i > 0 {
// Use position of expression following the comma as
// comma position for correct comment placement, but
// only if the expression is on the same line.
if !needsLinebreak {
p.setPos(x.Pos())
}
p.print(token.COMMA)
needsBlank := true
if needsLinebreak {
// Lines are broken using newlines so comments remain aligned
// unless useFF is set or there are multiple expressions on
// the same line in which case formfeed is used.
nbreaks := p.linebreak(line, 0, ws, useFF || prevBreak+1 < i)
if nbreaks > 0 {
ws = ignore
prevBreak = i
needsBlank = false // we got a line break instead
}
// If there was a new section or more than one new line
// (which means that the tabwriter will implicitly break
// the section), reset the geomean variables since we are
// starting a new group of elements with the next element.
if nbreaks > 1 {
lnsum = 0
count = 0
}
}
if needsBlank {
p.print(blank)
}
}
if len(list) > 1 && isPair && size > 0 && needsLinebreak {
// We have a key:value expression that fits onto one line
// and it's not on the same line as the prior expression:
// Use a column for the key such that consecutive entries
// can align if possible.
// (needsLinebreak is set if we started a new line before)
p.expr(pair.Key)
p.setPos(pair.Colon)
p.print(token.COLON, vtab)
p.expr(pair.Value)
} else {
p.expr0(x, depth)
}
if size > 0 {
lnsum += math.Log(float64(size))
count++
}
prevLine = line
}
if mode&commaTerm != 0 && next.IsValid() && p.pos.Line < next.Line {
// Print a terminating comma if the next token is on a new line.
p.print(token.COMMA)
if isIncomplete {
p.print(newline)
p.print("// " + filteredMsg)
}
if ws == ignore && mode&noIndent == 0 {
// unindent if we indented
p.print(unindent)
}
p.print(formfeed) // terminating comma needs a line break to look good
return
}
if isIncomplete {
p.print(token.COMMA, newline)
p.print("// "+filteredMsg, newline)
}
if ws == ignore && mode&noIndent == 0 {
// unindent if we indented
p.print(unindent)
}
}
type paramMode int
const (
funcParam paramMode = iota
funcTParam
typeTParam
)
func (p *printer) parameters(fields *ast.FieldList, mode paramMode) {
openTok, closeTok := token.LPAREN, token.RPAREN
if mode != funcParam {
openTok, closeTok = token.LBRACK, token.RBRACK
}
p.setPos(fields.Opening)
p.print(openTok)
if len(fields.List) > 0 {
prevLine := p.lineFor(fields.Opening)
ws := indent
for pi, par := range fields.List {
// determine par begin and end line (may be different
// if there are multiple parameter names for this par
// or the type is on a separate line)
parLineBeg := p.lineFor(par.Pos())
parLineEnd := p.lineFor(par.End())
// separating "," if needed
needsLinebreak := 0 < prevLine && prevLine < parLineBeg
if pi > 0 {
// use position of parameter following the comma as
// comma position for correct comma placement, but
// only if the next parameter is on the same line
if !needsLinebreak {
p.setPos(par.Pos())
}
p.print(token.COMMA)
}
// separator if needed (linebreak or blank)
if needsLinebreak && p.linebreak(parLineBeg, 0, ws, true) > 0 {
// break line if the opening "(" or previous parameter ended on a different line
ws = ignore
} else if pi > 0 {
p.print(blank)
}
// parameter names
if len(par.Names) > 1 {
nnm := len(par.Names)
for ni, nm := range par.Names {
p.print(nm.Name)
p.print(token.COLON)
p.print(blank)
atyp, isPtr := p.ptrParamType(stripParensAlways(par.Type))
p.expr(atyp)
if isPtr {
p.print(">")
p.curPtrArgs = append(p.curPtrArgs, par.Names[0])
}
if ni < nnm-1 {
p.print(token.COMMA)
}
}
} else if len(par.Names) > 0 {
// Very subtle: If we indented before (ws == ignore), identList
// won't indent again. If we didn't (ws == indent), identList will
// indent if the identList spans multiple lines, and it will outdent
// again at the end (and still ws == indent). Thus, a subsequent indent
// by a linebreak call after a type, or in the next multi-line identList
// will do the right thing.
p.identList(par.Names, ws == indent)
p.print(token.COLON)
p.print(blank)
if pi == 0 { // gosl: cannot have a "real" pointer arg as the first parameter
// because we assume all first parameters are method receivers.
atyp, isPtr := p.methRecvPtrType(stripParensAlways(par.Type), par.Names[0])
p.expr(atyp)
if isPtr {
p.print(">")
p.curPtrArgs = append(p.curPtrArgs, par.Names[0])
}
} else {
atyp, isPtr := p.ptrParamType(stripParensAlways(par.Type))
p.expr(atyp)
if isPtr {
p.print(">")
p.curPtrArgs = append(p.curPtrArgs, par.Names[0])
}
}
} else {
atyp, isPtr := p.ptrParamType(stripParensAlways(par.Type))
p.expr(atyp)
if isPtr {
p.print(">")
}
}
prevLine = parLineEnd
}
// if the closing ")" is on a separate line from the last parameter,
// print an additional "," and line break
if closing := p.lineFor(fields.Closing); 0 < prevLine && prevLine < closing {
p.print(token.COMMA)
p.linebreak(closing, 0, ignore, true)
} else if mode == typeTParam && fields.NumFields() == 1 && combinesWithName(fields.List[0].Type) {
// A type parameter list [P T] where the name P and the type expression T syntactically
// combine to another valid (value) expression requires a trailing comma, as in [P *T,]
// (or an enclosing interface as in [P interface(*T)]), so that the type parameter list
// is not gotosld as an array length [P*T].
p.print(token.COMMA)
}
// unindent if we indented
if ws == ignore {
p.print(unindent)
}
}
p.setPos(fields.Closing)
p.print(closeTok)
}
type rwArg struct {
idx *ast.IndexExpr
tmpVar string
}
func (p *printer) assignRwArgs(rwargs []rwArg) {
nrw := len(rwargs)
if nrw == 0 {
return
}
p.print(token.SEMICOLON, blank, formfeed)
for i, rw := range rwargs {
p.expr(rw.idx)
p.print(token.ASSIGN)
tv := rw.tmpVar
if len(tv) > 0 && tv[0] == '&' {
tv = tv[1:]
}
p.print(tv)
if i < nrw-1 {
p.print(token.SEMICOLON, blank)
}
}
}
// gosl: ensure basic literals are properly cast
func (p *printer) goslFixArgs(args []ast.Expr, params *types.Tuple) ([]ast.Expr, []rwArg) {
ags := slices.Clone(args)
mx := min(len(args), params.Len())
var rwargs []rwArg
for i := 0; i < mx; i++ {
ag := args[i]
pr := params.At(i)
switch x := ag.(type) {
case *ast.BasicLit:
typ := pr.Type()
tnm := getLocalTypeName(typ)
nn := normalizedNumber(x)
nn.Value = tnm + "(" + nn.Value + ")"
ags[i] = nn
case *ast.Ident:
if gvar := p.GoToSL.GetTempVar(x.Name); gvar != nil {
if !(gvar.Var.ReadOnly && !gvar.ReadWrite) {
x.Name = "&" + x.Name
fmt.Println("fix amper", x.Name)
ags[i] = x
}
}
case *ast.IndexExpr:
isGlobal, tmpVar, _, _, isReadOnly := p.globalVar(x)
if isGlobal {
ags[i] = &ast.Ident{Name: tmpVar}
if !isReadOnly {
rwargs = append(rwargs, rwArg{idx: x, tmpVar: tmpVar})
}
}
case *ast.UnaryExpr:
if x.Op == token.AND {
if sel, ok := x.X.(*ast.SelectorExpr); ok {
_, _, _, _, bt, _ := p.selectorPath(sel)
if bt != nil {
ags[i] = sel // gosl: get rid of ampersand -- todo may need further qualification?
}
}
}
if idx, ok := x.X.(*ast.IndexExpr); ok {
isGlobal, tmpVar, _, _, isReadOnly := p.globalVar(idx)
if isGlobal {
ags[i] = &ast.Ident{Name: tmpVar}
if !isReadOnly {
rwargs = append(rwargs, rwArg{idx: idx, tmpVar: tmpVar})
}
}
}
}
}
return ags, rwargs
}
// gosl: ensure basic literals are properly cast
func (p *printer) matchLiteralType(x ast.Expr, typ *ast.Ident) bool {
if lit, ok := x.(*ast.BasicLit); ok {
p.print(typ.Name, token.LPAREN, normalizedNumber(lit), token.RPAREN)
return true
}
return false
}
// gosl: ensure basic literals are properly cast
func (p *printer) matchAssignType(lhs []ast.Expr, rhs []ast.Expr) bool {
if len(rhs) != 1 || len(lhs) != 1 {
return false
}
val := ""
lit, ok := rhs[0].(*ast.BasicLit)
if ok {
val = normalizedNumber(lit).Value
} else {
un, ok := rhs[0].(*ast.UnaryExpr)
if !ok || un.Op != token.SUB {
return false
}
lit, ok = un.X.(*ast.BasicLit)
if !ok {
return false
}
val = "-" + normalizedNumber(lit).Value
}
var err error
var typ types.Type
if id, ok := lhs[0].(*ast.Ident); ok {
typ = p.getIdType(id)
if typ == nil {
return false
}
} else if sl, ok := lhs[0].(*ast.SelectorExpr); ok {
typ, err = p.pathType(sl)
if err != nil {
return false
}
} else if st, ok := lhs[0].(*ast.StarExpr); ok {
if id, ok := st.X.(*ast.Ident); ok {
typ = p.getIdType(id)
if typ == nil {
return false
}
}
if err != nil {
return false
}
}
if typ == nil {
return false
}
tnm := getLocalTypeName(typ)
if tnm[0] == '*' {
tnm = tnm[1:]
}
p.print(tnm, "(", val, ")")
return true
}
// gosl: pathType returns the final type for the selector path.
// a.b.c -> sel.X = (a.b) Sel=c -- returns type of c by tracing
// through the path.
func (p *printer) pathType(x *ast.SelectorExpr) (types.Type, error) {
var paths []*ast.Ident
cur := x
for {
paths = append(paths, cur.Sel)
if sl, ok := cur.X.(*ast.SelectorExpr); ok { // path is itself a selector
cur = sl
continue
}
if id, ok := cur.X.(*ast.Ident); ok {
paths = append(paths, id)
break
}
return nil, fmt.Errorf("gosl pathType: path not a pure selector path")
}
np := len(paths)
idt := p.getIdType(paths[np-1])
if idt == nil {
err := fmt.Errorf("gosl pathType ERROR: cannot find type for name: %q", paths[np-1].Name)
p.userError(err)
return nil, err
}
bt, err := p.getStructType(idt)
if err != nil {
return nil, err
}
for pi := np - 2; pi >= 0; pi-- {
pt := paths[pi]
f := fieldByName(bt, pt.Name)
if f == nil {
return nil, fmt.Errorf("gosl pathType: field not found %q in type: %q:", pt, bt.String())
}
if pi == 0 {
return f.Type(), nil
} else {
bt, err = p.getStructType(f.Type())
if err != nil {
return nil, err
}
}
}
return nil, fmt.Errorf("gosl pathType: path not a pure selector path")
}
// gosl: check if identifier is a pointer arg
func (p *printer) isPtrArg(id *ast.Ident) bool {
for _, pt := range p.curPtrArgs {
if id.Name == pt.Name {
return true
}
}
return false
}
// gosl: dereference pointer vals
func (p *printer) derefPtrArgs(x ast.Expr, prec, depth int) {
if id, ok := x.(*ast.Ident); ok {
if p.isPtrArg(id) {
p.print(token.LPAREN, token.MUL, id, token.RPAREN)
} else {
p.expr1(x, prec, depth)
}
} else {
p.expr1(x, prec, depth)
}
}
// gosl: mark pointer param types (only for non-struct), returns true if pointer
func (p *printer) ptrParamType(x ast.Expr) (ast.Expr, bool) {
if u, ok := x.(*ast.StarExpr); ok {
switch pt := u.X.(type) {
case *ast.Ident:
typ := p.getIdType(pt)
if typ != nil {
if _, ok := typ.Underlying().(*types.Struct); ok {
tn := getLocalTypeName(typ)
pi := strings.Index(tn, ".")
if pi > 0 {
tn = tn[pi+1:]
}
// fmt.Printf("struct typ: %s\n", tn)
if _, ok := p.GoToSL.VarStructTypes[tn]; ok {
return u.X, false // no pointer, else ok
}
}
}
p.print("ptr<function", token.COMMA)
return u.X, true
case *ast.SelectorExpr:
if id, ok := pt.X.(*ast.Ident); ok {
if id.Name == "math32" {
p.print("ptr<function", token.COMMA)
return u.X, true
}
}
default:
fmt.Println("ERROR: unrecognized pointer type -- can only have pointers to structs and vector types", pt)
}
}
return x, false
}
// gosl: don't use pointers for method receivers
func (p *printer) methRecvPtrType(x ast.Expr, recvnm *ast.Ident) (ast.Expr, bool) {
if u, ok := x.(*ast.StarExpr); ok {
isptr := p.isPtrArg(recvnm)
if isptr {
p.print("ptr<function", token.COMMA)
}
return u.X, isptr
}
return x, false
}
// gosl: printMethRecv prints the method recv prefix for function. returns true if recv is ptr
func (p *printer) printMethRecv() (isPtr bool, typnm string) {
if u, ok := p.curMethRecv.Type.(*ast.StarExpr); ok {
typnm = u.X.(*ast.Ident).Name
isPtr = true
} else {
typnm = p.curMethRecv.Type.(*ast.Ident).Name
}
return
}
// combinesWithName reports whether a name followed by the expression x
// syntactically combines to another valid (value) expression. For instance
// using *T for x, "name *T" syntactically appears as the expression x*T.
// On the other hand, using P|Q or *P|~Q for x, "name P|Q" or name *P|~Q"
// cannot be combined into a valid (value) expression.
func combinesWithName(x ast.Expr) bool {
switch x := x.(type) {
case *ast.StarExpr:
// name *x.X combines to name*x.X if x.X is not a type element
return !isTypeElem(x.X)
case *ast.BinaryExpr:
return combinesWithName(x.X) && !isTypeElem(x.Y)
case *ast.ParenExpr:
// name(x) combines but we are making sure at
// the call site that x is never parenthesized.
panic("unexpected parenthesized expression")
}
return false
}
// isTypeElem reports whether x is a (possibly parenthesized) type element expression.
// The result is false if x could be a type element OR an ordinary (value) expression.
func isTypeElem(x ast.Expr) bool {
switch x := x.(type) {
case *ast.ArrayType, *ast.StructType, *ast.FuncType, *ast.InterfaceType, *ast.MapType, *ast.ChanType:
return true
case *ast.UnaryExpr:
return x.Op == token.TILDE
case *ast.BinaryExpr:
return isTypeElem(x.X) || isTypeElem(x.Y)
case *ast.ParenExpr:
return isTypeElem(x.X)
}
return false
}
func (p *printer) signature(sig *ast.FuncType, recv *ast.FieldList) {
if sig.TypeParams != nil {
p.parameters(sig.TypeParams, funcTParam)
}
if sig.Params != nil {
if recv != nil {
flist := &ast.FieldList{}
*flist = *recv
flist.List = append(flist.List, sig.Params.List...)
p.parameters(flist, funcParam)
} else {
p.parameters(sig.Params, funcParam)
}
} else if recv != nil {
p.parameters(recv, funcParam)
} else {
p.print(token.LPAREN, token.RPAREN)
}
res := sig.Results
n := res.NumFields()
if n > 0 {
// res != nil
if id, ok := res.List[0].Type.(*ast.Ident); ok {
p.curReturnType = id
}
p.print(blank, "->", blank)
if n == 1 && res.List[0].Names == nil {
// single anonymous res; no ()'s
p.expr(stripParensAlways(res.List[0].Type))
return
}
p.parameters(res, funcParam)
}
}
func identListSize(list []*ast.Ident, maxSize int) (size int) {
for i, x := range list {
if i > 0 {
size += len(", ")
}
size += utf8.RuneCountInString(x.Name)
if size >= maxSize {
break
}
}
return
}
func (p *printer) isOneLineFieldList(list []*ast.Field) bool {
if len(list) != 1 {
return false // allow only one field
}
f := list[0]
if f.Tag != nil || f.Comment != nil {
return false // don't allow tags or comments
}
// only name(s) and type
const maxSize = 30 // adjust as appropriate, this is an approximate value
namesSize := identListSize(f.Names, maxSize)
if namesSize > 0 {
namesSize = 1 // blank between names and types
}
typeSize := p.nodeSize(f.Type, maxSize)
return namesSize+typeSize <= maxSize
}
func (p *printer) setLineComment(text string) {
p.setComment(&ast.CommentGroup{List: []*ast.Comment{{Slash: token.NoPos, Text: text}}})
}
func (p *printer) fieldList(fields *ast.FieldList, isStruct, isIncomplete bool) {
lbrace := fields.Opening
list := fields.List
rbrace := fields.Closing
hasComments := isIncomplete || p.commentBefore(p.posFor(rbrace))
srcIsOneLine := lbrace.IsValid() && rbrace.IsValid() && p.lineFor(lbrace) == p.lineFor(rbrace)
if !hasComments && srcIsOneLine {
// possibly a one-line struct/interface
if len(list) == 0 {
// no blank between keyword and {} in this case
p.setPos(lbrace)
p.print(token.LBRACE)
p.setPos(rbrace)
p.print(token.RBRACE)
return
} else if p.isOneLineFieldList(list) {
// small enough - print on one line
// (don't use identList and ignore source line breaks)
p.setPos(lbrace)
p.print(token.LBRACE, blank)
f := list[0]
if isStruct {
for i, x := range f.Names {
if i > 0 {
// no comments so no need for comma position
p.print(token.COMMA, blank)
}
p.expr(x)
}
p.print(token.COLON)
if len(f.Names) > 0 {
p.print(blank)
}
p.expr(f.Type)
} else { // interface
if len(f.Names) > 0 {
name := f.Names[0] // method name
p.expr(name)
p.print(token.COLON)
p.signature(f.Type.(*ast.FuncType), nil) // don't print "func"
} else {
// embedded interface
p.expr(f.Type)
}
}
p.print(blank)
p.setPos(rbrace)
p.print(token.RBRACE)
return
}
}
// hasComments || !srcIsOneLine
p.print(blank)
p.setPos(lbrace)
p.print(token.LBRACE, indent)
if hasComments || len(list) > 0 {
p.print(formfeed)
}
if isStruct {
sep := vtab
if len(list) == 1 {
sep = blank
}
var line int
for i, f := range list {
if i > 0 {
p.linebreak(p.lineFor(f.Pos()), 1, ignore, p.linesFrom(line) > 0)
}
extraTabs := 0
p.setComment(f.Doc)
p.recordLine(&line)
if len(f.Names) > 1 {
nnm := len(f.Names)
p.setPos(f.Type.Pos())
for ni, nm := range f.Names {
p.print(nm.Name)
p.print(token.COLON)
p.print(sep)
p.expr(f.Type)
if ni < nnm-1 {
p.print(token.COMMA)
p.print(formfeed)
}
}
extraTabs = 1
} else if len(f.Names) > 0 {
// named fields
p.identList(f.Names, false)
p.print(token.COLON)
p.print(sep)
p.expr(f.Type)
extraTabs = 1
} else {
// anonymous field
p.expr(f.Type)
extraTabs = 2
}
p.print(token.COMMA)
// if f.Tag != nil {
// if len(f.Names) > 0 && sep == vtab {
// p.print(sep)
// }
// p.print(sep)
// p.expr(f.Tag)
// extraTabs = 0
// }
if f.Comment != nil {
for ; extraTabs > 0; extraTabs-- {
p.print(sep)
}
p.setComment(f.Comment)
}
}
if isIncomplete {
if len(list) > 0 {
p.print(formfeed)
}
p.flush(p.posFor(rbrace), token.RBRACE) // make sure we don't lose the last line comment
p.setLineComment("// " + filteredMsg)
}
} else { // interface
var line int
var prev *ast.Ident // previous "type" identifier
for i, f := range list {
var name *ast.Ident // first name, or nil
if len(f.Names) > 0 {
name = f.Names[0]
}
if i > 0 {
// don't do a line break (min == 0) if we are printing a list of types
// TODO(gri) this doesn't work quite right if the list of types is
// spread across multiple lines
min := 1
if prev != nil && name == prev {
min = 0
}
p.linebreak(p.lineFor(f.Pos()), min, ignore, p.linesFrom(line) > 0)
}
p.setComment(f.Doc)
p.recordLine(&line)
if name != nil {
// method
p.expr(name)
p.signature(f.Type.(*ast.FuncType), nil) // don't print "func"
prev = nil
} else {
// embedded interface
p.expr(f.Type)
prev = nil
}
p.setComment(f.Comment)
}
if isIncomplete {
if len(list) > 0 {
p.print(formfeed)
}
p.flush(p.posFor(rbrace), token.RBRACE) // make sure we don't lose the last line comment
p.setLineComment("// contains filtered or unexported methods")
}
}
p.print(unindent, formfeed)
p.setPos(rbrace)
p.print(token.RBRACE)
}
// ----------------------------------------------------------------------------
// Expressions
func walkBinary(e *ast.BinaryExpr) (has4, has5 bool, maxProblem int) {
switch e.Op.Precedence() {
case 4:
has4 = true
case 5:
has5 = true
}
switch l := e.X.(type) {
case *ast.BinaryExpr:
if l.Op.Precedence() < e.Op.Precedence() {
// parens will be inserted.
// pretend this is an *ast.ParenExpr and do nothing.
break
}
h4, h5, mp := walkBinary(l)
has4 = has4 || h4
has5 = has5 || h5
maxProblem = max(maxProblem, mp)
}
switch r := e.Y.(type) {
case *ast.BinaryExpr:
if r.Op.Precedence() <= e.Op.Precedence() {
// parens will be inserted.
// pretend this is an *ast.ParenExpr and do nothing.
break
}
h4, h5, mp := walkBinary(r)
has4 = has4 || h4
has5 = has5 || h5
maxProblem = max(maxProblem, mp)
case *ast.StarExpr:
if e.Op == token.QUO { // `*/`
maxProblem = 5
}
case *ast.UnaryExpr:
switch e.Op.String() + r.Op.String() {
case "/*", "&&", "&^":
maxProblem = 5
case "++", "--":
maxProblem = max(maxProblem, 4)
}
}
return
}
func cutoff(e *ast.BinaryExpr, depth int) int {
has4, has5, maxProblem := walkBinary(e)
if maxProblem > 0 {
return maxProblem + 1
}
if has4 && has5 {
if depth == 1 {
return 5
}
return 4
}
if depth == 1 {
return 6
}
return 4
}
func diffPrec(expr ast.Expr, prec int) int {
x, ok := expr.(*ast.BinaryExpr)
if !ok || prec != x.Op.Precedence() {
return 1
}
return 0
}
func reduceDepth(depth int) int {
depth--
if depth < 1 {
depth = 1
}
return depth
}
// Format the binary expression: decide the cutoff and then format.
// Let's call depth == 1 Normal mode, and depth > 1 Compact mode.
// (Algorithm suggestion by Russ Cox.)
//
// The precedences are:
//
// 5 * / % << >> & &^
// 4 + - | ^
// 3 == != < <= > >=
// 2 &&
// 1 ||
//
// The only decision is whether there will be spaces around levels 4 and 5.
// There are never spaces at level 6 (unary), and always spaces at levels 3 and below.
//
// To choose the cutoff, look at the whole expression but excluding primary
// expressions (function calls, parenthesized exprs), and apply these rules:
//
// 1. If there is a binary operator with a right side unary operand
// that would clash without a space, the cutoff must be (in order):
//
// /* 6
// && 6
// &^ 6
// ++ 5
// -- 5
//
// (Comparison operators always have spaces around them.)
//
// 2. If there is a mix of level 5 and level 4 operators, then the cutoff
// is 5 (use spaces to distinguish precedence) in Normal mode
// and 4 (never use spaces) in Compact mode.
//
// 3. If there are no level 4 operators or no level 5 operators, then the
// cutoff is 6 (always use spaces) in Normal mode
// and 4 (never use spaces) in Compact mode.
func (p *printer) binaryExpr(x *ast.BinaryExpr, prec1, cutoff, depth int) {
prec := x.Op.Precedence()
if prec < prec1 {
// parenthesis needed
// Note: The gotoslr inserts an ast.ParenExpr node; thus this case
// can only occur if the AST is created in a different way.
p.print(token.LPAREN)
p.expr0(x, reduceDepth(depth)) // parentheses undo one level of depth
p.print(token.RPAREN)
return
}
printBlank := prec < cutoff
ws := indent
p.expr1(x.X, prec, depth+diffPrec(x.X, prec))
if printBlank {
p.print(blank)
}
xline := p.pos.Line // before the operator (it may be on the next line!)
yline := p.lineFor(x.Y.Pos())
p.setPos(x.OpPos)
if x.Op == token.AND_NOT {
p.print(token.AND, blank, token.TILDE)
} else {
p.print(x.Op)
}
if xline != yline && xline > 0 && yline > 0 {
// at least one line break, but respect an extra empty line
// in the source
if p.linebreak(yline, 1, ws, true) > 0 {
ws = ignore
printBlank = false // no blank after line break
}
}
if printBlank {
p.print(blank)
}
p.expr1(x.Y, prec+1, depth+1)
if ws == ignore {
p.print(unindent)
}
}
func isBinary(expr ast.Expr) bool {
_, ok := expr.(*ast.BinaryExpr)
return ok
}
func (p *printer) expr1(expr ast.Expr, prec1, depth int) {
p.setPos(expr.Pos())
switch x := expr.(type) {
case *ast.BadExpr:
p.print("BadExpr")
case *ast.Ident:
if x.Name == "int" {
p.print("i32")
} else {
p.print(x)
}
case *ast.BinaryExpr:
if depth < 1 {
p.internalError("depth < 1:", depth)
depth = 1
}
p.binaryExpr(x, prec1, cutoff(x, depth), depth)
case *ast.KeyValueExpr:
p.expr(x.Key)
p.setPos(x.Colon)
p.print(token.COLON, blank)
p.expr(x.Value)
case *ast.StarExpr:
const prec = token.UnaryPrec
if prec < prec1 {
// parenthesis needed
p.print(token.LPAREN)
p.print(token.MUL)
p.expr(x.X)
p.print(token.RPAREN)
} else {
// no parenthesis needed
p.print(token.MUL)
p.expr(x.X)
}
case *ast.UnaryExpr:
const prec = token.UnaryPrec
if prec < prec1 {
// parenthesis needed
p.print(token.LPAREN)
p.expr(x)
p.print(token.RPAREN)
} else {
// no parenthesis needed
p.print(x.Op)
if x.Op == token.RANGE {
// TODO(gri) Remove this code if it cannot be reached.
p.print(blank)
}
p.expr1(x.X, prec, depth)
}
case *ast.BasicLit:
if p.PrintConfig.Mode&normalizeNumbers != 0 {
x = normalizedNumber(x)
}
p.print(x)
case *ast.FuncLit:
p.setPos(x.Type.Pos())
p.print(token.FUNC)
// See the comment in funcDecl about how the header size is computed.
startCol := p.out.Column - len("func")
p.signature(x.Type, nil)
p.funcBody(p.distanceFrom(x.Type.Pos(), startCol), blank, x.Body)
case *ast.ParenExpr:
if _, hasParens := x.X.(*ast.ParenExpr); hasParens {
// don't print parentheses around an already parenthesized expression
// TODO(gri) consider making this more general and incorporate precedence levels
p.expr0(x.X, depth)
} else {
p.print(token.LPAREN)
p.expr0(x.X, reduceDepth(depth)) // parentheses undo one level of depth
p.setPos(x.Rparen)
p.print(token.RPAREN)
}
case *ast.SelectorExpr:
p.selectorExpr(x, depth)
case *ast.TypeAssertExpr:
p.expr1(x.X, token.HighestPrec, depth)
p.print(token.PERIOD)
p.setPos(x.Lparen)
p.print(token.LPAREN)
if x.Type != nil {
p.expr(x.Type)
} else {
p.print(token.TYPE)
}
p.setPos(x.Rparen)
p.print(token.RPAREN)
case *ast.IndexExpr:
// TODO(gri): should treat[] like parentheses and undo one level of depth
p.globalVarBasic(x)
p.expr1(x.X, token.HighestPrec, 1)
p.setPos(x.Lbrack)
p.print(token.LBRACK)
p.expr0(x.Index, depth+1)
p.setPos(x.Rbrack)
p.print(token.RBRACK)
case *ast.IndexListExpr:
// TODO(gri): as for IndexExpr, should treat [] like parentheses and undo
// one level of depth
p.expr1(x.X, token.HighestPrec, 1)
p.setPos(x.Lbrack)
p.print(token.LBRACK)
p.exprList(x.Lbrack, x.Indices, depth+1, commaTerm, x.Rbrack, false)
p.setPos(x.Rbrack)
p.print(token.RBRACK)
case *ast.SliceExpr:
// TODO(gri): should treat[] like parentheses and undo one level of depth
p.expr1(x.X, token.HighestPrec, 1)
p.setPos(x.Lbrack)
p.print(token.LBRACK)
indices := []ast.Expr{x.Low, x.High}
if x.Max != nil {
indices = append(indices, x.Max)
}
// determine if we need extra blanks around ':'
var needsBlanks bool
if depth <= 1 {
var indexCount int
var hasBinaries bool
for _, x := range indices {
if x != nil {
indexCount++
if isBinary(x) {
hasBinaries = true
}
}
}
if indexCount > 1 && hasBinaries {
needsBlanks = true
}
}
for i, x := range indices {
if i > 0 {
if indices[i-1] != nil && needsBlanks {
p.print(blank)
}
p.print(token.COLON)
if x != nil && needsBlanks {
p.print(blank)
}
}
if x != nil {
p.expr0(x, depth+1)
}
}
p.setPos(x.Rbrack)
p.print(token.RBRACK)
case *ast.CallExpr:
if len(x.Args) > 1 {
depth++
}
fid, isid := x.Fun.(*ast.Ident)
// Conversions to literal function types or <-chan
// types require parentheses around the type.
paren := false
switch t := x.Fun.(type) {
case *ast.FuncType:
paren = true
case *ast.ChanType:
paren = t.Dir == ast.RECV
}
if paren {
p.print(token.LPAREN)
}
if _, ok := x.Fun.(*ast.SelectorExpr); ok {
p.methodExpr(x, depth)
break // handles everything, break out of case
}
args := x.Args
var rwargs []rwArg
if isid {
// fmt.Println("start call:", fid.Name, p.curFunc)
if p.curFunc != nil {
p.curFunc.Funcs[fid.Name] = p.GoToSL.RecycleFunc(fid.Name)
}
if obj, ok := p.pkg.TypesInfo.Uses[fid]; ok {
if ft, ok := obj.(*types.Func); ok {
sig := ft.Type().(*types.Signature)
args, rwargs = p.goslFixArgs(x.Args, sig.Params())
}
}
}
p.expr1(x.Fun, token.HighestPrec, depth)
if paren {
p.print(token.RPAREN)
}
p.setPos(x.Lparen)
p.print(token.LPAREN)
if x.Ellipsis.IsValid() {
p.exprList(x.Lparen, args, depth, 0, x.Ellipsis, false)
p.setPos(x.Ellipsis)
p.print(token.ELLIPSIS)
if x.Rparen.IsValid() && p.lineFor(x.Ellipsis) < p.lineFor(x.Rparen) {
p.print(token.COMMA, formfeed)
}
} else {
p.exprList(x.Lparen, args, depth, commaTerm, x.Rparen, false)
}
p.setPos(x.Rparen)
p.print(token.RPAREN)
p.assignRwArgs(rwargs)
// fmt.Println("call:", x.Fun, p.curFunc)
case *ast.CompositeLit:
// composite literal elements that are composite literals themselves may have the type omitted
lb := token.LBRACE
rb := token.RBRACE
if _, isAry := x.Type.(*ast.ArrayType); isAry {
lb = token.LPAREN
rb = token.RPAREN
}
if x.Type != nil {
p.expr1(x.Type, token.HighestPrec, depth)
}
p.level++
p.setPos(x.Lbrace)
p.print(lb)
p.exprList(x.Lbrace, x.Elts, 1, commaTerm, x.Rbrace, x.Incomplete)
// do not insert extra line break following a /*-style comment
// before the closing '}' as it might break the code if there
// is no trailing ','
mode := noExtraLinebreak
// do not insert extra blank following a /*-style comment
// before the closing '}' unless the literal is empty
if len(x.Elts) > 0 {
mode |= noExtraBlank
}
// need the initial indent to print lone comments with
// the proper level of indentation
p.print(indent, unindent, mode)
p.setPos(x.Rbrace)
p.print(rb, mode)
p.level--
case *ast.Ellipsis:
p.print(token.ELLIPSIS)
if x.Elt != nil {
p.expr(x.Elt)
}
case *ast.ArrayType:
p.print("array")
// p.print(token.LBRACK)
// if x.Len != nil {
// p.expr(x.Len)
// }
// p.print(token.RBRACK)
// p.expr(x.Elt)
case *ast.StructType:
// p.print(token.STRUCT)
p.fieldList(x.Fields, true, x.Incomplete)
case *ast.FuncType:
p.print(token.FUNC)
p.signature(x, nil)
case *ast.InterfaceType:
p.print(token.INTERFACE)
p.fieldList(x.Methods, false, x.Incomplete)
case *ast.MapType:
p.print(token.MAP, token.LBRACK)
p.expr(x.Key)
p.print(token.RBRACK)
p.expr(x.Value)
case *ast.ChanType:
switch x.Dir {
case ast.SEND | ast.RECV:
p.print(token.CHAN)
case ast.RECV:
p.print(token.ARROW, token.CHAN) // x.Arrow and x.Pos() are the same
case ast.SEND:
p.print(token.CHAN)
p.setPos(x.Arrow)
p.print(token.ARROW)
}
p.print(blank)
p.expr(x.Value)
default:
panic("unreachable")
}
}
// normalizedNumber rewrites base prefixes and exponents
// of numbers to use lower-case letters (0X123 to 0x123 and 1.2E3 to 1.2e3),
// and removes leading 0's from integer imaginary literals (0765i to 765i).
// It leaves hexadecimal digits alone.
//
// normalizedNumber doesn't modify the ast.BasicLit value lit points to.
// If lit is not a number or a number in canonical format already,
// lit is returned as is. Otherwise a new ast.BasicLit is created.
func normalizedNumber(lit *ast.BasicLit) *ast.BasicLit {
if lit.Kind != token.INT && lit.Kind != token.FLOAT && lit.Kind != token.IMAG {
return lit // not a number - nothing to do
}
if len(lit.Value) < 2 {
return lit // only one digit (common case) - nothing to do
}
// len(lit.Value) >= 2
// We ignore lit.Kind because for lit.Kind == token.IMAG the literal may be an integer
// or floating-point value, decimal or not. Instead, just consider the literal pattern.
x := lit.Value
switch x[:2] {
default:
// 0-prefix octal, decimal int, or float (possibly with 'i' suffix)
if i := strings.LastIndexByte(x, 'E'); i >= 0 {
x = x[:i] + "e" + x[i+1:]
break
}
// remove leading 0's from integer (but not floating-point) imaginary literals
if x[len(x)-1] == 'i' && !strings.ContainsAny(x, ".e") {
x = strings.TrimLeft(x, "0_")
if x == "i" {
x = "0i"
}
}
case "0X":
x = "0x" + x[2:]
// possibly a hexadecimal float
if i := strings.LastIndexByte(x, 'P'); i >= 0 {
x = x[:i] + "p" + x[i+1:]
}
case "0x":
// possibly a hexadecimal float
i := strings.LastIndexByte(x, 'P')
if i == -1 {
return lit // nothing to do
}
x = x[:i] + "p" + x[i+1:]
case "0O":
x = "0o" + x[2:]
case "0o":
return lit // nothing to do
case "0B":
x = "0b" + x[2:]
case "0b":
return lit // nothing to do
}
return &ast.BasicLit{ValuePos: lit.ValuePos, Kind: lit.Kind, Value: x}
}
// selectorExpr handles an *ast.SelectorExpr node and reports whether x spans
// multiple lines, and thus was indented.
func (p *printer) selectorExpr(x *ast.SelectorExpr, depth int) (wasIndented bool) {
p.derefPtrArgs(x.X, token.HighestPrec, depth)
p.print(token.PERIOD)
if line := p.lineFor(x.Sel.Pos()); p.pos.IsValid() && p.pos.Line < line {
p.print(indent, newline)
p.setPos(x.Sel.Pos())
p.print(x.Sel)
p.print(unindent)
return true
}
p.setPos(x.Sel.Pos())
if x.Sel.Name == "X" || x.Sel.Name == "Y" || x.Sel.Name == "Z" || x.Sel.Name == "W" {
p.print(strings.ToLower(x.Sel.Name))
} else {
p.print(x.Sel)
}
return false
}
func (p *printer) selectorPath(x *ast.SelectorExpr) (recvPath, recvType string, pathType types.Type, paths []string, bt *types.Struct, err error) {
var baseRecv *ast.Ident // first receiver in path
cur := x
for {
paths = append(paths, cur.Sel.Name)
if sl, ok := cur.X.(*ast.SelectorExpr); ok { // path is itself a selector
cur = sl
continue
}
if id, ok := cur.X.(*ast.Ident); ok {
baseRecv = id
break
}
err = fmt.Errorf("gosl methodPath ERROR: path for method call must be simple list of fields, not %#v:", cur.X)
p.userError(err)
return
}
recvPath = baseRecv.Name
var idt types.Type
gvar := p.GoToSL.GetTempVar(baseRecv.Name)
if gvar != nil {
idt = p.getTypeNameType(gvar.Var.SLType())
} else {
idt = p.getIdType(baseRecv)
}
if idt == nil || typeIsInvalid(idt) {
err = fmt.Errorf("gosl methodPath ERROR: cannot find type for name: %q, gvar: %v", baseRecv.Name, gvar)
panic(err)
// p.userError(err)
return
}
bt, err = p.getStructType(idt)
if err != nil {
fmt.Println(baseRecv)
return
}
return
}
// gosl: methodExpr needs to deal with possible multiple chains of selector exprs
// to determine the actual type and name of the receiver.
// a.b.c() -> sel.X = (a.b) Sel=c
func (p *printer) methodPath(x *ast.SelectorExpr) (recvPath, recvType string, pathType types.Type, err error) {
var paths []string
var bt *types.Struct
recvPath, recvType, pathType, paths, bt, err = p.selectorPath(x)
curt := bt
np := len(paths)
for pi := np - 1; pi >= 0; pi-- {
pth := paths[pi]
recvPath += "." + pth
f := fieldByName(curt, pth)
if f == nil {
err = fmt.Errorf("gosl ERROR: field not found %q in type: %q:", pth, curt.String())
p.userError(err)
return
}
if pi == 0 {
pathType = f.Type()
recvType = getLocalTypeName(f.Type())
} else {
curt, err = p.getStructType(f.Type())
if err != nil {
return
}
}
}
return
}
func fieldByName(st *types.Struct, name string) *types.Var {
nf := st.NumFields()
for i := range nf {
f := st.Field(i)
if f.Name() == name {
return f
}
}
return nil
}
func (p *printer) getIdType(id *ast.Ident) types.Type {
if obj, ok := p.pkg.TypesInfo.Uses[id]; ok {
return obj.Type()
}
return nil
}
func (p *printer) getTypeNameType(typeName string) types.Type {
obj := p.pkg.Types.Scope().Lookup(typeName)
if obj != nil {
return obj.Type()
}
return nil
}
func getLocalTypeName(typ types.Type) string {
_, nm := path.Split(typ.String())
return nm
}
func typeIsInvalid(typ types.Type) bool {
if b, ok := typ.(*types.Basic); ok {
if b.Kind() == types.Invalid {
return true
}
}
return false
}
func (p *printer) getStructType(typ types.Type) (*types.Struct, error) {
utyp := typ.Underlying()
switch x := utyp.(type) {
case *types.Struct:
return x, nil
case *types.Pointer:
return p.getStructType(x.Elem())
case *types.Slice:
return p.getStructType(x.Elem())
case *types.Basic:
fmt.Println("basic kind:", x.String())
}
err := fmt.Errorf("gosl ERROR: type is not a struct and it should be: %q %+T %+T", typ.String(), typ, utyp)
panic(err)
// p.userError(err)
return nil, err
}
func (p *printer) getNamedType(typ types.Type) (*types.Named, error) {
if nmd, ok := typ.(*types.Named); ok {
return nmd, nil
}
typ = typ.Underlying()
if ptr, ok := typ.(*types.Pointer); ok {
typ = ptr.Elem()
if nmd, ok := typ.(*types.Named); ok {
return nmd, nil
}
}
if sl, ok := typ.(*types.Slice); ok {
typ = sl.Elem()
if nmd, ok := typ.(*types.Named); ok {
return nmd, nil
}
}
err := fmt.Errorf("gosl ERROR: type is not a named type: %q %+t", typ.String(), typ)
p.userError(err)
return nil, err
}
func (p *printer) globalVarBasic(idx *ast.IndexExpr) {
id, ok := idx.X.(*ast.Ident)
if !ok {
if sel, ok := idx.X.(*ast.SelectorExpr); ok {
if sel.Sel.Name != "Values" {
return
}
id, ok = sel.X.(*ast.Ident)
if !ok {
return
}
// fall through with this..
} else {
return
}
}
st := p.GoToSL
gvr := st.GlobalVar(id.Name)
if gvr == nil {
return
}
if p.curFunc != nil {
p.curFunc.AddVarUsed(gvr)
if p.curMethIsAtomic {
p.curFunc.AddAtomic(gvr)
}
}
}
// gosl: globalVar looks up whether the id in an IndexExpr is a global gosl variable.
// in which case it returns a temp variable name to use, and the type info.
func (p *printer) globalVar(idx *ast.IndexExpr) (isGlobal bool, tmpVar, typName string, vtyp types.Type, isReadOnly bool) {
id, ok := idx.X.(*ast.Ident)
if !ok {
return
}
st := p.GoToSL
gvr := st.GlobalVar(id.Name)
if gvr == nil {
return
}
if p.curFunc != nil {
p.curFunc.AddVarUsed(gvr)
}
isGlobal = true
isReadOnly = gvr.ReadOnly
if st.VarIsReadWrite(id.Name) {
isReadOnly = false
}
tmpVar = id.Name
tmpVar = strings.ToLower(id.Name)
vtyp = p.getIdType(id)
if vtyp == nil {
err := fmt.Errorf("gosl globalVar ERROR: cannot find type for name: %q", id.Name)
p.userError(err)
return
}
nmd, err := p.getNamedType(vtyp)
if err == nil {
vtyp = nmd
}
typName = gvr.SLType()
p.print("let ", tmpVar, token.ASSIGN)
p.expr(idx)
p.print(token.SEMICOLON, blank)
if !isReadOnly {
tmpVar = "&" + tmpVar
}
return
}
// gosl: replace GetVar function call with assignment of local var
func (p *printer) getGlobalVar(ae *ast.AssignStmt, gvr *Var) {
st := p.GoToSL
tmpVar := ae.Lhs[0].(*ast.Ident).Name
cf := ae.Rhs[0].(*ast.CallExpr)
ro := gvr.ReadOnly
rwoverride := false
if ro {
if st.VarIsReadWrite(gvr.Name) {
ro = false
rwoverride = true
}
}
if ro {
p.print("let", blank, tmpVar, blank, token.ASSIGN, blank, gvr.Name, token.LBRACK)
} else {
p.print("var", blank, tmpVar, blank, token.ASSIGN, blank, gvr.Name, token.LBRACK)
}
p.expr(cf.Args[0])
p.print(token.RBRACK, token.SEMICOLON)
gvars := p.GoToSL.GetVarStack.Peek()
if gvars != nil {
gvars[tmpVar] = &GetGlobalVar{Var: gvr, TmpVar: tmpVar, IdxExpr: cf.Args[0], ReadWrite: rwoverride}
p.GoToSL.GetVarStack[len(p.GoToSL.GetVarStack)-1] = gvars
}
}
// gosl: set non-read-only global vars back from temp var
func (p *printer) setGlobalVars(gvrs map[string]*GetGlobalVar) {
for _, gvr := range gvrs {
if gvr.Var.ReadOnly && !gvr.ReadWrite {
continue
}
p.print(formfeed, "\t")
p.print(gvr.Var.Name, token.LBRACK)
p.expr(gvr.IdxExpr)
p.print(token.RBRACK, blank, token.ASSIGN, blank)
p.print(gvr.TmpVar)
p.print(token.SEMICOLON, blank)
}
}
// gosl: methodIndex processes an index expression as receiver type of method call
func (p *printer) methodIndex(idx *ast.IndexExpr) (recvPath, recvType string, pathType types.Type, isReadOnly bool, err error) {
id, ok := idx.X.(*ast.Ident)
if !ok {
err = fmt.Errorf("gosl methodIndex ERROR: must have a recv variable identifier, not %#v:", idx.X)
p.userError(err)
return
}
isGlobal, tmpVar, typName, vtyp, isReadOnly := p.globalVar(idx)
if isGlobal {
recvPath = tmpVar
recvType = typName
pathType = vtyp
} else {
_ = id
// do above
}
return
}
func (p *printer) tensorMethod(x *ast.CallExpr, vr *Var, methName string) {
args := x.Args
gv := p.GoToSL.GlobalVar(vr.Name)
if gv != nil && p.curFunc != nil {
p.curFunc.AddVarUsed(vr)
}
stArg := 0
if strings.HasPrefix(methName, "Set") {
stArg = 1
}
if strings.HasSuffix(methName, "Ptr") {
p.print(token.AND)
gv := p.GoToSL.GlobalVar(vr.Name)
if gv != nil && p.curFunc != nil && p.curMethIsAtomic {
p.curFunc.AddAtomic(vr)
}
}
if vr.NBuffs > 1 {
if stArg == 0 {
p.print(vr.Name+"Get", token.LPAREN)
} else {
p.print(vr.Name+methName, token.LPAREN)
p.expr(args[0])
p.print(token.COMMA, blank)
}
} else {
p.print(vr.Name, token.LBRACK)
}
p.print(vr.IndexFunc(), token.LPAREN)
nd := vr.TensorDims
for d := range nd {
p.print(vr.IndexStride(d), token.COMMA, blank)
}
n := len(args)
for i := stArg; i < n; i++ {
ag := args[i]
p.print("u32", token.LPAREN)
if ce, ok := ag.(*ast.CallExpr); ok { // get rid of int() wrapper from goal n-dim index
if fn, ok := ce.Fun.(*ast.Ident); ok {
if fn.Name == "int" {
ag = ce.Args[0]
}
}
}
p.expr(ag)
p.print(token.RPAREN)
if i < n-1 {
p.print(token.COMMA, blank)
}
}
if vr.NBuffs > 1 {
p.print(token.RPAREN, token.RPAREN)
} else {
p.print(token.RPAREN, token.RBRACK)
if strings.HasPrefix(methName, "Set") {
opnm := strings.TrimPrefix(methName, "Set")
tok := token.ASSIGN
switch opnm {
case "Add":
tok = token.ADD_ASSIGN
case "Sub":
tok = token.SUB_ASSIGN
case "Mul":
tok = token.MUL_ASSIGN
case "Div":
tok = token.QUO_ASSIGN
}
p.print(blank, tok, blank)
p.expr(args[0])
}
}
}
func (p *printer) methodExpr(x *ast.CallExpr, depth int) {
path := x.Fun.(*ast.SelectorExpr) // we know fun is selector
methName := path.Sel.Name
recvPath := ""
recvType := ""
var err error
pathIsPackage := false
var rwargs []rwArg
var pathType types.Type
if sl, ok := path.X.(*ast.SelectorExpr); ok { // path is itself a selector
recvPath, recvType, pathType, err = p.methodPath(sl)
if err != nil {
return
}
} else if id, ok := path.X.(*ast.Ident); ok {
gvr := p.GoToSL.GlobalVar(id.Name)
if gvr != nil && gvr.Tensor {
p.tensorMethod(x, gvr, methName)
return
}
recvPath = id.Name
typ := p.getIdType(id)
if typ != nil {
recvType = getLocalTypeName(typ)
if strings.HasPrefix(recvType, "invalid") {
if gvar := p.GoToSL.GetTempVar(id.Name); gvar != nil {
recvType = gvar.Var.SLType()
if !(gvar.Var.ReadOnly && !gvar.ReadWrite) {
recvPath = "&" + recvPath
}
pathType = p.getTypeNameType(gvar.Var.SLType())
} else {
pathIsPackage = true
recvType = id.Name // is a package path
}
} else {
if gvar := p.GoToSL.GetTempVar(id.Name); gvar != nil {
recvType = gvar.Var.SLType()
if !(gvar.Var.ReadOnly && !gvar.ReadWrite) {
recvPath = "&" + recvPath
}
pathType = p.getTypeNameType(gvar.Var.SLType())
} else {
pathType = typ
recvPath = recvPath
}
}
} else {
pathIsPackage = true
recvType = id.Name // is a package path
}
} else if idx, ok := path.X.(*ast.IndexExpr); ok {
isReadOnly := false
recvPath, recvType, pathType, isReadOnly, err = p.methodIndex(idx)
if err != nil {
return
}
if !isReadOnly {
rwargs = append(rwargs, rwArg{idx: idx, tmpVar: recvPath})
}
} else {
// fmt.Println("arg issue with:", methName)
// err := fmt.Errorf("gosl methodExpr ERROR: path expression for method call must be simple list of fields, not %#v:", path.X)
// p.userError(err)
// return
}
args := x.Args
if pathType != nil {
meth, _, _ := types.LookupFieldOrMethod(pathType, true, p.pkg.Types, methName)
if meth != nil {
if ft, ok := meth.(*types.Func); ok {
sig := ft.Type().(*types.Signature)
var rwa []rwArg
args, rwa = p.goslFixArgs(x.Args, sig.Params())
rwargs = append(rwargs, rwa...)
}
}
if len(rwargs) > 0 {
p.print(formfeed)
}
}
// fmt.Println(pathIsPackage, recvType, methName, recvPath)
if p.mathMeth(x, depth, methName, recvPath, recvType) {
return
}
if pathIsPackage {
if recvType == "atomic" || recvType == "atomicx" {
p.curMethIsAtomic = true
switch {
case strings.HasPrefix(methName, "Add"):
p.print("atomicAdd")
case strings.HasPrefix(methName, "Max"):
p.print("atomicMax")
}
} else {
p.print(recvType + "." + methName)
if p.curFunc != nil {
p.curFunc.Funcs[methName] = p.GoToSL.RecycleFunc(methName)
}
}
p.setPos(x.Lparen)
p.print(token.LPAREN)
} else {
recvType = strings.TrimPrefix(recvType, "imports.") // no!
fname := recvType + "_" + methName
if p.curFunc != nil {
p.curFunc.Funcs[fname] = p.GoToSL.RecycleFunc(fname)
}
p.print(fname)
p.setPos(x.Lparen)
p.print(token.LPAREN)
p.print(recvPath)
if len(x.Args) > 0 {
p.print(token.COMMA, blank)
}
}
if x.Ellipsis.IsValid() {
p.exprList(x.Lparen, args, depth, 0, x.Ellipsis, false)
p.setPos(x.Ellipsis)
p.print(token.ELLIPSIS)
if x.Rparen.IsValid() && p.lineFor(x.Ellipsis) < p.lineFor(x.Rparen) {
p.print(token.COMMA, formfeed)
}
} else {
p.exprList(x.Lparen, args, depth, commaTerm, x.Rparen, false)
}
p.setPos(x.Rparen)
p.print(token.RPAREN)
p.curMethIsAtomic = false
p.assignRwArgs(rwargs) // gosl: assign temp var back to global var
}
// gosl: process math methods into expressions: .Add() -> + and V() to get slvec type
func (p *printer) mathMeth(x *ast.CallExpr, depth int, methName, recvPath, recvType string) bool {
if strings.HasPrefix(recvType, "slvec.") && methName == "V" {
btyp := strings.TrimPrefix(recvType, "slvec.")
rtyp := "math32." + btyp
p.print(rtyp)
p.setPos(x.Lparen)
p.print(token.LPAREN)
switch btyp {
case "Vector2", "Vector2i":
p.print(recvPath+".x", token.COMMA, recvPath+".y")
case "Vector3":
p.print(recvPath+".x", token.COMMA, recvPath+".y", token.COMMA, recvPath+".z")
}
p.setPos(x.Rparen)
p.print(token.RPAREN)
p.curMethIsAtomic = false
return true
}
opr := token.ILLEGAL
switch methName {
case "Add":
opr = token.ADD
case "Sub":
opr = token.SUB
case "Mul", "MulVector", "MulVector3", "MulScalar":
opr = token.MUL
case "Div", "DivScalar":
opr = token.QUO
}
if opr == token.ILLEGAL {
return false
}
path := x.Fun.(*ast.SelectorExpr) // we know fun is selector
p.expr(path.X)
p.print(opr)
p.setPos(x.Lparen)
p.print(token.LPAREN)
p.exprList(x.Lparen, x.Args, depth, commaTerm, x.Rparen, false)
p.setPos(x.Rparen)
p.print(token.RPAREN)
p.curMethIsAtomic = false
return true
}
func (p *printer) expr0(x ast.Expr, depth int) {
p.expr1(x, token.LowestPrec, depth)
}
func (p *printer) expr(x ast.Expr) {
const depth = 1
p.expr1(x, token.LowestPrec, depth)
}
// ----------------------------------------------------------------------------
// Statements
// Print the statement list indented, but without a newline after the last statement.
// Extra line breaks between statements in the source are respected but at most one
// empty line is printed between statements.
func (p *printer) stmtList(list []ast.Stmt, nindent int, nextIsRBrace bool) {
if nindent > 0 {
p.print(indent)
}
var line int
i := 0
for _, s := range list {
// ignore empty statements (was issue 3466)
if _, isEmpty := s.(*ast.EmptyStmt); !isEmpty {
// nindent == 0 only for lists of switch/select case clauses;
// in those cases each clause is a new section
if len(p.output) > 0 {
// only print line break if we are not at the beginning of the output
// (i.e., we are not printing only a partial program)
p.linebreak(p.lineFor(s.Pos()), 1, ignore, i == 0 || nindent == 0 || p.linesFrom(line) > 0)
}
p.recordLine(&line)
p.stmt(s, nextIsRBrace && i == len(list)-1, false)
// labeled statements put labels on a separate line, but here
// we only care about the start line of the actual statement
// without label - correct line for each label
for t := s; ; {
lt, _ := t.(*ast.LabeledStmt)
if lt == nil {
break
}
line++
t = lt.Stmt
}
i++
}
}
if nindent > 0 {
p.print(unindent)
}
}
// block prints an *ast.BlockStmt; it always spans at least two lines.
func (p *printer) block(b *ast.BlockStmt, nindent int) {
p.GoToSL.GetVarStack.Push(make(map[string]*GetGlobalVar))
p.setPos(b.Lbrace)
p.print(token.LBRACE)
nstmt := len(b.List)
retLast := false
if nstmt > 1 {
if _, ok := b.List[nstmt-1].(*ast.ReturnStmt); ok {
retLast = true
}
}
if retLast {
p.stmtList(b.List[:nstmt-1], nindent, true)
} else {
p.stmtList(b.List, nindent, true)
}
getVars := p.GoToSL.GetVarStack.Pop()
if len(getVars) > 0 { // gosl: set the get vars
p.setGlobalVars(getVars)
}
if retLast {
p.stmt(b.List[nstmt-1], true, false)
}
p.linebreak(p.lineFor(b.Rbrace), 1, ignore, true)
p.setPos(b.Rbrace)
p.print(token.RBRACE)
}
func isTypeName(x ast.Expr) bool {
switch t := x.(type) {
case *ast.Ident:
return true
case *ast.SelectorExpr:
return isTypeName(t.X)
}
return false
}
func stripParens(x ast.Expr) ast.Expr {
if px, strip := x.(*ast.ParenExpr); strip {
// parentheses must not be stripped if there are any
// unparenthesized composite literals starting with
// a type name
ast.Inspect(px.X, func(node ast.Node) bool {
switch x := node.(type) {
case *ast.ParenExpr:
// parentheses protect enclosed composite literals
return false
case *ast.CompositeLit:
if isTypeName(x.Type) {
strip = false // do not strip parentheses
}
return false
}
// in all other cases, keep inspecting
return true
})
if strip {
return stripParens(px.X)
}
}
return x
}
func stripParensAlways(x ast.Expr) ast.Expr {
if x, ok := x.(*ast.ParenExpr); ok {
return stripParensAlways(x.X)
}
return x
}
func (p *printer) controlClause(isForStmt bool, init ast.Stmt, expr ast.Expr, post ast.Stmt) {
p.print(blank)
p.print(token.LPAREN)
needsBlank := false
if init == nil && post == nil {
// no semicolons required
if expr != nil {
p.expr(stripParens(expr))
needsBlank = true
}
} else {
// all semicolons required
// (they are not separators, print them explicitly)
if init != nil {
p.stmt(init, false, false) // false = generate own semi
p.print(blank)
} else {
p.print(token.SEMICOLON, blank)
}
if expr != nil {
p.expr(stripParens(expr))
needsBlank = true
}
if isForStmt {
p.print(token.SEMICOLON, blank)
needsBlank = false
if post != nil {
p.stmt(post, false, true) // nosemi
needsBlank = true
}
}
}
p.print(token.RPAREN)
if needsBlank {
p.print(blank)
}
}
// indentList reports whether an expression list would look better if it
// were indented wholesale (starting with the very first element, rather
// than starting at the first line break).
func (p *printer) indentList(list []ast.Expr) bool {
// Heuristic: indentList reports whether there are more than one multi-
// line element in the list, or if there is any element that is not
// starting on the same line as the previous one ends.
if len(list) >= 2 {
var b = p.lineFor(list[0].Pos())
var e = p.lineFor(list[len(list)-1].End())
if 0 < b && b < e {
// list spans multiple lines
n := 0 // multi-line element count
line := b
for _, x := range list {
xb := p.lineFor(x.Pos())
xe := p.lineFor(x.End())
if line < xb {
// x is not starting on the same
// line as the previous one ended
return true
}
if xb < xe {
// x is a multi-line element
n++
}
line = xe
}
return n > 1
}
}
return false
}
func (p *printer) stmt(stmt ast.Stmt, nextIsRBrace bool, nosemi bool) {
p.setPos(stmt.Pos())
switch s := stmt.(type) {
case *ast.BadStmt:
p.print("BadStmt")
case *ast.DeclStmt:
p.decl(s.Decl)
case *ast.EmptyStmt:
// nothing to do
case *ast.LabeledStmt:
// a "correcting" unindent immediately following a line break
// is applied before the line break if there is no comment
// between (see writeWhitespace)
p.print(unindent)
p.expr(s.Label)
p.setPos(s.Colon)
p.print(token.COLON, indent)
if e, isEmpty := s.Stmt.(*ast.EmptyStmt); isEmpty {
if !nextIsRBrace {
p.print(newline)
p.setPos(e.Pos())
p.print(token.SEMICOLON)
break
}
} else {
p.linebreak(p.lineFor(s.Stmt.Pos()), 1, ignore, true)
}
p.stmt(s.Stmt, nextIsRBrace, nosemi)
case *ast.ExprStmt:
const depth = 1
p.expr0(s.X, depth)
if !nosemi {
p.print(token.SEMICOLON)
}
case *ast.SendStmt:
const depth = 1
p.expr0(s.Chan, depth)
p.print(blank)
p.setPos(s.Arrow)
p.print(token.ARROW, blank)
p.expr0(s.Value, depth)
case *ast.IncDecStmt:
const depth = 1
p.expr0(s.X, depth+1)
p.setPos(s.TokPos)
p.print(s.Tok)
if !nosemi {
p.print(token.SEMICOLON)
}
case *ast.AssignStmt:
var depth = 1
if len(s.Lhs) > 1 && len(s.Rhs) > 1 {
depth++
}
if s.Tok == token.DEFINE {
if ce, ok := s.Rhs[0].(*ast.CallExpr); ok {
if fid, ok := ce.Fun.(*ast.Ident); ok {
if strings.HasPrefix(fid.Name, "Get") {
if gvr, ok := p.GoToSL.GetFuncs[fid.Name]; ok {
if p.curFunc != nil {
p.curFunc.AddVarUsed(gvr)
}
p.getGlobalVar(s, gvr) // replace GetVar function call with assignment of local var
return
}
}
}
}
p.print("var", blank) // we don't know if it is var or let..
}
p.exprList(s.Pos(), s.Lhs, depth, 0, s.TokPos, false)
p.print(blank)
p.setPos(s.TokPos)
switch s.Tok {
case token.DEFINE:
p.print(token.ASSIGN, blank)
case token.AND_NOT_ASSIGN:
p.print(token.AND_ASSIGN, blank, "~")
default:
p.print(s.Tok, blank)
}
if p.matchAssignType(s.Lhs, s.Rhs) {
} else {
p.exprList(s.TokPos, s.Rhs, depth, 0, token.NoPos, false)
}
if !nosemi {
p.print(token.SEMICOLON)
}
p.print(newline)
case *ast.GoStmt:
p.print(token.GO, blank)
p.expr(s.Call)
case *ast.DeferStmt:
p.print(token.DEFER, blank)
p.expr(s.Call)
case *ast.ReturnStmt:
p.print(token.RETURN)
if s.Results != nil {
p.print(blank)
if !p.matchLiteralType(s.Results[0], p.curReturnType) {
// Use indentList heuristic to make corner cases look
// better (issue 1207). A more systematic approach would
// always indent, but this would cause significant
// reformatting of the code base and not necessarily
// lead to more nicely formatted code in general.
if p.indentList(s.Results) {
p.print(indent)
// Use NoPos so that a newline never goes before
// the results (see issue #32854).
p.exprList(token.NoPos, s.Results, 1, noIndent, token.NoPos, false)
p.print(unindent)
} else {
p.exprList(token.NoPos, s.Results, 1, 0, token.NoPos, false)
}
}
}
if !nosemi {
p.print(token.SEMICOLON)
}
case *ast.BranchStmt:
p.print(s.Tok)
if s.Label != nil {
p.print(blank)
p.expr(s.Label)
}
p.print(token.SEMICOLON)
case *ast.BlockStmt:
p.block(s, 1)
case *ast.IfStmt:
p.print(token.IF)
p.controlClause(false, s.Init, s.Cond, nil)
p.block(s.Body, 1)
if s.Else != nil {
p.print(blank, token.ELSE, blank)
switch s.Else.(type) {
case *ast.BlockStmt, *ast.IfStmt:
p.stmt(s.Else, nextIsRBrace, false)
default:
// This can only happen with an incorrectly
// constructed AST. Permit it but print so
// that it can be gotosld without errors.
p.print(token.LBRACE, indent, formfeed)
p.stmt(s.Else, true, false)
p.print(unindent, formfeed, token.RBRACE)
}
}
case *ast.CaseClause:
if s.List != nil {
p.print(token.CASE, blank)
p.exprList(s.Pos(), s.List, 1, 0, s.Colon, false)
} else {
p.print(token.DEFAULT)
}
p.setPos(s.Colon)
p.print(token.COLON, blank, token.LBRACE) // Go implies new context, C doesn't
p.stmtList(s.Body, 1, nextIsRBrace)
p.print(formfeed, token.RBRACE)
case *ast.SwitchStmt:
p.print(token.SWITCH)
p.controlClause(false, s.Init, s.Tag, nil)
p.block(s.Body, 0)
case *ast.TypeSwitchStmt:
p.print(token.SWITCH)
if s.Init != nil {
p.print(blank)
p.stmt(s.Init, false, false)
p.print(token.SEMICOLON)
}
p.print(blank)
p.stmt(s.Assign, false, false)
p.print(blank)
p.block(s.Body, 0)
case *ast.CommClause:
if s.Comm != nil {
p.print(token.CASE, blank)
p.stmt(s.Comm, false, false)
} else {
p.print(token.DEFAULT)
}
p.setPos(s.Colon)
p.print(token.COLON)
p.stmtList(s.Body, 1, nextIsRBrace)
case *ast.SelectStmt:
p.print(token.SELECT, blank)
body := s.Body
if len(body.List) == 0 && !p.commentBefore(p.posFor(body.Rbrace)) {
// print empty select statement w/o comments on one line
p.setPos(body.Lbrace)
p.print(token.LBRACE)
p.setPos(body.Rbrace)
p.print(token.RBRACE)
} else {
p.block(body, 0)
}
case *ast.ForStmt:
p.print(token.FOR)
p.controlClause(true, s.Init, s.Cond, s.Post)
p.block(s.Body, 1)
case *ast.RangeStmt:
// gosl: only supporting the for i := range 10 kind of range loop
p.print(token.FOR, blank)
if s.Key != nil {
p.print(token.LPAREN, "var", blank)
p.expr(s.Key)
p.print(token.ASSIGN, "0", token.SEMICOLON, blank)
p.expr(s.Key)
p.print(token.LSS)
p.expr(stripParens(s.X))
p.print(token.SEMICOLON, blank)
p.expr(s.Key)
p.print(token.INC, token.RPAREN)
// if s.Value != nil {
// // use position of value following the comma as
// // comma position for correct comment placement
// p.setPos(s.Value.Pos())
// p.print(token.COMMA, blank)
// p.expr(s.Value)
// }
// p.print(blank)
// p.setPos(s.TokPos)
// p.print(s.Tok, blank)
} else {
p.print(token.LPAREN, "var", blank)
p.print("i")
p.print(token.ASSIGN, "0", token.SEMICOLON, blank)
p.print("i")
p.print(token.LSS)
p.expr(stripParens(s.X))
p.print(token.SEMICOLON, blank)
p.print("i")
p.print(token.INC, token.RPAREN)
}
// p.print(token.RANGE, blank)
// p.expr(stripParens(s.X))
p.print(blank)
p.block(s.Body, 1)
default:
panic("unreachable")
}
}
// ----------------------------------------------------------------------------
// Declarations
// The keepTypeColumn function determines if the type column of a series of
// consecutive const or var declarations must be kept, or if initialization
// values (V) can be placed in the type column (T) instead. The i'th entry
// in the result slice is true if the type column in spec[i] must be kept.
//
// For example, the declaration:
//
// const (
// foobar int = 42 // comment
// x = 7 // comment
// foo
// bar = 991
// )
//
// leads to the type/values matrix below. A run of value columns (V) can
// be moved into the type column if there is no type for any of the values
// in that column (we only move entire columns so that they align properly).
//
// matrix formatted result
// matrix
// T V -> T V -> true there is a T and so the type
// - V - V true column must be kept
// - - - - false
// - V V - false V is moved into T column
func keepTypeColumn(specs []ast.Spec) []bool {
m := make([]bool, len(specs))
populate := func(i, j int, keepType bool) {
if keepType {
for ; i < j; i++ {
m[i] = true
}
}
}
i0 := -1 // if i0 >= 0 we are in a run and i0 is the start of the run
var keepType bool
for i, s := range specs {
t := s.(*ast.ValueSpec)
if t.Values != nil {
if i0 < 0 {
// start of a run of ValueSpecs with non-nil Values
i0 = i
keepType = false
}
} else {
if i0 >= 0 {
// end of a run
populate(i0, i, keepType)
i0 = -1
}
}
if t.Type != nil {
keepType = true
}
}
if i0 >= 0 {
// end of a run
populate(i0, len(specs), keepType)
}
return m
}
func (p *printer) valueSpec(s *ast.ValueSpec, keepType bool, tok token.Token, firstSpec *ast.ValueSpec, isIota bool, idx int) {
p.setComment(s.Doc)
// gosl: key to use Pos() as first arg to trigger emitting of comments!
switch tok {
case token.CONST:
p.setPos(s.Pos())
p.print(tok, blank)
case token.TYPE:
p.setPos(s.Pos())
p.print("alias", blank)
}
p.print(vtab)
extraTabs := 3
p.identList(s.Names, false) // always present
if isIota {
if s.Type != nil {
p.print(token.COLON, blank)
p.expr(s.Type)
} else if firstSpec.Type != nil {
p.print(token.COLON, blank)
p.expr(firstSpec.Type)
}
p.print(vtab, token.ASSIGN, blank)
p.print(fmt.Sprintf("%d", idx))
} else if s.Type != nil || keepType {
p.print(token.COLON, blank)
p.expr(s.Type)
extraTabs--
} else if tok == token.CONST && firstSpec.Type != nil {
p.expr(firstSpec.Type)
extraTabs--
}
if !(isIota && s == firstSpec) && s.Values != nil {
p.print(vtab, token.ASSIGN, blank)
p.exprList(token.NoPos, s.Values, 1, 0, token.NoPos, false)
extraTabs--
}
p.print(token.SEMICOLON)
if s.Comment != nil {
for ; extraTabs > 0; extraTabs-- {
p.print(vtab)
}
p.setComment(s.Comment)
}
}
func sanitizeImportPath(lit *ast.BasicLit) *ast.BasicLit {
// Note: An unmodified AST generated by go/gotoslr will already
// contain a backward- or double-quoted path string that does
// not contain any invalid characters, and most of the work
// here is not needed. However, a modified or generated AST
// may possibly contain non-canonical paths. Do the work in
// all cases since it's not too hard and not speed-critical.
// if we don't have a proper string, be conservative and return whatever we have
if lit.Kind != token.STRING {
return lit
}
s, err := strconv.Unquote(lit.Value)
if err != nil {
return lit
}
// if the string is an invalid path, return whatever we have
//
// spec: "Implementation restriction: A compiler may restrict
// ImportPaths to non-empty strings using only characters belonging
// to Unicode's L, M, N, P, and S general categories (the Graphic
// characters without spaces) and may also exclude the characters
// !"#$%&'()*,:;<=>?[\]^`{|} and the Unicode replacement character
// U+FFFD."
if s == "" {
return lit
}
const illegalChars = `!"#$%&'()*,:;<=>?[\]^{|}` + "`\uFFFD"
for _, r := range s {
if !unicode.IsGraphic(r) || unicode.IsSpace(r) || strings.ContainsRune(illegalChars, r) {
return lit
}
}
// otherwise, return the double-quoted path
s = strconv.Quote(s)
if s == lit.Value {
return lit // nothing wrong with lit
}
return &ast.BasicLit{ValuePos: lit.ValuePos, Kind: token.STRING, Value: s}
}
// The parameter n is the number of specs in the group. If doIndent is set,
// multi-line identifier lists in the spec are indented when the first
// linebreak is encountered.
func (p *printer) spec(spec ast.Spec, n int, doIndent bool, tok token.Token) {
switch s := spec.(type) {
case *ast.ImportSpec:
p.setComment(s.Doc)
if s.Name != nil {
p.expr(s.Name)
p.print(blank)
}
p.expr(sanitizeImportPath(s.Path))
p.setComment(s.Comment)
p.setPos(s.EndPos)
case *ast.ValueSpec:
if n != 1 {
p.internalError("expected n = 1; got", n)
}
p.setComment(s.Doc)
if len(s.Names) > 1 {
nnm := len(s.Names)
for ni, nm := range s.Names {
p.print(tok, blank)
p.print(nm.Name)
if s.Type != nil {
p.print(token.COLON, blank)
p.expr(s.Type)
}
if s.Values != nil {
p.print(blank, token.ASSIGN, blank)
p.exprList(token.NoPos, s.Values, 1, 0, token.NoPos, false)
}
p.print(token.SEMICOLON)
if ni < nnm-1 {
p.print(formfeed)
}
}
} else {
p.print(tok, blank)
p.identList(s.Names, doIndent) // always present
if s.Type != nil {
p.print(token.COLON, blank)
p.expr(s.Type)
}
if s.Values != nil {
p.print(blank, token.ASSIGN, blank)
p.exprList(token.NoPos, s.Values, 1, 0, token.NoPos, false)
}
p.print(token.SEMICOLON)
p.setComment(s.Comment)
}
case *ast.TypeSpec:
p.setComment(s.Doc)
st, isStruct := s.Type.(*ast.StructType)
if isStruct {
p.setPos(st.Pos())
p.print(token.STRUCT, blank)
} else {
p.print("alias", blank)
}
p.expr(s.Name)
if !isStruct {
p.print(blank, token.ASSIGN, blank)
}
if s.TypeParams != nil {
p.parameters(s.TypeParams, typeTParam)
}
// if n == 1 {
// p.print(blank)
// } else {
// p.print(vtab)
// }
if s.Assign.IsValid() {
p.print(token.ASSIGN, blank)
}
p.expr(s.Type)
if !isStruct {
p.print(token.SEMICOLON)
}
p.setComment(s.Comment)
default:
panic("unreachable")
}
}
// gosl: process system global vars
func (p *printer) systemVars(d *ast.GenDecl, sysname string) {
if !p.GoToSL.GetFuncGraph {
return
}
sy := p.GoToSL.System(sysname)
var gp *Group
var err error
for _, s := range d.Specs {
vs := s.(*ast.ValueSpec)
dirs, docs := p.findDirective(vs.Doc)
readOnly := false
readOrWrite := false
if hasDirective(dirs, "read-only") {
readOnly = true
} else if hasDirective(dirs, "read-or-write") {
readOnly = true
readOrWrite = true
}
if gpnm, ok := directiveAfter(dirs, "group"); ok {
if gpnm == "" {
gp = &Group{Name: fmt.Sprintf("Group_%d", len(sy.Groups)), Doc: docs}
sy.Groups = append(sy.Groups, gp)
} else {
gps := strings.Fields(gpnm)
gp = &Group{Doc: docs}
if gps[0] == "-uniform" {
gp.Uniform = true
if len(gps) > 1 {
gp.Name = gps[1]
}
} else {
gp.Name = gps[0]
}
sy.Groups = append(sy.Groups, gp)
}
}
if gp == nil {
gp = &Group{Name: fmt.Sprintf("Group_%d", len(sy.Groups)), Doc: docs}
sy.Groups = append(sy.Groups, gp)
}
if len(vs.Names) != 1 {
err = fmt.Errorf("gosl: system %q: vars must have only 1 variable per line", sysname)
p.userError(err)
}
nm := vs.Names[0].Name
typ := ""
if sl, ok := vs.Type.(*ast.ArrayType); ok {
id, ok := sl.Elt.(*ast.Ident)
if !ok {
err = fmt.Errorf("gosl: system %q: Var type not recognized: %#v", sysname, sl.Elt)
p.userError(err)
continue
}
// by the time this happens, all types have been moved to imports and show up there
// so we've lost the original origin. And we'd have to make up an incompatible type name
// anyway, so bottom line is: all var types need to be defined locally.
// tt := p.getIdType(id)
// if tt != nil {
// fmt.Println("idtyp:", tt.String())
// }
typ = "[]" + id.Name
} else {
sel, ok := vs.Type.(*ast.SelectorExpr)
if !ok {
st, ok := vs.Type.(*ast.StarExpr)
if !ok {
err = fmt.Errorf("gosl: system %q: Var types must be []slices or tensor.Float32, tensor.Uint32", sysname)
p.userError(err)
continue
}
sel, ok = st.X.(*ast.SelectorExpr)
if !ok {
err = fmt.Errorf("gosl: system %q: Var types must be []slices or tensor.Float32, tensor.Uint32", sysname)
p.userError(err)
continue
}
}
sid, ok := sel.X.(*ast.Ident)
if !ok {
err = fmt.Errorf("gosl: system %q: Var type selector is not recognized: %#v", sysname, sel.X)
p.userError(err)
continue
}
typ = sid.Name + "." + sel.Sel.Name
}
vr := &Var{Name: nm, Type: typ, ReadOnly: readOnly, ReadOrWrite: readOrWrite}
if strings.HasPrefix(typ, "tensor.") {
vr.Tensor = true
dstr, ok := directiveAfter(dirs, "dims")
if !ok {
err = fmt.Errorf("gosl: system %q: variable %q tensor vars require //gosl:dims <n> to specify number of dimensions", sysname, nm)
p.userError(err)
continue
}
if dims, err := strconv.Atoi(dstr); err == nil {
vr.SetTensorKind()
vr.TensorDims = dims
} else {
err = fmt.Errorf("gosl: system %q: variable %q tensor dims parse error: %s", sysname, nm, err.Error())
p.userError(err)
}
if nbufstr, ok := directiveAfter(dirs, "nbuffs"); ok {
if nbuf, err := strconv.Atoi(nbufstr); err == nil {
vr.NBuffs = nbuf
} else {
err = fmt.Errorf("gosl: system %q: variable %q tensor nbuffs parse error: %s", sysname, nm, err.Error())
p.userError(err)
}
}
}
gp.Vars = append(gp.Vars, vr)
if p.GoToSL.Config.Debug {
fmt.Println("\tAdded var:", nm, typ, "to group:", gp.Name)
}
}
p.GoToSL.VarsAdded()
}
func (p *printer) genDecl(d *ast.GenDecl) {
p.setComment(d.Doc)
// note: critical to print here to trigger comment generation in right place
p.setPos(d.Pos())
if d.Tok == token.IMPORT {
return
}
// p.print(d.Pos(), d.Tok, blank)
p.print(ignore) // don't print
if d.Lparen.IsValid() || len(d.Specs) != 1 {
// group of parenthesized declarations
// p.setPos(d.Lparen)
// p.print(token.LPAREN)
if n := len(d.Specs); n > 0 {
// p.print(indent, formfeed)
if n > 1 && (d.Tok == token.CONST || d.Tok == token.VAR) {
// two or more grouped const/var declarations:
if d.Tok == token.VAR {
dirs, _ := p.findDirective(d.Doc)
if sysname, ok := directiveAfter(dirs, "vars"); ok {
p.systemVars(d, sysname)
return
}
}
// determine if the type column must be kept
keepType := keepTypeColumn(d.Specs)
firstSpec := d.Specs[0].(*ast.ValueSpec)
isIota := false
if d.Tok == token.CONST {
if id, isId := firstSpec.Values[0].(*ast.Ident); isId {
if id.Name == "iota" {
isIota = true
}
}
}
var line int
for i, s := range d.Specs {
vs := s.(*ast.ValueSpec)
if i > 0 {
p.linebreak(p.lineFor(s.Pos()), 1, ignore, p.linesFrom(line) > 0)
}
p.recordLine(&line)
p.valueSpec(vs, keepType[i], d.Tok, firstSpec, isIota, i)
}
} else {
tok := d.Tok
if p.curFunc == nil && tok == token.VAR { // only system vars are supported at global scope
// could add further comment-directive logic
// to specify <private> or <workgroup> scope if needed
tok = token.CONST
}
var line int
for i, s := range d.Specs {
if i > 0 {
p.linebreak(p.lineFor(s.Pos()), 1, ignore, p.linesFrom(line) > 0)
}
p.recordLine(&line)
p.spec(s, n, false, tok)
}
}
// p.print(unindent, formfeed)
}
// p.setPos(d.Rparen)
// p.print(token.RPAREN)
} else if len(d.Specs) > 0 {
tok := d.Tok
if p.curFunc == nil && tok == token.VAR { // only system vars are supported at global scope
tok = token.CONST
}
// single declaration
p.spec(d.Specs[0], 1, true, tok)
}
}
// sizeCounter is an io.Writer which counts the number of bytes written,
// as well as whether a newline character was seen.
type sizeCounter struct {
hasNewline bool
size int
}
func (c *sizeCounter) Write(p []byte) (int, error) {
if !c.hasNewline {
for _, b := range p {
if b == '\n' || b == '\f' {
c.hasNewline = true
break
}
}
}
c.size += len(p)
return len(p), nil
}
// nodeSize determines the size of n in chars after formatting.
// The result is <= maxSize if the node fits on one line with at
// most maxSize chars and the formatted output doesn't contain
// any control chars. Otherwise, the result is > maxSize.
func (p *printer) nodeSize(n ast.Node, maxSize int) (size int) {
// nodeSize invokes the printer, which may invoke nodeSize
// recursively. For deep composite literal nests, this can
// lead to an exponential algorithm. Remember previous
// results to prune the recursion (was issue 1628).
if size, found := p.nodeSizes[n]; found {
return size
}
size = maxSize + 1 // assume n doesn't fit
p.nodeSizes[n] = size
// nodeSize computation must be independent of particular
// style so that we always get the same decision; print
// in RawFormat
cfg := PrintConfig{Mode: RawFormat}
var counter sizeCounter
if err := cfg.fprint(&counter, p.pkg, n, p.nodeSizes); err != nil {
return
}
if counter.size <= maxSize && !counter.hasNewline {
// n fits in a single line
size = counter.size
p.nodeSizes[n] = size
}
return
}
// numLines returns the number of lines spanned by node n in the original source.
func (p *printer) numLines(n ast.Node) int {
if from := n.Pos(); from.IsValid() {
if to := n.End(); to.IsValid() {
return p.lineFor(to) - p.lineFor(from) + 1
}
}
return infinity
}
// bodySize is like nodeSize but it is specialized for *ast.BlockStmt's.
func (p *printer) bodySize(b *ast.BlockStmt, maxSize int) int {
pos1 := b.Pos()
pos2 := b.Rbrace
if pos1.IsValid() && pos2.IsValid() && p.lineFor(pos1) != p.lineFor(pos2) {
// opening and closing brace are on different lines - don't make it a one-liner
return maxSize + 1
}
if len(b.List) > 5 {
// too many statements - don't make it a one-liner
return maxSize + 1
}
// otherwise, estimate body size
bodySize := p.commentSizeBefore(p.posFor(pos2))
for i, s := range b.List {
if bodySize > maxSize {
break // no need to continue
}
if i > 0 {
bodySize += 2 // space for a semicolon and blank
}
bodySize += p.nodeSize(s, maxSize)
}
return bodySize
}
// funcBody prints a function body following a function header of given headerSize.
// If the header's and block's size are "small enough" and the block is "simple enough",
// the block is printed on the current line, without line breaks, spaced from the header
// by sep. Otherwise the block's opening "{" is printed on the current line, followed by
// lines for the block's statements and its closing "}".
func (p *printer) funcBody(headerSize int, sep whiteSpace, b *ast.BlockStmt) {
if b == nil {
return
}
// save/restore composite literal nesting level
defer func(level int) {
p.level = level
}(p.level)
p.level = 0
const maxSize = 100
if headerSize+p.bodySize(b, maxSize) <= maxSize {
p.print(sep)
p.setPos(b.Lbrace)
p.print(token.LBRACE)
if len(b.List) > 0 {
p.print(blank)
for i, s := range b.List {
if i > 0 {
p.print(token.SEMICOLON, blank)
}
p.stmt(s, i == len(b.List)-1, false)
}
p.print(blank)
}
p.print(noExtraLinebreak)
p.setPos(b.Rbrace)
p.print(token.RBRACE, noExtraLinebreak)
return
}
if sep != ignore {
p.print(blank) // always use blank
}
p.block(b, 1)
}
// distanceFrom returns the column difference between p.out (the current output
// position) and startOutCol. If the start position is on a different line from
// the current position (or either is unknown), the result is infinity.
func (p *printer) distanceFrom(startPos token.Pos, startOutCol int) int {
if startPos.IsValid() && p.pos.IsValid() && p.posFor(startPos).Line == p.pos.Line {
return p.out.Column - startOutCol
}
return infinity
}
func (p *printer) methRecvType(typ ast.Expr) string {
switch x := typ.(type) {
case *ast.StarExpr:
return p.methRecvType(x.X)
case *ast.Ident:
return x.Name
default:
return fmt.Sprintf("recv type unknown: %+T", x)
}
return ""
}
func (p *printer) funcDecl(d *ast.FuncDecl) {
fname := ""
if d.Recv != nil {
for ex := range p.ExcludeFunctions {
if d.Name.Name == ex {
return
}
}
if d.Recv.List[0].Names != nil {
dirs, _ := p.findDirective(d.Doc)
pointerRecv := false
if hasDirective(dirs, "pointer-receiver") {
pointerRecv = true
}
p.curMethRecv = d.Recv.List[0]
isptr, typnm := p.printMethRecv()
if isptr && !pointerRecv {
isptr = false
}
if isptr {
p.curPtrArgs = []*ast.Ident{p.curMethRecv.Names[0]}
}
fname = typnm + "_" + d.Name.Name
// fmt.Printf("cur func recv: %v\n", p.curMethRecv)
}
// p.parameters(d.Recv, funcParam) // method: print receiver
// p.print(blank)
} else {
fname = d.Name.Name
}
if p.GoToSL.GetFuncGraph {
p.curFunc = p.GoToSL.RecycleFunc(fname)
} else {
fn, ok := p.GoToSL.KernelFuncs[fname]
if !ok {
return
}
p.curFunc = fn
}
p.setComment(d.Doc)
p.setPos(d.Pos())
// We have to save startCol only after emitting FUNC; otherwise it can be on a
// different line (all whitespace preceding the FUNC is emitted only when the
// FUNC is emitted).
startCol := p.out.Column - len("func ")
p.print("fn", blank, fname)
p.signature(d.Type, d.Recv)
p.funcBody(p.distanceFrom(d.Pos(), startCol), vtab, d.Body)
p.curPtrArgs = nil
p.curMethRecv = nil
if p.GoToSL.GetFuncGraph {
p.GoToSL.FuncGraph[fname] = p.curFunc
}
p.curFunc = nil
}
func (p *printer) decl(decl ast.Decl) {
switch d := decl.(type) {
case *ast.BadDecl:
p.setPos(d.Pos())
p.print("BadDecl")
case *ast.GenDecl:
p.genDecl(d)
case *ast.FuncDecl:
p.funcDecl(d)
default:
panic("unreachable")
}
}
// ----------------------------------------------------------------------------
// Files
func declToken(decl ast.Decl) (tok token.Token) {
tok = token.ILLEGAL
switch d := decl.(type) {
case *ast.GenDecl:
tok = d.Tok
case *ast.FuncDecl:
tok = token.FUNC
}
return
}
func (p *printer) declList(list []ast.Decl) {
tok := token.ILLEGAL
for _, d := range list {
prev := tok
tok = declToken(d)
// If the declaration token changed (e.g., from CONST to TYPE)
// or the next declaration has documentation associated with it,
// print an empty line between top-level declarations.
// (because p.linebreak is called with the position of d, which
// is past any documentation, the minimum requirement is satisfied
// even w/o the extra getDoc(d) nil-check - leave it in case the
// linebreak logic improves - there's already a TODO).
if len(p.output) > 0 {
// only print line break if we are not at the beginning of the output
// (i.e., we are not printing only a partial program)
min := 1
if prev != tok || getDoc(d) != nil {
min = 2
}
// start a new section if the next declaration is a function
// that spans multiple lines (see also issue #19544)
p.linebreak(p.lineFor(d.Pos()), min, ignore, tok == token.FUNC && p.numLines(d) > 1)
}
p.decl(d)
}
}
func (p *printer) file(src *ast.File) {
p.setComment(src.Doc)
p.setPos(src.Pos())
p.print(token.PACKAGE, blank)
p.expr(src.Name)
p.declList(src.Decls)
p.print(newline)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file is largely copied from the Go source,
// src/go/printer/printer.go:
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
import (
"fmt"
"go/ast"
"go/build/constraint"
"go/token"
"io"
"os"
"strings"
"sync"
"text/tabwriter"
"unicode"
"cogentcore.org/core/base/fsx"
"golang.org/x/tools/go/packages"
)
const (
maxNewlines = 2 // max. number of newlines between source text
debug = false // enable for debugging
infinity = 1 << 30
)
type whiteSpace byte
const (
ignore = whiteSpace(0)
blank = whiteSpace(' ')
vtab = whiteSpace('\v')
newline = whiteSpace('\n')
formfeed = whiteSpace('\f')
indent = whiteSpace('>')
unindent = whiteSpace('<')
)
// A pmode value represents the current printer mode.
type pmode int
const (
noExtraBlank pmode = 1 << iota // disables extra blank after /*-style comment
noExtraLinebreak // disables extra line break after /*-style comment
)
type commentInfo struct {
cindex int // current comment index
comment *ast.CommentGroup // = printer.comments[cindex]; or nil
commentOffset int // = printer.posFor(printer.comments[cindex].List[0].Pos()).Offset; or infinity
commentNewline bool // true if the comment group contains newlines
}
type printer struct {
// Configuration (does not change after initialization)
PrintConfig
fset *token.FileSet
pkg *packages.Package // gosl: extra
// Current state
output []byte // raw printer result
indent int // current indentation
level int // level == 0: outside composite literal; level > 0: inside composite literal
mode pmode // current printer mode
endAlignment bool // if set, terminate alignment immediately
impliedSemi bool // if set, a linebreak implies a semicolon
lastTok token.Token // last token printed (token.ILLEGAL if it's whitespace)
prevOpen token.Token // previous non-brace "open" token (, [, or token.ILLEGAL
wsbuf []whiteSpace // delayed white space
goBuild []int // start index of all //go:build comments in output
plusBuild []int // start index of all // +build comments in output
// Positions
// The out position differs from the pos position when the result
// formatting differs from the source formatting (in the amount of
// white space). If there's a difference and SourcePos is set in
// ConfigMode, //line directives are used in the output to restore
// original source positions for a reader.
pos token.Position // current position in AST (source) space
out token.Position // current position in output space
last token.Position // value of pos after calling writeString
linePtr *int // if set, record out.Line for the next token in *linePtr
sourcePosErr error // if non-nil, the first error emitting a //line directive
// The list of all source comments, in order of appearance.
comments []*ast.CommentGroup // may be nil
useNodeComments bool // if not set, ignore lead and line comments of nodes
// Information about p.comments[p.cindex]; set up by nextComment.
commentInfo
// Cache of already computed node sizes.
nodeSizes map[ast.Node]int
// Cache of most recently computed line position.
cachedPos token.Pos
cachedLine int // line corresponding to cachedPos
// current arguments to function that are pointers and thus need dereferencing
// when accessing fields
curPtrArgs []*ast.Ident
curFunc *Function
curMethRecv *ast.Field // current method receiver, also included in curPtrArgs if ptr
curReturnType *ast.Ident
curMethIsAtomic bool // current method an atomic.* function -- marks arg as atomic
}
func (p *printer) internalError(msg ...any) {
if debug {
fmt.Print(p.pos.String() + ": ")
fmt.Println(msg...)
panic("go/printer")
}
}
func (p *printer) userError(err error) {
fname := fsx.DirAndFile(p.pos.String())
fmt.Print(fname + ": ")
fmt.Println(err.Error())
}
// commentsHaveNewline reports whether a list of comments belonging to
// an *ast.CommentGroup contains newlines. Because the position information
// may only be partially correct, we also have to read the comment text.
func (p *printer) commentsHaveNewline(list []*ast.Comment) bool {
// len(list) > 0
line := p.lineFor(list[0].Pos())
for i, c := range list {
if i > 0 && p.lineFor(list[i].Pos()) != line {
// not all comments on the same line
return true
}
if t := c.Text; len(t) >= 2 && (t[1] == '/' || strings.Contains(t, "\n")) {
return true
}
}
_ = line
return false
}
func (p *printer) nextComment() {
for p.cindex < len(p.comments) {
c := p.comments[p.cindex]
p.cindex++
if list := c.List; len(list) > 0 {
p.comment = c
p.commentOffset = p.posFor(list[0].Pos()).Offset
p.commentNewline = p.commentsHaveNewline(list)
return
}
// we should not reach here (correct ASTs don't have empty
// ast.CommentGroup nodes), but be conservative and try again
}
// no more comments
p.commentOffset = infinity
}
// commentBefore reports whether the current comment group occurs
// before the next position in the source code and printing it does
// not introduce implicit semicolons.
func (p *printer) commentBefore(next token.Position) bool {
return p.commentOffset < next.Offset && (!p.impliedSemi || !p.commentNewline)
}
// commentSizeBefore returns the estimated size of the
// comments on the same line before the next position.
func (p *printer) commentSizeBefore(next token.Position) int {
// save/restore current p.commentInfo (p.nextComment() modifies it)
defer func(info commentInfo) {
p.commentInfo = info
}(p.commentInfo)
size := 0
for p.commentBefore(next) {
for _, c := range p.comment.List {
size += len(c.Text)
}
p.nextComment()
}
return size
}
// recordLine records the output line number for the next non-whitespace
// token in *linePtr. It is used to compute an accurate line number for a
// formatted construct, independent of pending (not yet emitted) whitespace
// or comments.
func (p *printer) recordLine(linePtr *int) {
p.linePtr = linePtr
}
// linesFrom returns the number of output lines between the current
// output line and the line argument, ignoring any pending (not yet
// emitted) whitespace or comments. It is used to compute an accurate
// size (in number of lines) for a formatted construct.
func (p *printer) linesFrom(line int) int {
return p.out.Line - line
}
func (p *printer) posFor(pos token.Pos) token.Position {
// not used frequently enough to cache entire token.Position
return p.fset.PositionFor(pos, false /* absolute position */)
}
func (p *printer) lineFor(pos token.Pos) int {
if pos != p.cachedPos {
p.cachedPos = pos
p.cachedLine = p.fset.PositionFor(pos, false /* absolute position */).Line
}
return p.cachedLine
}
// writeLineDirective writes a //line directive if necessary.
func (p *printer) writeLineDirective(pos token.Position) {
if pos.IsValid() && (p.out.Line != pos.Line || p.out.Filename != pos.Filename) {
if strings.ContainsAny(pos.Filename, "\r\n") {
if p.sourcePosErr == nil {
p.sourcePosErr = fmt.Errorf("go/printer: source filename contains unexpected newline character: %q", pos.Filename)
}
return
}
p.output = append(p.output, tabwriter.Escape) // protect '\n' in //line from tabwriter interpretation
p.output = append(p.output, fmt.Sprintf("//line %s:%d\n", pos.Filename, pos.Line)...)
p.output = append(p.output, tabwriter.Escape)
// p.out must match the //line directive
p.out.Filename = pos.Filename
p.out.Line = pos.Line
}
}
// writeIndent writes indentation.
func (p *printer) writeIndent() {
// use "hard" htabs - indentation columns
// must not be discarded by the tabwriter
n := p.PrintConfig.Indent + p.indent // include base indentation
for i := 0; i < n; i++ {
p.output = append(p.output, '\t')
}
// update positions
p.pos.Offset += n
p.pos.Column += n
p.out.Column += n
}
// writeByte writes ch n times to p.output and updates p.pos.
// Only used to write formatting (white space) characters.
func (p *printer) writeByte(ch byte, n int) {
if p.endAlignment {
// Ignore any alignment control character;
// and at the end of the line, break with
// a formfeed to indicate termination of
// existing columns.
switch ch {
case '\t', '\v':
ch = ' '
case '\n', '\f':
ch = '\f'
p.endAlignment = false
}
}
if p.out.Column == 1 {
// no need to write line directives before white space
p.writeIndent()
}
for i := 0; i < n; i++ {
p.output = append(p.output, ch)
}
// update positions
p.pos.Offset += n
if ch == '\n' || ch == '\f' {
p.pos.Line += n
p.out.Line += n
p.pos.Column = 1
p.out.Column = 1
return
}
p.pos.Column += n
p.out.Column += n
}
// writeString writes the string s to p.output and updates p.pos, p.out,
// and p.last. If isLit is set, s is escaped w/ tabwriter.Escape characters
// to protect s from being interpreted by the tabwriter.
//
// Note: writeString is only used to write Go tokens, literals, and
// comments, all of which must be written literally. Thus, it is correct
// to always set isLit = true. However, setting it explicitly only when
// needed (i.e., when we don't know that s contains no tabs or line breaks)
// avoids processing extra escape characters and reduces run time of the
// printer benchmark by up to 10%.
func (p *printer) writeString(pos token.Position, s string, isLit bool) {
if p.out.Column == 1 {
if p.PrintConfig.Mode&SourcePos != 0 {
p.writeLineDirective(pos)
}
p.writeIndent()
}
if pos.IsValid() {
// update p.pos (if pos is invalid, continue with existing p.pos)
// Note: Must do this after handling line beginnings because
// writeIndent updates p.pos if there's indentation, but p.pos
// is the position of s.
p.pos = pos
}
if isLit {
// Protect s such that is passes through the tabwriter
// unchanged. Note that valid Go programs cannot contain
// tabwriter.Escape bytes since they do not appear in legal
// UTF-8 sequences.
p.output = append(p.output, tabwriter.Escape)
}
if debug {
p.output = append(p.output, fmt.Sprintf("/*%s*/", pos)...) // do not update p.pos!
}
p.output = append(p.output, s...)
// update positions
nlines := 0
var li int // index of last newline; valid if nlines > 0
for i := 0; i < len(s); i++ {
// Raw string literals may contain any character except back quote (`).
if ch := s[i]; ch == '\n' || ch == '\f' {
// account for line break
nlines++
li = i
// A line break inside a literal will break whatever column
// formatting is in place; ignore any further alignment through
// the end of the line.
p.endAlignment = true
}
}
p.pos.Offset += len(s)
if nlines > 0 {
p.pos.Line += nlines
p.out.Line += nlines
c := len(s) - li
p.pos.Column = c
p.out.Column = c
} else {
p.pos.Column += len(s)
p.out.Column += len(s)
}
if isLit {
p.output = append(p.output, tabwriter.Escape)
}
p.last = p.pos
}
// writeCommentPrefix writes the whitespace before a comment.
// If there is any pending whitespace, it consumes as much of
// it as is likely to help position the comment nicely.
// pos is the comment position, next the position of the item
// after all pending comments, prev is the previous comment in
// a group of comments (or nil), and tok is the next token.
func (p *printer) writeCommentPrefix(pos, next token.Position, prev *ast.Comment, tok token.Token) {
if len(p.output) == 0 {
// the comment is the first item to be printed - don't write any whitespace
return
}
if pos.IsValid() && pos.Filename != p.last.Filename {
// comment in a different file - separate with newlines
p.writeByte('\f', maxNewlines)
return
}
if pos.Line == p.last.Line && (prev == nil || prev.Text[1] != '/') {
// comment on the same line as last item:
// separate with at least one separator
hasSep := false
if prev == nil {
// first comment of a comment group
j := 0
for i, ch := range p.wsbuf {
switch ch {
case blank:
// ignore any blanks before a comment
p.wsbuf[i] = ignore
continue
case vtab:
// respect existing tabs - important
// for proper formatting of commented structs
hasSep = true
continue
case indent:
// apply pending indentation
continue
}
j = i
break
}
p.writeWhitespace(j)
}
// make sure there is at least one separator
if !hasSep {
sep := byte('\t')
if pos.Line == next.Line {
// next item is on the same line as the comment
// (which must be a /*-style comment): separate
// with a blank instead of a tab
sep = ' '
}
p.writeByte(sep, 1)
}
} else {
// comment on a different line:
// separate with at least one line break
droppedLinebreak := false
j := 0
for i, ch := range p.wsbuf {
switch ch {
case blank, vtab:
// ignore any horizontal whitespace before line breaks
p.wsbuf[i] = ignore
continue
case indent:
// apply pending indentation
continue
case unindent:
// if this is not the last unindent, apply it
// as it is (likely) belonging to the last
// construct (e.g., a multi-line expression list)
// and is not part of closing a block
if i+1 < len(p.wsbuf) && p.wsbuf[i+1] == unindent {
continue
}
// if the next token is not a closing }, apply the unindent
// if it appears that the comment is aligned with the
// token; otherwise assume the unindent is part of a
// closing block and stop (this scenario appears with
// comments before a case label where the comments
// apply to the next case instead of the current one)
if tok != token.RBRACE && pos.Column == next.Column {
continue
}
case newline, formfeed:
p.wsbuf[i] = ignore
droppedLinebreak = prev == nil // record only if first comment of a group
}
j = i
break
}
p.writeWhitespace(j)
// determine number of linebreaks before the comment
n := 0
if pos.IsValid() && p.last.IsValid() {
n = pos.Line - p.last.Line
if n < 0 { // should never happen
n = 0
}
}
// at the package scope level only (p.indent == 0),
// add an extra newline if we dropped one before:
// this preserves a blank line before documentation
// comments at the package scope level (issue 2570)
if p.indent == 0 && droppedLinebreak {
n++
}
// make sure there is at least one line break
// if the previous comment was a line comment
if n == 0 && prev != nil && prev.Text[1] == '/' {
n = 1
}
if n > 0 {
// use formfeeds to break columns before a comment;
// this is analogous to using formfeeds to separate
// individual lines of /*-style comments
p.writeByte('\f', nlimit(n))
}
}
}
// Returns true if s contains only white space
// (only tabs and blanks can appear in the printer's context).
func isBlank(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] > ' ' {
return false
}
}
return true
}
// commonPrefix returns the common prefix of a and b.
func commonPrefix(a, b string) string {
i := 0
for i < len(a) && i < len(b) && a[i] == b[i] && (a[i] <= ' ' || a[i] == '*') {
i++
}
return a[0:i]
}
// trimRight returns s with trailing whitespace removed.
func trimRight(s string) string {
return strings.TrimRightFunc(s, unicode.IsSpace)
}
// stripCommonPrefix removes a common prefix from /*-style comment lines (unless no
// comment line is indented, all but the first line have some form of space prefix).
// The prefix is computed using heuristics such that is likely that the comment
// contents are nicely laid out after re-printing each line using the printer's
// current indentation.
func stripCommonPrefix(lines []string) {
if len(lines) <= 1 {
return // at most one line - nothing to do
}
// len(lines) > 1
// The heuristic in this function tries to handle a few
// common patterns of /*-style comments: Comments where
// the opening /* and closing */ are aligned and the
// rest of the comment text is aligned and indented with
// blanks or tabs, cases with a vertical "line of stars"
// on the left, and cases where the closing */ is on the
// same line as the last comment text.
// Compute maximum common white prefix of all but the first,
// last, and blank lines, and replace blank lines with empty
// lines (the first line starts with /* and has no prefix).
// In cases where only the first and last lines are not blank,
// such as two-line comments, or comments where all inner lines
// are blank, consider the last line for the prefix computation
// since otherwise the prefix would be empty.
//
// Note that the first and last line are never empty (they
// contain the opening /* and closing */ respectively) and
// thus they can be ignored by the blank line check.
prefix := ""
prefixSet := false
if len(lines) > 2 {
for i, line := range lines[1 : len(lines)-1] {
if isBlank(line) {
lines[1+i] = "" // range starts with lines[1]
} else {
if !prefixSet {
prefix = line
prefixSet = true
}
prefix = commonPrefix(prefix, line)
}
}
}
// If we don't have a prefix yet, consider the last line.
if !prefixSet {
line := lines[len(lines)-1]
prefix = commonPrefix(line, line)
}
/*
* Check for vertical "line of stars" and correct prefix accordingly.
*/
lineOfStars := false
if p, _, ok := strings.Cut(prefix, "*"); ok {
// remove trailing blank from prefix so stars remain aligned
prefix = strings.TrimSuffix(p, " ")
lineOfStars = true
} else {
// No line of stars present.
// Determine the white space on the first line after the /*
// and before the beginning of the comment text, assume two
// blanks instead of the /* unless the first character after
// the /* is a tab. If the first comment line is empty but
// for the opening /*, assume up to 3 blanks or a tab. This
// whitespace may be found as suffix in the common prefix.
first := lines[0]
if isBlank(first[2:]) {
// no comment text on the first line:
// reduce prefix by up to 3 blanks or a tab
// if present - this keeps comment text indented
// relative to the /* and */'s if it was indented
// in the first place
i := len(prefix)
for n := 0; n < 3 && i > 0 && prefix[i-1] == ' '; n++ {
i--
}
if i == len(prefix) && i > 0 && prefix[i-1] == '\t' {
i--
}
prefix = prefix[0:i]
} else {
// comment text on the first line
suffix := make([]byte, len(first))
n := 2 // start after opening /*
for n < len(first) && first[n] <= ' ' {
suffix[n] = first[n]
n++
}
if n > 2 && suffix[2] == '\t' {
// assume the '\t' compensates for the /*
suffix = suffix[2:n]
} else {
// otherwise assume two blanks
suffix[0], suffix[1] = ' ', ' '
suffix = suffix[0:n]
}
// Shorten the computed common prefix by the length of
// suffix, if it is found as suffix of the prefix.
prefix = strings.TrimSuffix(prefix, string(suffix))
}
}
// Handle last line: If it only contains a closing */, align it
// with the opening /*, otherwise align the text with the other
// lines.
last := lines[len(lines)-1]
closing := "*/"
before, _, _ := strings.Cut(last, closing) // closing always present
if isBlank(before) {
// last line only contains closing */
if lineOfStars {
closing = " */" // add blank to align final star
}
lines[len(lines)-1] = prefix + closing
} else {
// last line contains more comment text - assume
// it is aligned like the other lines and include
// in prefix computation
prefix = commonPrefix(prefix, last)
}
// Remove the common prefix from all but the first and empty lines.
for i, line := range lines {
if i > 0 && line != "" {
lines[i] = line[len(prefix):]
}
}
}
func (p *printer) writeComment(comment *ast.Comment) {
text := comment.Text
pos := p.posFor(comment.Pos())
const linePrefix = "//line "
if strings.HasPrefix(text, linePrefix) && (!pos.IsValid() || pos.Column == 1) {
// Possibly a //-style line directive.
// Suspend indentation temporarily to keep line directive valid.
defer func(indent int) { p.indent = indent }(p.indent)
p.indent = 0
}
// shortcut common case of //-style comments
if text[1] == '/' {
if constraint.IsGoBuild(text) {
p.goBuild = append(p.goBuild, len(p.output))
} else if constraint.IsPlusBuild(text) {
p.plusBuild = append(p.plusBuild, len(p.output))
}
p.writeString(pos, trimRight(text), true)
return
}
// for /*-style comments, print line by line and let the
// write function take care of the proper indentation
lines := strings.Split(text, "\n")
// The comment started in the first column but is going
// to be indented. For an idempotent result, add indentation
// to all lines such that they look like they were indented
// before - this will make sure the common prefix computation
// is the same independent of how many times formatting is
// applied (was issue 1835).
if pos.IsValid() && pos.Column == 1 && p.indent > 0 {
for i, line := range lines[1:] {
lines[1+i] = " " + line
}
}
stripCommonPrefix(lines)
// write comment lines, separated by formfeed,
// without a line break after the last line
for i, line := range lines {
if i > 0 {
p.writeByte('\f', 1)
pos = p.pos
}
if len(line) > 0 {
p.writeString(pos, trimRight(line), true)
}
}
}
// writeCommentSuffix writes a line break after a comment if indicated
// and processes any leftover indentation information. If a line break
// is needed, the kind of break (newline vs formfeed) depends on the
// pending whitespace. The writeCommentSuffix result indicates if a
// newline was written or if a formfeed was dropped from the whitespace
// buffer.
func (p *printer) writeCommentSuffix(needsLinebreak bool) (wroteNewline, droppedFF bool) {
for i, ch := range p.wsbuf {
switch ch {
case blank, vtab:
// ignore trailing whitespace
p.wsbuf[i] = ignore
case indent, unindent:
// don't lose indentation information
case newline, formfeed:
// if we need a line break, keep exactly one
// but remember if we dropped any formfeeds
if needsLinebreak {
needsLinebreak = false
wroteNewline = true
} else {
if ch == formfeed {
droppedFF = true
}
p.wsbuf[i] = ignore
}
}
}
p.writeWhitespace(len(p.wsbuf))
// make sure we have a line break
if needsLinebreak {
p.writeByte('\n', 1)
wroteNewline = true
}
return
}
// containsLinebreak reports whether the whitespace buffer contains any line breaks.
func (p *printer) containsLinebreak() bool {
for _, ch := range p.wsbuf {
if ch == newline || ch == formfeed {
return true
}
}
return false
}
// intersperseComments consumes all comments that appear before the next token
// tok and prints it together with the buffered whitespace (i.e., the whitespace
// that needs to be written before the next token). A heuristic is used to mix
// the comments and whitespace. The intersperseComments result indicates if a
// newline was written or if a formfeed was dropped from the whitespace buffer.
func (p *printer) intersperseComments(next token.Position, tok token.Token) (wroteNewline, droppedFF bool) {
var last *ast.Comment
for p.commentBefore(next) {
list := p.comment.List
changed := false
if p.lastTok != token.IMPORT && // do not rewrite cgo's import "C" comments
p.posFor(p.comment.Pos()).Column == 1 &&
p.posFor(p.comment.End()+1) == next {
// Unindented comment abutting next token position:
// a top-level doc comment.
list = formatDocComment(list)
changed = true
if len(p.comment.List) > 0 && len(list) == 0 {
// The doc comment was removed entirely.
// Keep preceding whitespace.
p.writeCommentPrefix(p.posFor(p.comment.Pos()), next, last, tok)
// Change print state to continue at next.
p.pos = next
p.last = next
// There can't be any more comments.
p.nextComment()
return p.writeCommentSuffix(false)
}
}
for _, c := range list {
p.writeCommentPrefix(p.posFor(c.Pos()), next, last, tok)
p.writeComment(c)
last = c
}
// In case list was rewritten, change print state to where
// the original list would have ended.
if len(p.comment.List) > 0 && changed {
last = p.comment.List[len(p.comment.List)-1]
p.pos = p.posFor(last.End())
p.last = p.pos
}
p.nextComment()
}
if last != nil {
// If the last comment is a /*-style comment and the next item
// follows on the same line but is not a comma, and not a "closing"
// token immediately following its corresponding "opening" token,
// add an extra separator unless explicitly disabled. Use a blank
// as separator unless we have pending linebreaks, they are not
// disabled, and we are outside a composite literal, in which case
// we want a linebreak (issue 15137).
// TODO(gri) This has become overly complicated. We should be able
// to track whether we're inside an expression or statement and
// use that information to decide more directly.
needsLinebreak := false
if p.mode&noExtraBlank == 0 &&
last.Text[1] == '*' && p.lineFor(last.Pos()) == next.Line &&
tok != token.COMMA &&
(tok != token.RPAREN || p.prevOpen == token.LPAREN) &&
(tok != token.RBRACK || p.prevOpen == token.LBRACK) {
if p.containsLinebreak() && p.mode&noExtraLinebreak == 0 && p.level == 0 {
needsLinebreak = true
} else {
p.writeByte(' ', 1)
}
}
// Ensure that there is a line break after a //-style comment,
// before EOF, and before a closing '}' unless explicitly disabled.
if last.Text[1] == '/' ||
tok == token.EOF ||
tok == token.RBRACE && p.mode&noExtraLinebreak == 0 {
needsLinebreak = true
}
return p.writeCommentSuffix(needsLinebreak)
}
// no comment was written - we should never reach here since
// intersperseComments should not be called in that case
p.internalError("intersperseComments called without pending comments")
return
}
// writeWhitespace writes the first n whitespace entries.
func (p *printer) writeWhitespace(n int) {
// write entries
for i := 0; i < n; i++ {
switch ch := p.wsbuf[i]; ch {
case ignore:
// ignore!
case indent:
p.indent++
case unindent:
p.indent--
if p.indent < 0 {
p.internalError("negative indentation:", p.indent)
p.indent = 0
}
case newline, formfeed:
// A line break immediately followed by a "correcting"
// unindent is swapped with the unindent - this permits
// proper label positioning. If a comment is between
// the line break and the label, the unindent is not
// part of the comment whitespace prefix and the comment
// will be positioned correctly indented.
if i+1 < n && p.wsbuf[i+1] == unindent {
// Use a formfeed to terminate the current section.
// Otherwise, a long label name on the next line leading
// to a wide column may increase the indentation column
// of lines before the label; effectively leading to wrong
// indentation.
p.wsbuf[i], p.wsbuf[i+1] = unindent, formfeed
i-- // do it again
continue
}
fallthrough
default:
p.writeByte(byte(ch), 1)
}
}
// shift remaining entries down
l := copy(p.wsbuf, p.wsbuf[n:])
p.wsbuf = p.wsbuf[:l]
}
// ----------------------------------------------------------------------------
// Printing interface
// nlimit limits n to maxNewlines.
func nlimit(n int) int {
return min(n, maxNewlines)
}
func mayCombine(prev token.Token, next byte) (b bool) {
switch prev {
case token.INT:
b = next == '.' // 1.
case token.ADD:
b = next == '+' // ++
case token.SUB:
b = next == '-' // --
case token.QUO:
b = next == '*' // /*
case token.LSS:
b = next == '-' || next == '<' // <- or <<
case token.AND:
b = next == '&' || next == '^' // && or &^
}
return
}
func (p *printer) setPos(pos token.Pos) {
if pos.IsValid() {
p.pos = p.posFor(pos) // accurate position of next item
}
}
// print prints a list of "items" (roughly corresponding to syntactic
// tokens, but also including whitespace and formatting information).
// It is the only print function that should be called directly from
// any of the AST printing functions in nodes.go.
//
// Whitespace is accumulated until a non-whitespace token appears. Any
// comments that need to appear before that token are printed first,
// taking into account the amount and structure of any pending white-
// space for best comment placement. Then, any leftover whitespace is
// printed, followed by the actual token.
func (p *printer) print(args ...any) {
for _, arg := range args {
// information about the current arg
var data string
var isLit bool
var impliedSemi bool // value for p.impliedSemi after this arg
// record previous opening token, if any
switch p.lastTok {
case token.ILLEGAL:
// ignore (white space)
case token.LPAREN, token.LBRACK:
p.prevOpen = p.lastTok
default:
// other tokens followed any opening token
p.prevOpen = token.ILLEGAL
}
switch x := arg.(type) {
case pmode:
// toggle printer mode
p.mode ^= x
continue
case whiteSpace:
if x == ignore {
// don't add ignore's to the buffer; they
// may screw up "correcting" unindents (see
// LabeledStmt)
continue
}
i := len(p.wsbuf)
if i == cap(p.wsbuf) {
// Whitespace sequences are very short so this should
// never happen. Handle gracefully (but possibly with
// bad comment placement) if it does happen.
p.writeWhitespace(i)
i = 0
}
p.wsbuf = p.wsbuf[0 : i+1]
p.wsbuf[i] = x
if x == newline || x == formfeed {
// newlines affect the current state (p.impliedSemi)
// and not the state after printing arg (impliedSemi)
// because comments can be interspersed before the arg
// in this case
p.impliedSemi = false
}
p.lastTok = token.ILLEGAL
continue
case *ast.Ident:
data = x.Name
impliedSemi = true
p.lastTok = token.IDENT
case *ast.BasicLit:
data = x.Value
isLit = true
impliedSemi = true
p.lastTok = x.Kind
case token.Token:
s := x.String()
if mayCombine(p.lastTok, s[0]) {
// the previous and the current token must be
// separated by a blank otherwise they combine
// into a different incorrect token sequence
// (except for token.INT followed by a '.' this
// should never happen because it is taken care
// of via binary expression formatting)
if len(p.wsbuf) != 0 {
p.internalError("whitespace buffer not empty")
}
p.wsbuf = p.wsbuf[0:1]
p.wsbuf[0] = ' '
}
data = s
// some keywords followed by a newline imply a semicolon
switch x {
case token.BREAK, token.CONTINUE, token.FALLTHROUGH, token.RETURN,
token.INC, token.DEC, token.RPAREN, token.RBRACK, token.RBRACE:
impliedSemi = true
}
p.lastTok = x
case string:
// incorrect AST - print error message
data = x
isLit = true
impliedSemi = true
p.lastTok = token.STRING
default:
fmt.Fprintf(os.Stderr, "print: unsupported argument %v (%T)\n", arg, arg)
panic("go/printer type")
}
// data != ""
next := p.pos // estimated/accurate position of next item
wroteNewline, droppedFF := p.flush(next, p.lastTok)
// intersperse extra newlines if present in the source and
// if they don't cause extra semicolons (don't do this in
// flush as it will cause extra newlines at the end of a file)
if !p.impliedSemi {
n := nlimit(next.Line - p.pos.Line)
// don't exceed maxNewlines if we already wrote one
if wroteNewline && n == maxNewlines {
n = maxNewlines - 1
}
if n > 0 {
ch := byte('\n')
if droppedFF {
ch = '\f' // use formfeed since we dropped one before
}
p.writeByte(ch, n)
impliedSemi = false
}
}
// the next token starts now - record its line number if requested
if p.linePtr != nil {
*p.linePtr = p.out.Line
p.linePtr = nil
}
p.writeString(next, data, isLit)
p.impliedSemi = impliedSemi
}
}
// flush prints any pending comments and whitespace occurring textually
// before the position of the next token tok. The flush result indicates
// if a newline was written or if a formfeed was dropped from the whitespace
// buffer.
func (p *printer) flush(next token.Position, tok token.Token) (wroteNewline, droppedFF bool) {
if p.commentBefore(next) {
// if there are comments before the next item, intersperse them
wroteNewline, droppedFF = p.intersperseComments(next, tok)
} else {
// otherwise, write any leftover whitespace
p.writeWhitespace(len(p.wsbuf))
}
return
}
// getDoc returns the ast.CommentGroup associated with n, if any.
func getDoc(n ast.Node) *ast.CommentGroup {
switch n := n.(type) {
case *ast.Field:
return n.Doc
case *ast.ImportSpec:
return n.Doc
case *ast.ValueSpec:
return n.Doc
case *ast.TypeSpec:
return n.Doc
case *ast.GenDecl:
return n.Doc
case *ast.FuncDecl:
return n.Doc
case *ast.File:
return n.Doc
}
return nil
}
func getLastComment(n ast.Node) *ast.CommentGroup {
switch n := n.(type) {
case *ast.Field:
return n.Comment
case *ast.ImportSpec:
return n.Comment
case *ast.ValueSpec:
return n.Comment
case *ast.TypeSpec:
return n.Comment
case *ast.GenDecl:
if len(n.Specs) > 0 {
return getLastComment(n.Specs[len(n.Specs)-1])
}
case *ast.File:
if len(n.Comments) > 0 {
return n.Comments[len(n.Comments)-1]
}
}
return nil
}
func (p *printer) printNode(node any) error {
// unpack *CommentedNode, if any
var comments []*ast.CommentGroup
if cnode, ok := node.(*CommentedNode); ok {
node = cnode.Node
comments = cnode.Comments
}
if comments != nil {
// commented node - restrict comment list to relevant range
n, ok := node.(ast.Node)
if !ok {
goto unsupported
}
beg := n.Pos()
end := n.End()
// if the node has associated documentation,
// include that commentgroup in the range
// (the comment list is sorted in the order
// of the comment appearance in the source code)
if doc := getDoc(n); doc != nil {
beg = doc.Pos()
}
if com := getLastComment(n); com != nil {
if e := com.End(); e > end {
end = e
}
}
// token.Pos values are global offsets, we can
// compare them directly
i := 0
for i < len(comments) && comments[i].End() < beg {
i++
}
j := i
for j < len(comments) && comments[j].Pos() < end {
j++
}
if i < j {
p.comments = comments[i:j]
}
} else if n, ok := node.(*ast.File); ok {
// use ast.File comments, if any
p.comments = n.Comments
}
// if there are no comments, use node comments
p.useNodeComments = p.comments == nil
// get comments ready for use
p.nextComment()
p.print(pmode(0))
// format node
switch n := node.(type) {
case ast.Expr:
p.expr(n)
case ast.Stmt:
// A labeled statement will un-indent to position the label.
// Set p.indent to 1 so we don't get indent "underflow".
if _, ok := n.(*ast.LabeledStmt); ok {
p.indent = 1
}
p.stmt(n, false, false)
case ast.Decl:
p.decl(n)
case ast.Spec:
p.spec(n, 1, false, token.EOF)
case []ast.Stmt:
// A labeled statement will un-indent to position the label.
// Set p.indent to 1 so we don't get indent "underflow".
for _, s := range n {
if _, ok := s.(*ast.LabeledStmt); ok {
p.indent = 1
}
}
p.stmtList(n, 0, false)
case []ast.Decl:
p.declList(n)
case *ast.File:
p.file(n)
default:
goto unsupported
}
return p.sourcePosErr
unsupported:
return fmt.Errorf("go/printer: unsupported node type %T", node)
}
// ----------------------------------------------------------------------------
// Trimmer
// A trimmer is an io.Writer filter for stripping tabwriter.Escape
// characters, trailing blanks and tabs, and for converting formfeed
// and vtab characters into newlines and htabs (in case no tabwriter
// is used). Text bracketed by tabwriter.Escape characters is passed
// through unchanged.
type trimmer struct {
output io.Writer
state int
space []byte
}
// trimmer is implemented as a state machine.
// It can be in one of the following states:
const (
inSpace = iota // inside space
inEscape // inside text bracketed by tabwriter.Escapes
inText // inside text
)
func (p *trimmer) resetSpace() {
p.state = inSpace
p.space = p.space[0:0]
}
// Design note: It is tempting to eliminate extra blanks occurring in
// whitespace in this function as it could simplify some
// of the blanks logic in the node printing functions.
// However, this would mess up any formatting done by
// the tabwriter.
var aNewline = []byte("\n")
func (p *trimmer) Write(data []byte) (n int, err error) {
// invariants:
// p.state == inSpace:
// p.space is unwritten
// p.state == inEscape, inText:
// data[m:n] is unwritten
m := 0
var b byte
for n, b = range data {
if b == '\v' {
b = '\t' // convert to htab
}
switch p.state {
case inSpace:
switch b {
case '\t', ' ':
p.space = append(p.space, b)
case '\n', '\f':
p.resetSpace() // discard trailing space
_, err = p.output.Write(aNewline)
case tabwriter.Escape:
_, err = p.output.Write(p.space)
p.state = inEscape
m = n + 1 // +1: skip tabwriter.Escape
default:
_, err = p.output.Write(p.space)
p.state = inText
m = n
}
case inEscape:
if b == tabwriter.Escape {
_, err = p.output.Write(data[m:n])
p.resetSpace()
}
case inText:
switch b {
case '\t', ' ':
_, err = p.output.Write(data[m:n])
p.resetSpace()
p.space = append(p.space, b)
case '\n', '\f':
_, err = p.output.Write(data[m:n])
p.resetSpace()
if err == nil {
_, err = p.output.Write(aNewline)
}
case tabwriter.Escape:
_, err = p.output.Write(data[m:n])
p.state = inEscape
m = n + 1 // +1: skip tabwriter.Escape
}
default:
panic("unreachable")
}
if err != nil {
return
}
}
n = len(data)
switch p.state {
case inEscape, inText:
_, err = p.output.Write(data[m:n])
p.resetSpace()
}
return
}
// ----------------------------------------------------------------------------
// Public interface
// A Mode value is a set of flags (or 0). They control printing.
type Mode uint
const (
RawFormat Mode = 1 << iota // do not use a tabwriter; if set, UseSpaces is ignored
TabIndent // use tabs for indentation independent of UseSpaces
UseSpaces // use spaces instead of tabs for alignment
SourcePos // emit //line directives to preserve original source positions
)
// The mode below is not included in printer's public API because
// editing code text is deemed out of scope. Because this mode is
// unexported, it's also possible to modify or remove it based on
// the evolving needs of go/format and cmd/gofmt without breaking
// users. See discussion in CL 240683.
const (
// normalizeNumbers means to canonicalize number
// literal prefixes and exponents while printing.
//
// This value is known in and used by go/format and cmd/gofmt.
// It is currently more convenient and performant for those
// packages to apply number normalization during printing,
// rather than by modifying the AST in advance.
normalizeNumbers Mode = 1 << 30
)
// A PrintConfig node controls the output of Fprint.
type PrintConfig struct {
Mode Mode // default: 0
Tabwidth int // default: 8
Indent int // default: 0 (all code is indented at least by this much)
GoToSL *State // gosl:
ExcludeFunctions map[string]bool
}
var printerPool = sync.Pool{
New: func() any {
return &printer{
// Whitespace sequences are short.
wsbuf: make([]whiteSpace, 0, 16),
// We start the printer with a 16K output buffer, which is currently
// larger than about 80% of Go files in the standard library.
output: make([]byte, 0, 16<<10),
}
},
}
func newPrinter(cfg *PrintConfig, pkg *packages.Package, nodeSizes map[ast.Node]int) *printer {
p := printerPool.Get().(*printer)
*p = printer{
PrintConfig: *cfg,
pkg: pkg,
fset: pkg.Fset,
pos: token.Position{Line: 1, Column: 1},
out: token.Position{Line: 1, Column: 1},
wsbuf: p.wsbuf[:0],
nodeSizes: nodeSizes,
cachedPos: -1,
output: p.output[:0],
}
return p
}
func (p *printer) free() {
// Hard limit on buffer size; see https://golang.org/issue/23199.
if cap(p.output) > 64<<10 {
return
}
printerPool.Put(p)
}
// fprint implements Fprint and takes a nodesSizes map for setting up the printer state.
func (cfg *PrintConfig) fprint(output io.Writer, pkg *packages.Package, node any, nodeSizes map[ast.Node]int) (err error) {
// print node
p := newPrinter(cfg, pkg, nodeSizes)
defer p.free()
if err = p.printNode(node); err != nil {
return
}
// print outstanding comments
p.impliedSemi = false // EOF acts like a newline
p.flush(token.Position{Offset: infinity, Line: infinity}, token.EOF)
// output is buffered in p.output now.
// fix //go:build and // +build comments if needed.
p.fixGoBuildLines()
// redirect output through a trimmer to eliminate trailing whitespace
// (Input to a tabwriter must be untrimmed since trailing tabs provide
// formatting information. The tabwriter could provide trimming
// functionality but no tabwriter is used when RawFormat is set.)
output = &trimmer{output: output}
// redirect output through a tabwriter if necessary
if cfg.Mode&RawFormat == 0 {
minwidth := cfg.Tabwidth
padchar := byte('\t')
if cfg.Mode&UseSpaces != 0 {
padchar = ' '
}
twmode := tabwriter.DiscardEmptyColumns
if cfg.Mode&TabIndent != 0 {
minwidth = 0
twmode |= tabwriter.TabIndent
}
output = tabwriter.NewWriter(output, minwidth, cfg.Tabwidth, 1, padchar, twmode)
}
// write printer result via tabwriter/trimmer to output
if _, err = output.Write(p.output); err != nil {
return
}
// flush tabwriter, if any
if tw, _ := output.(*tabwriter.Writer); tw != nil {
err = tw.Flush()
}
return
}
// A CommentedNode bundles an AST node and corresponding comments.
// It may be provided as argument to any of the [Fprint] functions.
type CommentedNode struct {
Node any // *ast.File, or ast.Expr, ast.Decl, ast.Spec, or ast.Stmt
Comments []*ast.CommentGroup
}
// Fprint "pretty-prints" an AST node to output for a given configuration cfg.
// Position information is interpreted relative to the file set fset.
// The node type must be *[ast.File], *[CommentedNode], [][ast.Decl], [][ast.Stmt],
// or assignment-compatible to [ast.Expr], [ast.Decl], [ast.Spec], or [ast.Stmt].
func (cfg *PrintConfig) Fprint(output io.Writer, pkg *packages.Package, node any) error {
return cfg.fprint(output, pkg, node, make(map[ast.Node]int))
}
// Fprint "pretty-prints" an AST node to output.
// It calls [PrintConfig.Fprint] with default settings.
// Note that gofmt uses tabs for indentation but spaces for alignment;
// use format.Node (package go/format) for output that matches gofmt.
func Fprint(output io.Writer, pkg *packages.Package, node any) error {
return (&PrintConfig{Tabwidth: 8}).Fprint(output, pkg, node)
}
// Copyright 2024 Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
import (
"bytes"
)
// MoveLines moves the st,ed region to 'to' line
func MoveLines(lines *[][]byte, to, st, ed int) {
mvln := (*lines)[st:ed]
btwn := (*lines)[to:st]
aft := (*lines)[ed:len(*lines)]
nln := make([][]byte, to, len(*lines))
copy(nln, (*lines)[:to])
nln = append(nln, mvln...)
nln = append(nln, btwn...)
nln = append(nln, aft...)
*lines = nln
}
// SlEdits performs post-generation edits for wgsl,
// replacing type names, slbool, function calls, etc.
// returns true if a slrand. or sltype. prefix was found,
// driveing copying of those files.
func SlEdits(src []byte) (lines [][]byte, hasSlrand bool, hasSltype bool) {
nl := []byte("\n")
lines = bytes.Split(src, nl)
hasSlrand, hasSltype = SlEditsReplace(lines)
return
}
type Replace struct {
From, To []byte
}
var Replaces = []Replace{
{[]byte("sltype.Uint32Vec2"), []byte("vec2<u32>")},
{[]byte("sltype.Float32Vec2"), []byte("vec2<f32>")},
{[]byte("slvec.Vector2i"), []byte("vec4<i32>")},
{[]byte("slvec.Vector2"), []byte("vec4<f32>")},
{[]byte("slvec.Vector3"), []byte("vec4<f32>")},
{[]byte("math32.Vector2i"), []byte("vec2<i32>")},
{[]byte("math32.Vector2"), []byte("vec2<f32>")},
{[]byte("math32.Vector3"), []byte("vec3<f32>")},
{[]byte("math32.Vector4"), []byte("vec4<f32>")},
{[]byte("math32.Matrix2"), []byte("mat2x3f")},
{[]byte("math32.Matrix3"), []byte("mat3x3f")},
{[]byte("math32.Matrix4"), []byte("mat4x4f")},
{[]byte("math32.Quat"), []byte("vec4<f32>")},
{[]byte("math32.Vec2i"), []byte("vec2<i32>")},
{[]byte("math32.Vec2"), []byte("vec2<f32>")},
{[]byte("math32.Vec3"), []byte("vec3<f32>")},
{[]byte("math32.Vec4"), []byte("vec4<f32>")},
{[]byte("math32.NewQuat"), []byte("vec4<f32>")},
{[]byte("math32.Mat3"), []byte("mat3x3f")},
{[]byte(".Values["), []byte("[")},
{[]byte("float32"), []byte("f32")},
{[]byte("float64"), []byte("f64")}, // TODO: not yet supported
{[]byte("uint32"), []byte("u32")},
{[]byte("uint64"), []byte("su64")},
{[]byte("int32"), []byte("i32")},
{[]byte("math32.FastExp("), []byte("FastExp(")}, // FastExp about same speed, numerically identical
// {[]byte("math32.FastExp("), []byte("exp(")}, // exp is slightly faster it seems
{[]byte("math.Float32frombits("), []byte("bitcast<f32>(")},
{[]byte("math.Float32bits("), []byte("bitcast<u32>(")},
{[]byte("shaders."), []byte("")},
{[]byte("slrand."), []byte("Rand")},
{[]byte("RandUi32"), []byte("RandUint32")}, // fix int32 -> i32
{[]byte(".SetFromVector2("), []byte("=(")},
{[]byte(".SetFrom2("), []byte("=(")},
{[]byte(".IsTrue()"), []byte("==1")},
{[]byte(".IsFalse()"), []byte("==0")},
{[]byte(".SetBool(true)"), []byte("=1")},
{[]byte(".SetBool(false)"), []byte("=0")},
{[]byte(".SetBool("), []byte("=i32(")},
{[]byte("slbool.Bool"), []byte("i32")},
{[]byte("slbool.True"), []byte("1")},
{[]byte("slbool.False"), []byte("0")},
{[]byte("slbool.IsTrue("), []byte("(1 == ")},
{[]byte("slbool.IsFalse("), []byte("(0 == ")},
{[]byte("slbool.FromBool("), []byte("i32(")},
{[]byte("bools.ToFloat32("), []byte("f32(")},
{[]byte("bools.FromFloat32("), []byte("bool(")},
{[]byte("num.FromBool[f32]("), []byte("f32(")},
{[]byte("num.ToBool("), []byte("bool(")},
{[]byte("sltype."), []byte("")},
}
func MathReplaceAll(mat, ln []byte) []byte {
ml := len(mat)
st := 0
for {
sln := ln[st:]
i := bytes.Index(sln, mat)
if i < 0 {
return ln
}
fl := ln[st+i+ml : st+i+ml+1]
dl := bytes.ToLower(fl)
el := ln[st+i+ml+1:]
ln = append(ln[:st+i], dl...)
ln = append(ln, el...)
st += i + 1
}
}
func SlRemoveComments(lines [][]byte) [][]byte {
comm := []byte("//")
olns := make([][]byte, 0, len(lines))
for _, ln := range lines {
ts := bytes.TrimSpace(ln)
if len(ts) == 0 {
continue
}
if bytes.HasPrefix(ts, comm) {
continue
}
olns = append(olns, ln)
}
return olns
}
// SlEditsReplace replaces Go with equivalent WGSL code
// returns true if has slrand. or sltype.
// to auto include that header file if so.
func SlEditsReplace(lines [][]byte) (bool, bool) {
mt32 := []byte("math32.")
mth := []byte("math.")
slr := []byte("slrand.")
styp := []byte("sltype.")
include := []byte("#include")
hasSlrand := false
hasSltype := false
for li, ln := range lines {
if bytes.Contains(ln, include) {
continue
}
if !hasSlrand && bytes.Contains(ln, slr) {
hasSlrand = true
}
if !hasSltype && bytes.Contains(ln, styp) {
hasSltype = true
}
for _, r := range Replaces {
ln = bytes.ReplaceAll(ln, r.From, r.To)
}
ln = MathReplaceAll(mt32, ln)
ln = MathReplaceAll(mth, ln)
lines[li] = ln
}
return hasSlrand, hasSltype
}
var SLBools = []Replace{
{[]byte(".IsTrue()"), []byte("==1")},
{[]byte(".IsFalse()"), []byte("==0")},
{[]byte(".SetBool(true)"), []byte("=1")},
{[]byte(".SetBool(false)"), []byte("=0")},
{[]byte(".SetBool("), []byte("=int32(")},
{[]byte("slbool.Bool"), []byte("int32")},
{[]byte("slbool.True"), []byte("1")},
{[]byte("slbool.False"), []byte("0")},
{[]byte("slbool.IsTrue("), []byte("(1 == ")},
{[]byte("slbool.IsFalse("), []byte("(0 == ")},
{[]byte("slbool.FromBool("), []byte("int32(")},
{[]byte("bools.ToFloat32("), []byte("float32(")},
{[]byte("bools.FromFloat32("), []byte("bool(")},
{[]byte("num.FromBool[f32]("), []byte("float32(")},
{[]byte("num.ToBool("), []byte("bool(")},
}
// SlBoolReplace replaces all the slbool methods with literal int32 expressions.
func SlBoolReplace(lines [][]byte) {
for li, ln := range lines {
for _, r := range SLBools {
ln = bytes.ReplaceAll(ln, r.From, r.To)
}
lines[li] = ln
}
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gotosl
import (
"bytes"
"fmt"
"go/ast"
"go/token"
"log"
"os"
"os/exec"
"path/filepath"
"sort"
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/gosl/alignsl"
"golang.org/x/exp/maps"
"golang.org/x/tools/go/packages"
)
// TranslateDir translate all .Go files in given directory to WGSL.
func (st *State) TranslateDir(pf string) error {
pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName | packages.NeedFiles | packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesSizes | packages.NeedTypesInfo}, pf)
// pkgs, err := packages.Load(&packages.Config{Mode: packages.LoadAllSyntax}, pf)
if err != nil {
return errors.Log(err)
}
if len(pkgs) != 1 {
err := fmt.Errorf("More than one package for path: %v", pf)
return errors.Log(err)
}
pkg := pkgs[0]
if len(pkg.GoFiles) == 0 {
err := fmt.Errorf("No Go files found in package: %+v", pkg)
return errors.Log(err)
}
// fmt.Printf("go files: %+v", pkg.GoFiles)
// return nil, err
files := pkg.GoFiles
st.FuncGraph = make(map[string]*Function)
st.GetFuncGraph = true
doFile := func(gofp string, buf *bytes.Buffer) {
_, gofn := filepath.Split(gofp)
if st.Config.Debug {
fmt.Printf("###################################\nTranslating Go file: %s\n", gofn)
}
var afile *ast.File
var fpos token.Position
for _, sy := range pkg.Syntax {
pos := pkg.Fset.Position(sy.Package)
_, posfn := filepath.Split(pos.Filename)
if posfn == gofn {
fpos = pos
afile = sy
break
}
}
if afile == nil {
fmt.Printf("Warning: File named: %s not found in Loaded package\n", gofn)
return
}
pcfg := PrintConfig{GoToSL: st, Mode: printerMode, Tabwidth: tabWidth, ExcludeFunctions: st.ExcludeMap}
pcfg.Fprint(buf, pkg, afile)
if !st.GetFuncGraph && !st.Config.Keep {
os.Remove(fpos.Filename)
}
}
// first pass is just to get the call graph:
for fn := range st.GoVarsFiles { // do varsFiles first!!
var buf bytes.Buffer
doFile(fn, &buf)
}
st.GenGPU(true) // generate an initial gosl.go in imports, so Go doesn't get confused
pkgs, err = packages.Load(&packages.Config{Mode: packages.NeedName | packages.NeedFiles | packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesSizes | packages.NeedTypesInfo}, pf)
pkg = pkgs[0]
files = pkg.GoFiles
for _, gofp := range files {
_, gofn := filepath.Split(gofp)
if _, ok := st.GoVarsFiles[gofn]; ok {
continue
}
if gofn == "gosl.go" {
if !st.Config.Keep {
os.Remove(gofp)
}
continue
}
var buf bytes.Buffer
doFile(gofp, &buf)
}
// st.PrintFuncGraph()
doKernelFile := func(fname string, lines [][]byte) ([][]byte, bool, bool) {
_, gofn := filepath.Split(fname)
var buf bytes.Buffer
doFile(fname, &buf)
slfix, hasSlrand, hasSltype := SlEdits(buf.Bytes())
slfix = SlRemoveComments(slfix)
exsl := st.ExtractWGSL(slfix)
lines = append(lines, []byte(""))
lines = append(lines, []byte(fmt.Sprintf("//////// import: %q", gofn)))
lines = append(lines, exsl...)
return lines, hasSlrand, hasSltype
}
// next pass is per kernel
st.GetFuncGraph = false
maxVarsUsed := 0
nOverBase := 0
sys := maps.Keys(st.Systems)
sort.Strings(sys)
for _, snm := range sys {
sy := st.Systems[snm]
kns := maps.Keys(sy.Kernels)
sort.Strings(kns)
for _, knm := range kns {
kn := sy.Kernels[knm]
st.KernelFuncs = st.AllFuncs(kn.Name)
if st.KernelFuncs == nil {
continue
}
st.CurKernel = kn
var hasSlrand, hasSltype, hasR, hasT bool
nvars := 0
kn.Atomics, kn.VarsUsed, nvars = st.VarsUsed(st.KernelFuncs)
maxVarsUsed = max(maxVarsUsed, nvars)
fmt.Printf("###################################\nTranslating Kernel file: %s NVars: %d (atomic: %d)\n", kn.Name, nvars, len(kn.Atomics))
if nvars > 10 { // todo: change when limit is raised to 16
fmt.Println("WARNING: NVars exceeds maxStorageBuffersPerShaderStage min of 10")
nOverBase++
}
hdr := st.GenKernelHeader(sy, kn)
lines := bytes.Split([]byte(hdr), []byte("\n"))
for fn := range st.GoVarsFiles { // do varsFiles first!!
lines, hasR, hasT = doKernelFile(fn, lines)
if hasR {
hasSlrand = true
}
if hasT {
hasSltype = true
}
}
for _, gofp := range files {
_, gofn := filepath.Split(gofp)
if _, ok := st.GoVarsFiles[gofn]; ok {
continue
}
if gofn == "gosl.go" {
continue
}
lines, hasR, hasT = doKernelFile(gofp, lines)
if hasR {
hasSlrand = true
}
if hasT {
hasSltype = true
}
}
if hasSlrand {
st.CopyPackageFile("slrand.wgsl", "cogentcore.org/lab/gosl/slrand")
hasSltype = true
}
if hasSltype {
st.CopyPackageFile("sltype.wgsl", "cogentcore.org/lab/gosl/sltype")
}
for _, im := range st.SLImportFiles {
if im.Name == "gosl.go" {
continue
}
lines = append(lines, []byte(""))
lines = append(lines, []byte(fmt.Sprintf("//////// import: %q", im.Name)))
lines = append(lines, im.Lines...)
}
kn.Lines = lines
kfn := kn.Name + ".wgsl"
fn := filepath.Join(st.Config.Output, kfn)
kn.Filename = fn
WriteFileLines(fn, lines)
st.CompileFile(kfn)
}
}
fmt.Println("\n###################################\nMaximum number of variables used per shader:", maxVarsUsed)
if nOverBase > 0 {
fmt.Printf("WARNING: %d shaders exceed maxStorageBuffersPerShaderStage min of 10\n", nOverBase)
}
//////// check types
structTypes := make(map[string]bool)
for nm := range st.VarStructTypes {
structTypes[nm] = true
}
serr := alignsl.CheckPackage(pkg, structTypes)
if serr != nil {
fmt.Println(serr)
}
return nil
}
var (
nagaWarned = false
tintWarned = false
)
func (st *State) CompileFile(fn string) error {
dir, _ := filepath.Abs(st.Config.Output)
if _, err := exec.LookPath("naga"); err == nil {
// cmd := exec.Command("naga", "--compact", fn, fn) // produces some pretty weird code actually
cmd := exec.Command("naga", fn)
cmd.Dir = dir
out, err := cmd.CombinedOutput()
fmt.Printf("\n-----------------------------------------------------\nnaga output for: %s\n%s", fn, out)
if err != nil {
log.Println(err)
return err
}
} else {
if !nagaWarned {
fmt.Println("\nImportant: you should install the 'naga' WGSL compiler from https://github.com/gfx-rs/wgpu to get immediate validation")
nagaWarned = true
}
}
if _, err := exec.LookPath("tint"); err == nil {
cmd := exec.Command("tint", "--validate", "--format", "wgsl", "-o", "/dev/null", fn)
cmd.Dir = dir
out, err := cmd.CombinedOutput()
fmt.Printf("\n-----------------------------------------------------\ntint output for: %s\n%s", fn, out)
if err != nil {
log.Println(err)
return err
}
} else {
if !tintWarned {
fmt.Println("\nImportant: you should install the 'tint' WGSL compiler from https://dawn.googlesource.com/dawn/ to get immediate validation")
tintWarned = true
}
}
return nil
}
// Copyright (c) 2019, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
package slbool defines a WGSL friendly int32 Bool type.
The standard WGSL bool type causes obscure errors,
and the int32 obeys the 4 byte basic alignment requirements.
gosl automatically converts this Go code into appropriate WGSL code.
*/
package slbool
// Bool is an WGSL friendly int32 Bool type.
type Bool int32
const (
// False is the [Bool] false value
False Bool = 0
// True is the [Bool] true value
True Bool = 1
)
// Bool returns the Bool as a standard Go bool
func (b Bool) Bool() bool {
return b == True
}
// IsTrue returns whether the bool is true
func (b Bool) IsTrue() bool {
return b == True
}
// IsFalse returns whether the bool is false
func (b Bool) IsFalse() bool {
return b == False
}
// SetBool sets the Bool from a standard Go bool
func (b *Bool) SetBool(bb bool) {
*b = FromBool(bb)
}
// String returns the bool as a string ("true"/"false")
func (b Bool) String() string {
if b.IsTrue() {
return "true"
}
return "false"
}
// FromString sets the bool from the given string
func (b *Bool) FromString(s string) {
if s == "true" || s == "True" {
b.SetBool(true)
} else {
b.SetBool(false)
}
}
// MarshalText implements the [encoding/text.Marshaler] interface
func (b Bool) MarshalText() ([]byte, error) { return []byte(b.String()), nil }
// UnmarshalText implements the [encoding/text.Unmarshaler] interface
func (b *Bool) UnmarshalText(s []byte) error { b.FromString(string(s)); return nil }
// IsTrue returns whether the given bool is true
func IsTrue(b Bool) bool {
return b == True
}
// IsFalse returns whether the given bool is false
func IsFalse(b Bool) bool {
return b == False
}
// FromBool returns the given Go bool as a [Bool]
func FromBool(b bool) Bool {
if b {
return True
}
return False
}
// Copyright (c) 2023, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package slboolcore
import (
"cogentcore.org/core/core"
"cogentcore.org/lab/gosl/slbool"
)
func init() {
core.AddValueType[slbool.Bool, core.Switch]()
}
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package slmath
//gosl:start
const Pi = 3.141592653589793
// MinAngleDiff returns the minimum difference between two angles
// (in radians): a-b, dealing with the wrap-around issues with angles.
func MinAngleDiff(a, b float32) float32 {
d := a - b
if d > Pi {
d -= 2 * Pi
}
if d < -Pi {
d += 2 * Pi
}
return d
}
//gosl:end
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package slmath
import "cogentcore.org/core/math32"
//gosl:start
// QuatLength returns the length of this quaternion.
func QuatLength(q math32.Quat) float32 {
return math32.Sqrt(q.X*q.X + q.Y*q.Y + q.Z*q.Z + q.W*q.W)
}
// QuatNormalize normalizes the quaternion.
func QuatNormalize(q math32.Quat) math32.Quat {
nq := q
l := QuatLength(q)
if l == 0 {
nq.X = 0
nq.Y = 0
nq.Z = 0
nq.W = 1
} else {
l = 1 / l
nq.X *= l
nq.Y *= l
nq.Z *= l
nq.W *= l
}
return nq
}
// MulQuatVector applies the rotation encoded in the [math32.Quat]
// to the [math32.Vector3].
func MulQuatVector(q math32.Quat, v math32.Vector3) math32.Vector3 {
xyz := math32.Vec3(q.X, q.Y, q.Z)
t := Cross3(xyz, v).MulScalar(2)
return v.Add(t.MulScalar(q.W)).Add(Cross3(xyz, t))
}
// MulQuatVectorInverse applies the inverse of the rotation encoded
// in the [math32.Quat] to the [math32.Vector3].
func MulQuatVectorInverse(q math32.Quat, v math32.Vector3) math32.Vector3 {
xyz := math32.Vec3(q.X, q.Y, q.Z)
t := Cross3(xyz, v).MulScalar(2)
return v.Sub(t.MulScalar(q.W)).Add(Cross3(xyz, t))
}
// MulQuats returns multiplication of a by b quaternions.
func MulQuats(a, b math32.Quat) math32.Quat {
// from http://www.euclideanspace.com/maths/algebra/realNormedAlgebra/quaternions/code/index.htm
var q math32.Quat
q.X = a.X*b.W + a.W*b.X + a.Y*b.Z - a.Z*b.Y
q.Y = a.Y*b.W + a.W*b.Y + a.Z*b.X - a.X*b.Z
q.Z = a.Z*b.W + a.W*b.Z + a.X*b.Y - a.Y*b.X
q.W = a.W*b.W - a.X*b.X - a.Y*b.Y - a.Z*b.Z
return q
}
// MulSpatialTransforms computes the equivalent of matrix multiplication for
// two quat-point spatial transforms: o = a * b
func MulSpatialTransforms(aP math32.Vector3, aQ math32.Quat, bP math32.Vector3, bQ math32.Quat, oP *math32.Vector3, oQ *math32.Quat) {
// rotate b by a and add a
*oP = MulQuatVector(aQ, bP).Add(aP)
*oQ = MulQuats(aQ, bQ)
}
// MulSpatialPoint applies quat-point spatial transform to given 3D point.
func MulSpatialPoint(xP math32.Vector3, xQ math32.Quat, p math32.Vector3) math32.Vector3 {
dp := MulQuatVector(xQ, p)
return dp.Add(xP)
}
func SpatialTransformInverse(p math32.Vector3, q math32.Quat, oP *math32.Vector3, oQ *math32.Quat) {
qi := QuatInverse(q)
*oP = Negate3(MulQuatVector(qi, p))
*oQ = qi
}
func QuatInverse(q math32.Quat) math32.Quat {
nq := q
nq.X *= -1
nq.Y *= -1
nq.Z *= -1
return QuatNormalize(nq)
}
func QuatDot(q, o math32.Quat) float32 {
return q.X*o.X + q.Y*o.Y + q.Z*o.Z + q.W*o.W
}
func QuatAdd(q math32.Quat, o math32.Quat) math32.Quat {
nq := q
nq.X += o.X
nq.Y += o.Y
nq.Z += o.Z
nq.W += o.W
return nq
}
func QuatMulScalar(q math32.Quat, s float32) math32.Quat {
nq := q
nq.X *= s
nq.Y *= s
nq.Z *= s
nq.W *= s
return nq
}
func QuatDim(v math32.Quat, dim int32) float32 {
if dim == 0 {
return v.X
}
if dim == 1 {
return v.Y
}
if dim == 2 {
return v.Z
}
return v.W
}
func QuatSetDim(v math32.Quat, dim int32, val float32) math32.Quat {
nv := v
if dim == 0 {
nv.X = val
}
if dim == 1 {
nv.Y = val
}
if dim == 3 {
nv.Z = val
}
if dim == 4 {
nv.W = val
}
return nv
}
func QuatToMatrix3(q math32.Quat) math32.Matrix3 {
var m math32.Matrix3
x := q.X
y := q.Y
z := q.Z
w := q.W
x2 := x + x
y2 := y + y
z2 := z + z
xx := x * x2
xy := x * y2
xz := x * z2
yy := y * y2
yz := y * z2
zz := z * z2
wx := w * x2
wy := w * y2
wz := w * z2
m[0] = 1 - (yy + zz)
m[3] = xy - wz
m[6] = xz + wy
m[1] = xy + wz
m[4] = 1 - (xx + zz)
m[7] = yz - wx
m[2] = xz - wy
m[5] = yz + wx
m[8] = 1 - (xx + yy)
return m
}
//gosl:end
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package slmath
import "cogentcore.org/core/math32"
//gosl:start
// DivSafe2 divides v by o elementwise, only where o != 0
func DivSafe2(v math32.Vector2, o math32.Vector2) math32.Vector2 {
nv := v
if o.X != 0 {
nv.X /= o.X
}
if o.Y != 0 {
nv.Y /= o.Y
}
return nv
}
func Negate2(v math32.Vector2) math32.Vector2 {
return math32.Vec2(-v.X, -v.Y)
}
// Length2 returns the length (magnitude) of this vector.
func Length2(v math32.Vector2) float32 {
return math32.Sqrt(v.X*v.X + v.Y*v.Y)
}
// LengthSquared2 returns the length squared of this vector.
func LengthSquared2(v math32.Vector2) float32 {
return v.X*v.X + v.Y*v.Y
}
func Dot2(v, o math32.Vector2) float32 {
return v.X*o.X + v.Y*o.Y
}
// Max2 returns max of this vector components vs. other vector.
func Max2(v, o math32.Vector2) math32.Vector2 {
return math32.Vec2(max(v.X, o.X), max(v.Y, o.Y))
}
// Min2 returns min of this vector components vs. other vector.
func Min2(v, o math32.Vector2) math32.Vector2 {
return math32.Vec2(min(v.X, o.X), min(v.Y, o.Y))
}
// Abs2 returns abs of this vector components.
func Abs2(v math32.Vector2) math32.Vector2 {
return math32.Vec2(math32.Abs(v.X), math32.Abs(v.Y))
}
func Clamp2(v, min, max math32.Vector2) math32.Vector2 {
r := v
if r.X < min.X {
r.X = min.X
} else if r.X > max.X {
r.X = max.X
}
if r.Y < min.Y {
r.Y = min.Y
} else if r.Y > max.Y {
r.Y = max.Y
}
return r
}
// Normal2 returns this vector divided by its length (its unit vector).
func Normal2(v math32.Vector2) math32.Vector2 {
return v.DivScalar(Length2(v))
}
// Cross2 returns the cross product of this vector with other.
func Cross2(v, o math32.Vector2) float32 {
return v.X*o.Y - v.Y*o.X
}
func Dim2(v math32.Vector2, dim int32) float32 {
if dim == 0 {
return v.X
}
if dim == 1 {
return v.Y
}
return 0
}
func SetDim2(v math32.Vector2, dim int32, val float32) math32.Vector2 {
nv := v
if dim == 0 {
nv.X = val
}
if dim == 1 {
nv.Y = val
}
return nv
}
//gosl:end
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package slmath
import "cogentcore.org/core/math32"
//gosl:start
// DivSafe3 divides v by o elementwise, only where o != 0
func DivSafe3(v math32.Vector3, o math32.Vector3) math32.Vector3 {
nv := v
if o.X != 0 {
nv.X /= o.X
}
if o.Y != 0 {
nv.Y /= o.Y
}
if o.Z != 0 {
nv.Z /= o.Z
}
return nv
}
func Negate3(v math32.Vector3) math32.Vector3 {
return math32.Vec3(-v.X, -v.Y, -v.Z)
}
// Length3 returns the length (magnitude) of this vector.
func Length3(v math32.Vector3) float32 {
return math32.Sqrt(v.X*v.X + v.Y*v.Y + v.Z*v.Z)
}
// LengthSquared3 returns the length squared of this vector.
func LengthSquared3(v math32.Vector3) float32 {
return v.X*v.X + v.Y*v.Y + v.Z*v.Z
}
func Dot3(v, o math32.Vector3) float32 {
return v.X*o.X + v.Y*o.Y + v.Z*o.Z
}
// Max3 returns max of this vector components vs. other vector.
func Max3(v, o math32.Vector3) math32.Vector3 {
return math32.Vec3(max(v.X, o.X), max(v.Y, o.Y), max(v.Z, o.Z))
}
// Min3 returns min of this vector components vs. other vector.
func Min3(v, o math32.Vector3) math32.Vector3 {
return math32.Vec3(min(v.X, o.X), min(v.Y, o.Y), min(v.Z, o.Z))
}
// Abs3 returns abs of this vector components.
func Abs3(v math32.Vector3) math32.Vector3 {
return math32.Vec3(math32.Abs(v.X), math32.Abs(v.Y), math32.Abs(v.Z))
}
func Clamp3(v, min, max math32.Vector3) math32.Vector3 {
r := v
if r.X < min.X {
r.X = min.X
} else if r.X > max.X {
r.X = max.X
}
if r.Y < min.Y {
r.Y = min.Y
} else if r.Y > max.Y {
r.Y = max.Y
}
if r.Z < min.Z {
r.Z = min.Z
} else if r.Z > max.Z {
r.Z = max.Z
}
return r
}
// ClampMagnitude3 clamps the magnitude of the components below given value.
func ClampMagnitude3(v math32.Vector3, mag float32) math32.Vector3 {
r := v
if r.X < -mag {
r.X = -mag
} else if r.X > mag {
r.X = mag
}
if r.Y < -mag {
r.Y = -mag
} else if r.Y > mag {
r.Y = mag
}
if r.Z < -mag {
r.Z = -mag
} else if r.Z > mag {
r.Z = mag
}
return r
}
// Normal3 returns this vector divided by its length (its unit vector).
func Normal3(v math32.Vector3) math32.Vector3 {
return v.DivScalar(Length3(v))
}
// Cross3 returns the cross product of this vector with other.
func Cross3(v, o math32.Vector3) math32.Vector3 {
return math32.Vec3(v.Y*o.Z-v.Z*o.Y, v.Z*o.X-v.X*o.Z, v.X*o.Y-v.Y*o.X)
}
func Dim3(v math32.Vector3, dim int32) float32 {
if dim == 0 {
return v.X
}
if dim == 1 {
return v.Y
}
return v.Z
}
func SetDim3(v math32.Vector3, dim int32, val float32) math32.Vector3 {
nv := v
if dim == 0 {
nv.X = val
}
if dim == 1 {
nv.Y = val
}
if dim == 2 {
nv.Z = val
}
return nv
}
//gosl:end
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package slrand
import (
"cogentcore.org/core/math32"
"cogentcore.org/lab/gosl/sltype"
)
// These are Go versions of the same Philox2x32 based random number generator
// functions available in .WGSL.
// Philox2x32round does one round of updating of the counter.
func Philox2x32round(counter uint64, key uint32) uint64 {
ctr := sltype.Uint64ToLoHi(counter)
mul := sltype.Uint64ToLoHi(sltype.Uint32Mul64(0xD256D193, ctr.X))
ctr.X = mul.Y ^ key ^ ctr.Y
ctr.Y = mul.X
return sltype.Uint64FromLoHi(ctr)
}
// Philox2x32bumpkey does one round of updating of the key
func Philox2x32bumpkey(key uint32) uint32 {
return key + 0x9E3779B9
}
// Philox2x32 implements the stateless counter-based RNG algorithm
// returning a random number as two uint32 values, given a
// counter and key input that determine the result.
func Philox2x32(counter uint64, key uint32) sltype.Uint32Vec2 {
// this is an unrolled loop of 10 updates based on initial counter and key,
// which produces the random deviation deterministically based on these inputs.
counter = Philox2x32round(counter, key) // 1
key = Philox2x32bumpkey(key)
counter = Philox2x32round(counter, key) // 2
key = Philox2x32bumpkey(key)
counter = Philox2x32round(counter, key) // 3
key = Philox2x32bumpkey(key)
counter = Philox2x32round(counter, key) // 4
key = Philox2x32bumpkey(key)
counter = Philox2x32round(counter, key) // 5
key = Philox2x32bumpkey(key)
counter = Philox2x32round(counter, key) // 6
key = Philox2x32bumpkey(key)
counter = Philox2x32round(counter, key) // 7
key = Philox2x32bumpkey(key)
counter = Philox2x32round(counter, key) // 8
key = Philox2x32bumpkey(key)
counter = Philox2x32round(counter, key) // 9
key = Philox2x32bumpkey(key)
return sltype.Uint64ToLoHi(Philox2x32round(counter, key)) // 10
}
/////////
// Methods below provide a standard interface with more
// readable names, mapping onto the Go rand methods.
//
// They assume a global shared counter, which is then
// incremented by a function index, defined for each function
// consuming random numbers that _could_ be called within a parallel
// processing loop. At the end of the loop, the global counter should
// be incremented by the total possible number of such functions.
// This results in fully resproducible results, invariant to
// specific processing order, and invariant to whether any one function
// actually calls the random number generator.
// Uint32Vec2 returns two uniformly distributed 32 unsigned integers,
// based on given global shared counter, function index offset from that
// counter for this specific random number call, and key as unique
// index of the item being processed.
func Uint32Vec2(counter uint64, funcIndex uint32, key uint32) sltype.Uint32Vec2 {
return Philox2x32(sltype.Uint64Add32(counter, funcIndex), key)
}
// Uint32 returns a uniformly distributed 32 unsigned integer,
// based on given global shared counter, function index offset from that
// counter for this specific random number call, and key as unique
// index of the item being processed.
func Uint32(counter uint64, funcIndex uint32, key uint32) uint32 {
return Philox2x32(sltype.Uint64Add32(counter, funcIndex), key).X
}
// Float32Vec2 returns two uniformly distributed float32 values in range (0,1),
// based on given global shared counter, function index offset from that
// counter for this specific random number call, and key as unique
// index of the item being processed.
func Float32Vec2(counter uint64, funcIndex uint32, key uint32) sltype.Float32Vec2 {
return sltype.Uint32ToFloat32Vec2(Uint32Vec2(counter, funcIndex, key))
}
// Float32 returns a uniformly distributed float32 value in range (0,1),
// based on given global shared counter, function index offset from that
// counter for this specific random number call, and key as unique
// index of the item being processed.
func Float32(counter uint64, funcIndex uint32, key uint32) float32 {
return sltype.Uint32ToFloat32(Uint32(counter, funcIndex, key))
}
// Float32Range11Vec2 returns two uniformly distributed float32 values in range [-1,1],
// based on given global shared counter, function index offset from that
// counter for this specific random number call, and key as unique
// index of the item being processed.
func Float32Range11Vec2(counter uint64, funcIndex uint32, key uint32) sltype.Float32Vec2 {
return sltype.Uint32ToFloat32Vec2(Uint32Vec2(counter, funcIndex, key))
}
// Float32Range11 returns a uniformly distributed float32 value in range [-1,1],
// based on given global shared counter, function index offset from that
// counter for this specific random number call, and key as unique
// index of the item being processed.
func Float32Range11(counter uint64, funcIndex uint32, key uint32) float32 {
return sltype.Uint32ToFloat32Range11(Uint32(counter, funcIndex, key))
}
// BoolP returns a bool true value with probability p
// based on given global shared counter, function index offset from that
// counter for this specific random number call, and key as unique
// index of the item being processed.
func BoolP(counter uint64, funcIndex uint32, key uint32, p float32) bool {
return (Float32(counter, funcIndex, key) < p)
}
func SincosPi(x float32) sltype.Float32Vec2 {
const PIf = 3.1415926535897932
var r sltype.Float32Vec2
r.Y, r.X = math32.Sincos(PIf * x)
return r
}
// Float32NormVec2 returns two random float32 numbers
// distributed according to the normal, Gaussian distribution
// with zero mean and unit variance.
// This is done very efficiently using the Box-Muller algorithm
// that consumes two random 32 bit uint values.
// Uses given global shared counter, function index offset from that
// counter for this specific random number call, and key as unique
// index of the item being processed.
func Float32NormVec2(counter uint64, funcIndex uint32, key uint32) sltype.Float32Vec2 {
ur := Uint32Vec2(counter, funcIndex, key)
f := SincosPi(sltype.Uint32ToFloat32Range11(ur.X))
r := math32.Sqrt(-2.0 * math32.Log(sltype.Uint32ToFloat32(ur.Y))) // guaranteed to avoid 0.
return f.MulScalar(r)
}
// Float32Norm returns a random float32 number
// distributed according to the normal, Gaussian distribution
// with zero mean and unit variance.
// Uses given global shared counter, function index offset from that
// counter for this specific random number call, and key as unique
// index of the item being processed.
func Float32Norm(counter uint64, funcIndex uint32, key uint32) float32 {
return Float32Vec2(counter, funcIndex, key).X
}
// Uint32N returns a uint32 in the range [0,N).
// Uses given global shared counter, function index offset from that
// counter for this specific random number call, and key as unique
// index of the item being processed.
func Uint32N(counter uint64, funcIndex uint32, key uint32, n uint32) uint32 {
v := Float32(counter, funcIndex, key)
return uint32(v * float32(n))
}
// Counter is used for storing the random counter using aligned 16 byte storage,
// with convenience methods for typical use cases.
// It retains a copy of the last Seed value, which is applied to the Hi uint32 value.
type Counter struct {
// Counter value
Counter uint64
// last seed value set by Seed method, restored by Reset()
HiSeed uint32
pad uint32
}
// Reset resets counter to last set Seed state
func (ct *Counter) Reset() {
ct.Counter = sltype.Uint64FromLoHi(sltype.Uint32Vec2{0, ct.HiSeed})
}
// Seed sets the Hi uint32 value from given seed, saving it in HiSeed field.
// Each increment in seed generates a unique sequence of over 4 billion numbers,
// so it is reasonable to just use incremental values there, but more widely
// spaced numbers will result in longer unique sequences.
// Resets Lo to 0.
// This same seed will be restored during Reset
func (ct *Counter) Seed(seed uint32) {
ct.HiSeed = seed
ct.Reset()
}
// Add increments the counter by given amount.
// Call this after completing a pass of computation
// where the value passed here is the max of funcIndex+1
// used for any possible random calls during that pass.
func (ct *Counter) Add(inc uint32) {
ct.Counter = sltype.Uint64Add32(ct.Counter, inc)
}
// Copyright (c) 2022, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sltype
import "cogentcore.org/core/math32"
// Int32Vec2 is a length 2 vector of int32
type Int32Vec2 = math32.Vector2i
// Int32Vec3 is a length 3 vector of int32
type IntVec3 = math32.Vector3i
// Int32Vec4 is a length 4 vector of int32
type Int32Vec4 struct {
X int32
Y int32
Z int32
W int32
}
// Add returns the vector p+q.
func (p Int32Vec4) Add(q Int32Vec4) Int32Vec4 {
return Int32Vec4{p.X + q.X, p.Y + q.Y, p.Z + q.Z, p.W + q.W}
}
// Sub returns the vector p-q.
func (p Int32Vec4) Sub(q Int32Vec4) Int32Vec4 {
return Int32Vec4{p.X - q.X, p.Y - q.Y, p.Z - q.Z, p.W - q.W}
}
// MulScalar returns the vector p*k.
func (p Int32Vec4) MulScalar(k int32) Int32Vec4 {
return Int32Vec4{p.X * k, p.Y * k, p.Z * k, p.W * k}
}
// DivScalar returns the vector p/k.
func (p Int32Vec4) DivScalar(k int32) Int32Vec4 {
return Int32Vec4{p.X / k, p.Y / k, p.Z / k, p.W / k}
}
//////// Unsigned
// Uint32Vec2 is a length 2 vector of uint32
type Uint32Vec2 struct {
X uint32
Y uint32
}
// Uint32Vec3 is a length 3 vector of uint32
type Uint32Vec3 struct {
X uint32
Y uint32
Z uint32
}
// Uint32Vec4 is a length 4 vector of uint32
type Uint32Vec4 struct {
X uint32
Y uint32
Z uint32
W uint32
}
func (u *Uint32Vec4) SetFromVec2(u2 Uint32Vec2) {
u.X = u2.X
u.Y = u2.Y
u.Z = 0
u.W = 1
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sltype
import (
"math"
)
// Uint32Mul64 multiplies two uint32 numbers into a uint64.
func Uint32Mul64(a, b uint32) uint64 {
return uint64(a) * uint64(b)
}
// Uint64ToLoHi splits a uint64 number into lo and hi uint32 components.
func Uint64ToLoHi(a uint64) Uint32Vec2 {
var r Uint32Vec2
r.Y = uint32(a >> 32)
r.X = uint32(a)
return r
}
// Uint64FromLoHi combines lo and hi uint32 components into a uint64 value.
func Uint64FromLoHi(a Uint32Vec2) uint64 {
return uint64(a.X) + uint64(a.Y)<<32
}
// Uint64Add32 adds given uint32 number to given uint64.
func Uint64Add32(a uint64, b uint32) uint64 {
return a + uint64(b)
}
// Uint64Incr returns increment of the given uint64.
func Uint64Incr(a uint64) uint64 {
return a + 1
}
// Uint32ToFloat32 converts a uint32 integer into a float32
// in the (0,1) interval (i.e., exclusive of 1).
// This differs from the Go standard by excluding 0, which is handy for passing
// directly to Log function, and from the reference Philox code by excluding 1
// which is in the Go standard and most other standard RNGs.
func Uint32ToFloat32(val uint32) float32 {
const factor = float32(1.) / (float32(0xffffffff) + float32(1.))
const halffactor = float32(0.5) * factor
f := float32(val)*factor + halffactor
if f == 1 { // exclude 1
return math.Float32frombits(0x3F7FFFFF)
}
return f
}
// Uint32ToFloat32Vec2 converts two uint32 bit integers
// into two corresponding 32 bit f32 values
// in the (0,1) interval (i.e., exclusive of 1).
func Uint32ToFloat32Vec2(val Uint32Vec2) Float32Vec2 {
var r Float32Vec2
r.X = Uint32ToFloat32(val.X)
r.Y = Uint32ToFloat32(val.Y)
return r
}
// Uint32ToFloat32Range11 converts a uint32 integer into a float32
// in the [-1..1] interval (inclusive of -1 and 1, never identically == 0).
func Uint32ToFloat32Range11(val uint32) float32 {
const factor = float32(1.) / (float32(0x7fffffff) + float32(1.))
const halffactor = float32(0.5) * factor
return (float32(int32(val))*factor + halffactor)
}
// Uint32ToFloat32Range11Vec2 converts two uint32 integers into two float32
// in the [-1,1] interval (inclusive of -1 and 1, never identically == 0)
func Uint32ToFloat32Range11Vec2(val Uint32Vec2) Float32Vec2 {
var r Float32Vec2
r.X = Uint32ToFloat32Range11(val.X)
r.Y = Uint32ToFloat32Range11(val.Y)
return r
}
// Copyright 2025 Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package slvec
import "cogentcore.org/core/math32"
//gosl:start
// Vector2 is a 2D vector/point with X and Y components,
// with padding values so it works in a GPU struct. Use the
// V() method to get a math32.Vector2 that supports standard
// math operations, which are converted to direct ops in WGSL.
type Vector2 struct {
X float32
Y float32
pad, pad1 float32
}
func (v *Vector2) V() math32.Vector2 {
return math32.Vec2(v.X, v.Y)
}
func (v *Vector2) Set(x, y float32) {
v.X = x
v.Y = y
}
func (v *Vector2) SetV(mv math32.Vector2) {
v.X = mv.X
v.Y = mv.Y
}
// Vector2i is a 2D vector/point with X and Y integer components.
// with padding values so it works in a GPU struct. Use the
// V() method to get a math32.Vector2i that supports standard
// math operations. Cannot use those math ops in gosl GPU
// code at this point, unfortunately.
type Vector2i struct {
X int32
Y int32
pad, pad1 int32
}
func (v *Vector2i) V() math32.Vector2i {
return math32.Vec2i(int(v.X), int(v.Y))
}
func (v *Vector2i) Set(x, y int) {
v.X = int32(x)
v.Y = int32(y)
}
func (v *Vector2i) SetV(mv math32.Vector2i) {
v.X = mv.X
v.Y = mv.Y
}
// Vector3 is a 3DD vector/point with X, Y, Z components,
// with padding values so it works in a GPU struct. Use the
// V() method to get a math32.Vector3 that supports standard
// math operations, which are converted to direct ops in WGSL.
type Vector3 struct {
X float32
Y float32
Z float32
pad float32
}
func (v *Vector3) V() math32.Vector3 {
return math32.Vec3(v.X, v.Y, v.Z)
}
func (v *Vector3) Set(x, y, z float32) {
v.X = x
v.Y = y
v.Z = z
}
func (v *Vector3) SetV(mv math32.Vector3) {
v.X = mv.X
v.Y = mv.Y
v.Z = mv.Z
}
//gosl:end
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package threading provides a simple parallel run function. this will be moved elsewhere.
package threading
import (
"math"
"sync"
)
// Maps the given function across the [0, total) range of items, using
// nThreads goroutines.
func ParallelRun(fun func(st, ed int), total int, nThreads int) {
itemsPerThr := int(math.Ceil(float64(total) / float64(nThreads)))
waitGroup := sync.WaitGroup{}
for start := 0; start < total; start += itemsPerThr {
start := start // be extra sure with closure
end := min(start+itemsPerThr, total)
waitGroup.Add(1) // todo: move out of loop
go func() {
fun(start, end)
waitGroup.Done()
}()
}
waitGroup.Wait()
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package lab
import (
"io/fs"
"os"
"path/filepath"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/styles"
"cogentcore.org/core/tree"
)
// Basic is a basic data browser with the files as the left panel,
// and the Tabber as the right panel.
type Basic struct {
core.Frame
Browser
}
// Init initializes with the data and script directories
func (br *Basic) Init() {
br.Frame.Init()
br.Styler(func(s *styles.Style) {
s.Grow.Set(1, 1)
})
br.OnShow(func(e events.Event) {
br.UpdateFiles()
})
tree.AddChildAt(br, "splits", func(w *core.Splits) {
br.Splits = w
w.SetSplits(.15, .85)
tree.AddChildAt(w, "fileframe", func(w *core.Frame) {
w.Styler(func(s *styles.Style) {
s.Direction = styles.Column
s.Overflow.Set(styles.OverflowAuto)
s.Grow.Set(1, 1)
})
tree.AddChildAt(w, "filetree", func(w *DataTree) {
br.Files = w
})
})
tree.AddChildAt(w, "tabs", func(w *Tabs) {
br.Tabs = w
Lab = br.Tabs.AsLab()
})
})
br.Updater(func() {
if br.Files != nil {
br.Files.Tabber = br.Tabs
}
})
}
// NewBasicWindow returns a new Lab Browser window for given
// file system (nil for os files) and data directory.
// do RunWindow on resulting [core.Body] to open the window.
func NewBasicWindow(fsys fs.FS, dataDir string) (*core.Body, *Basic) {
startDir, _ := os.Getwd()
startDir = errors.Log1(filepath.Abs(startDir))
b := core.NewBody("Cogent Lab: " + fsx.DirAndFile(startDir))
br := NewBasic(b)
b.AddTopBar(func(bar *core.Frame) {
tb := core.NewToolbar(bar)
br.Toolbar = tb
tb.Maker(br.MakeToolbar)
})
br.FS = fsys
ddr := dataDir
if fsys == nil {
ddr = errors.Log1(filepath.Abs(dataDir))
}
br.SetDataRoot(ddr)
br.SetScriptsDir(filepath.Join(ddr, "labscripts"))
LabBrowser = &br.Browser
return b, br
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package lab
//go:generate core generate
import (
"fmt"
"io/fs"
"log/slog"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"unicode"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/icons"
"cogentcore.org/core/styles"
"cogentcore.org/core/tree"
"cogentcore.org/lab/goal/goalib"
"golang.org/x/exp/maps"
)
var (
// LabBrowser is the current Lab Browser, for yaegi / Go consistent access.
LabBrowser *Browser
// RunScript is set if labscripts is imported.
// Runs given script name on browser's interpreter.
RunScript func(br *Browser, scriptName string) error
// RunScriptCode is set if labscripts is imported.
// Runs given code on browser's interpreter.
RunScriptCode func(br *Browser, code string) error
)
// Browser holds all the elements of a data browser, for browsing data
// either on an OS filesystem or as a tensorfs virtual data filesystem.
// It supports the automatic loading of [goal] scripts as toolbar actions to
// perform pre-programmed tasks on the data, to create app-like functionality.
// Scripts are ordered alphabetically and any leading #- prefix is automatically
// removed from the label, so you can use numbers to specify a custom order.
// It is not a [core.Widget] itself, and is intended to be incorporated into
// a [core.Frame] widget, potentially along with other custom elements.
// See [Basic] for a basic implementation.
type Browser struct { //types:add -setters
// FS is the filesystem, if browsing an FS.
FS fs.FS
// DataRoot is the path to the root of the data to browse.
DataRoot string
// StartDir is the starting directory, where the app was originally started.
StartDir string
// ScriptsDir is the directory containing scripts for toolbar actions.
// It defaults to DataRoot/dbscripts
ScriptsDir string
// Scripts are interpreted goal scripts (via yaegi) to automate
// routine tasks.
Scripts map[string]string `set:"-"`
// Interpreter is the interpreter to use for running Browser scripts.
// is of type: *goal/interpreter.Interpreter but can't use that directly
// to avoid importing goal unless needed. Import [labscripts] if needed.
Interpreter any `set:"-"`
// Files is the [DataTree] tree browser of the tensorfs or files.
Files *DataTree
// Tabs is the [Tabs] element managing tabs of data views.
Tabs *Tabs
// Toolbar is the top-level toolbar for the browser, if used.
Toolbar *core.Toolbar
// Splits is the overall [core.Splits] for the browser.
Splits *core.Splits
}
// UpdateFiles Updates the files list.
func (br *Browser) UpdateFiles() { //types:add
if br.Files == nil {
return
}
files := br.Files
if br.FS != nil {
files.SortByModTime = true
files.OpenPathFS(br.FS, br.DataRoot)
} else {
files.OpenPath(br.DataRoot)
}
}
// UpdateScripts updates the Scripts and updates the toolbar.
func (br *Browser) UpdateScripts() { //types:add
redo := (br.Scripts != nil)
scr := fsx.Filenames(br.ScriptsDir, ".goal")
br.Scripts = make(map[string]string)
for _, s := range scr {
snm := strings.TrimSuffix(s, ".goal")
sc, err := os.ReadFile(filepath.Join(br.ScriptsDir, s))
if err == nil {
if unicode.IsLower(rune(snm[0])) {
if !redo {
fmt.Println("run init script:", snm)
if RunScriptCode != nil {
RunScriptCode(br, string(sc))
}
}
} else {
ssc := string(sc)
br.Scripts[snm] = ssc
}
} else {
slog.Error(err.Error())
}
}
if br.Toolbar != nil {
br.Toolbar.Update()
}
}
// MakeToolbar makes a default toolbar for the browser, with update files
// and update scripts buttons, followed by MakeScriptsToolbar for the scripts.
func (br *Browser) MakeToolbar(p *tree.Plan) {
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(br.UpdateFiles).SetText("").SetIcon(icons.Refresh).SetShortcut("Command+U")
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(br.UpdateScripts).SetText("").SetIcon(icons.Code)
})
br.MakeScriptsToolbar(p)
}
// MakeScriptsToolbar is a maker for adding buttons for each uppercase script
// to the toolbar.
func (br *Browser) MakeScriptsToolbar(p *tree.Plan) {
scr := maps.Keys(br.Scripts)
slices.Sort(scr)
for _, s := range scr {
lbl := TrimOrderPrefix(s)
tree.AddAt(p, lbl, func(w *core.Button) {
w.SetText(lbl).SetIcon(icons.RunCircle).
OnClick(func(e events.Event) {
if RunScript != nil {
RunScript(br, s)
}
})
sc := br.Scripts[s]
tt := FirstComment(sc)
if tt == "" {
tt = "Run Script (add a comment to top of script to provide more useful info here)"
}
w.SetTooltip(tt)
})
}
}
//////// Helpers
// FirstComment returns the first comment lines from given .goal file,
// which is used to set the tooltip for scripts.
func FirstComment(sc string) string {
sl := goalib.SplitLines(sc)
cmt := ""
for _, l := range sl {
if !strings.HasPrefix(l, "// ") {
return cmt
}
cmt += strings.TrimSpace(l[3:]) + " "
}
return cmt
}
// TrimOrderPrefix trims any optional #- prefix from given string,
// used for ordering items by name.
func TrimOrderPrefix(s string) string {
i := strings.Index(s, "-")
if i < 0 {
return s
}
ds := s[:i]
if _, err := strconv.Atoi(ds); err != nil {
return s
}
return s[i+1:]
}
// PromptOKCancel prompts the user for whether to do something,
// calling the given function if the user clicks OK.
func PromptOKCancel(ctx core.Widget, prompt string, fun func()) {
d := core.NewBody(prompt)
d.AddBottomBar(func(bar *core.Frame) {
d.AddCancel(bar)
d.AddOK(bar).OnClick(func(e events.Event) {
if fun != nil {
fun()
}
})
})
d.RunDialog(ctx)
}
// PromptString prompts the user for a string value (initial value given),
// calling the given function if the user clicks OK.
func PromptString(ctx core.Widget, str string, prompt string, fun func(s string)) {
d := core.NewBody(prompt)
tf := core.NewTextField(d).SetText(str)
tf.Styler(func(s *styles.Style) {
s.Min.X.Ch(60)
})
d.AddBottomBar(func(bar *core.Frame) {
d.AddCancel(bar)
d.AddOK(bar).OnClick(func(e events.Event) {
if fun != nil {
fun(tf.Text())
}
})
})
d.RunDialog(ctx)
}
// PromptStruct prompts the user for the values in given struct (pass a pointer),
// calling the given function if the user clicks OK.
func PromptStruct(ctx core.Widget, str any, prompt string, fun func()) {
d := core.NewBody(prompt)
core.NewForm(d).SetStruct(str)
d.AddBottomBar(func(bar *core.Frame) {
d.AddCancel(bar)
d.AddOK(bar).OnClick(func(e events.Event) {
if fun != nil {
fun()
}
})
})
d.RunDialog(ctx)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package lab
import (
"image"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/fileinfo"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/filetree"
"cogentcore.org/core/icons"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/states"
"cogentcore.org/core/text/diffbrowser"
"cogentcore.org/core/tree"
"cogentcore.org/core/types"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorfs"
)
// Treer is an interface for getting the Root node as a DataTree struct.
type Treer interface {
AsDataTree() *DataTree
}
// AsDataTree returns the given value as a [DataTree] if it has
// an AsDataTree() method, or nil otherwise.
func AsDataTree(n tree.Node) *DataTree {
if t, ok := n.(Treer); ok {
return t.AsDataTree()
}
return nil
}
// DataTree is the databrowser version of [filetree.Tree],
// which provides the Tabber to show data editors.
type DataTree struct {
filetree.Tree
// Tabber is the [Tabber] for this tree.
Tabber Tabber
}
func (ft *DataTree) AsDataTree() *DataTree {
return ft
}
func (ft *DataTree) Init() {
ft.Tree.Init()
ft.Root = ft
ft.FileNodeType = types.For[FileNode]()
}
// FileNode is databrowser version of FileNode for FileTree
type FileNode struct {
filetree.Node
}
func (fn *FileNode) Init() {
fn.Node.Init()
fn.AddContextMenu(fn.ContextMenu)
}
// Tabber returns the [Tabber] for this filenode, from root tree.
func (fn *FileNode) Tabber() Tabber {
fr := AsDataTree(fn.Root)
if fr != nil {
return fr.Tabber
}
return nil
}
func (fn *FileNode) WidgetTooltip(pos image.Point) (string, image.Point) {
res := fn.Tooltip
if fn.Info.Cat == fileinfo.Data {
ofn := fn.AsNode()
switch fn.Info.Known {
case fileinfo.Number, fileinfo.String:
dv := TensorFS(ofn)
v := dv.String()
if res != "" {
res += " "
}
res += v
}
}
return res, fn.DefaultTooltipPos()
}
// TensorFS returns the tensorfs representation of this item.
// returns nil if not a dataFS item.
func TensorFS(fn *filetree.Node) *tensorfs.Node {
dfs, ok := fn.FileRoot().FS.(*tensorfs.Node)
if !ok {
return nil
}
dfi, err := dfs.Stat(string(fn.Filepath))
if errors.Log(err) != nil {
return nil
}
return dfi.(*tensorfs.Node)
}
func (fn *FileNode) GetFileInfo() error {
err := fn.InitFileInfo()
if fn.FileRoot().FS == nil {
return err
}
d := TensorFS(fn.AsNode())
if d != nil {
fn.Info.Known = d.KnownFileInfo()
fn.Info.Cat = fileinfo.Data
switch fn.Info.Known {
case fileinfo.Tensor:
fn.Info.Ic = icons.BarChart
case fileinfo.Table:
fn.Info.Ic = icons.BarChart4Bars
case fileinfo.Number:
fn.Info.Ic = icons.Tag
case fileinfo.String:
fn.Info.Ic = icons.Title
default:
fn.Info.Ic = icons.BarChart
}
}
return err
}
func (fn *FileNode) OpenFile() error {
ofn := fn.AsNode()
ts := fn.Tabber()
if ts == nil {
return nil
}
df := fsx.DirAndFile(string(fn.Filepath))
switch {
case fn.IsDir():
d := TensorFS(ofn)
dt := tensorfs.DirTable(d, nil)
ts.AsLab().TensorTable(df, dt)
case fn.Info.Cat == fileinfo.Data:
switch fn.Info.Known {
case fileinfo.Tensor:
d := TensorFS(ofn)
ts.AsLab().TensorEditor(df, d.Tensor)
case fileinfo.Number:
dv := TensorFS(ofn)
if dv.Tensor.Len() == 0 {
core.MessageSnackbar(fn, "No data in tensor")
break
}
v := dv.Tensor.Float1D(0)
d := core.NewBody(df)
core.NewText(d).SetType(core.TextSupporting).SetText(df)
sp := core.NewSpinner(d).SetValue(float32(v))
d.AddBottomBar(func(bar *core.Frame) {
d.AddCancel(bar)
d.AddOK(bar).OnClick(func(e events.Event) {
dv.Tensor.SetFloat1D(float64(sp.Value), 0)
})
})
d.RunDialog(fn)
case fileinfo.String:
dv := TensorFS(ofn)
if dv.Tensor.Len() == 0 {
core.MessageSnackbar(fn, "No data in tensor")
break
}
v := dv.Tensor.String1D(0)
d := core.NewBody(df)
core.NewText(d).SetType(core.TextSupporting).SetText(df)
tf := core.NewTextField(d).SetText(v)
d.AddBottomBar(func(bar *core.Frame) {
d.AddCancel(bar)
d.AddOK(bar).OnClick(func(e events.Event) {
dv.Tensor.SetString1D(tf.Text(), 0)
})
})
d.RunDialog(fn)
case fileinfo.Toml:
ts.AsLab().EditorFile(df, string(fn.Filepath))
default:
dt := table.New()
err := dt.OpenCSV(fsx.Filename(fn.Filepath), tensor.Tab) // todo: need more flexible data handling mode
if err != nil {
core.ErrorSnackbar(fn, err)
} else {
ts.AsLab().TensorTable(df, dt)
}
}
case fn.IsExec(): // todo: use exec?
fn.OpenFilesDefault()
case fn.Info.Cat == fileinfo.Video: // todo: use our video viewer
fn.OpenFilesDefault()
case fn.Info.Cat == fileinfo.Audio: // todo: use our audio viewer
fn.OpenFilesDefault()
case fn.Info.Cat == fileinfo.Image: // todo: use our image viewer
fn.OpenFilesDefault()
case fn.Info.Cat == fileinfo.Model: // todo: use xyz
fn.OpenFilesDefault()
case fn.Info.Cat == fileinfo.Sheet: // todo: use our spreadsheet :)
fn.OpenFilesDefault()
case fn.Info.Cat == fileinfo.Bin: // don't edit
fn.OpenFilesDefault()
case fn.Info.Cat == fileinfo.Archive || fn.Info.Cat == fileinfo.Backup: // don't edit
fn.OpenFilesDefault()
default:
ts.AsLab().EditorFile(df, string(fn.Filepath))
}
return nil
}
// EditFiles calls EditFile on selected files
func (fn *FileNode) EditFiles() { //types:add
fn.SelectedFunc(func(sn *filetree.Node) {
sn.This.(*FileNode).EditFile()
})
}
// EditFile pulls up this file in a texteditor
func (fn *FileNode) EditFile() {
if fn.IsDir() {
fn.OpenFile()
return
}
ts := fn.Tabber()
if ts == nil {
return
}
if fn.Info.Cat == fileinfo.Data {
fn.OpenFile()
return
}
df := fsx.DirAndFile(string(fn.Filepath))
ts.AsLab().EditorFile(df, string(fn.Filepath))
}
// PlotFiles calls PlotFile on selected files
func (fn *FileNode) PlotFiles() { //types:add
fn.SelectedFunc(func(sn *filetree.Node) {
if sfn, ok := sn.This.(*FileNode); ok {
sfn.PlotFile()
}
})
}
// PlotFile creates a plot of data.
func (fn *FileNode) PlotFile() {
ts := fn.Tabber()
if ts == nil {
return
}
d := TensorFS(fn.AsNode())
if d != nil {
ts.AsLab().PlotTensorFS(d)
return
}
if fn.Info.Cat != fileinfo.Data {
return
}
df := fsx.DirAndFile(string(fn.Filepath))
ptab := df + " Plot"
dt := table.New(df)
err := dt.OpenCSV(fsx.Filename(fn.Filepath), tensor.Tab) // todo: need more flexible data handling mode
if err != nil {
core.ErrorSnackbar(fn, err)
return
}
ts.AsLab().PlotTable(ptab, dt)
}
// todo: this is too redundant -- need a better soln
// GridFiles calls GridFile on selected files
func (fn *FileNode) GridFiles() { //types:add
fn.SelectedFunc(func(sn *filetree.Node) {
if sfn, ok := sn.This.(*FileNode); ok {
sfn.GridFile()
}
})
}
// GridFile creates a grid view of data.
func (fn *FileNode) GridFile() {
ts := fn.Tabber()
if ts == nil {
return
}
d := TensorFS(fn.AsNode())
if d != nil {
ts.AsLab().GridTensorFS(d)
return
}
}
// DiffDirs displays a browser with differences between two selected directories
func (fn *FileNode) DiffDirs() { //types:add
var da, db *filetree.Node
fn.SelectedFunc(func(sn *filetree.Node) {
if sn.IsDir() {
if da == nil {
da = sn
} else if db == nil {
db = sn
}
}
})
if da == nil || db == nil {
core.MessageSnackbar(fn, "DiffDirs requires two selected directories")
return
}
NewDiffBrowserDirs(string(da.Filepath), string(db.Filepath))
}
// NewDiffBrowserDirs returns a new diff browser for files that differ
// within the two given directories. Excludes Job and .tsv data files.
func NewDiffBrowserDirs(pathA, pathB string) {
brow, b := diffbrowser.NewBrowserWindow()
brow.DiffDirs(pathA, pathB, func(fname string) bool {
if IsTableFile(fname) {
return true
}
if strings.HasPrefix(fname, "job.") || fname == "dbmeta.toml" {
return true
}
return false
})
b.RunWindow()
}
func IsTableFile(fname string) bool {
return strings.HasSuffix(fname, ".tsv") || strings.HasSuffix(fname, ".csv")
}
func (fn *FileNode) ContextMenu(m *core.Scene) {
core.NewFuncButton(m).SetFunc(fn.EditFiles).SetText("Edit").SetIcon(icons.Edit).
Styler(func(s *styles.Style) {
s.SetState(!fn.HasSelection(), states.Disabled)
})
core.NewFuncButton(m).SetFunc(fn.PlotFiles).SetText("Plot").SetIcon(icons.Edit).
Styler(func(s *styles.Style) {
s.SetState(!fn.HasSelection() || fn.Info.Cat != fileinfo.Data, states.Disabled)
})
core.NewFuncButton(m).SetFunc(fn.GridFiles).SetText("Grid").SetIcon(icons.Edit).
Styler(func(s *styles.Style) {
s.SetState(!fn.HasSelection() || fn.Info.Cat != fileinfo.Data, states.Disabled)
})
core.NewFuncButton(m).SetFunc(fn.DiffDirs).SetText("Diff Dirs").SetIcon(icons.Edit).
Styler(func(s *styles.Style) {
s.SetState(!fn.HasSelection() || !fn.IsDir(), states.Disabled)
})
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package labscripts
import (
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/logx"
"cogentcore.org/lab/goal/interpreter"
"cogentcore.org/lab/lab"
"github.com/cogentcore/yaegi/interp"
)
func init() {
lab.RunScript = RunScript
lab.RunScriptCode = RunScriptCode
}
// br.Interpreter = in
// if br.Interpreter == nil {
// br.InitInterp()
// in = br.Interpreter
// }
// in.Interp.Use(coresymbols.Symbols) // gui imports
// Interpreter returns the interpreter for given browser,
// or nil and an error message if not set.
func Interpreter(br *lab.Browser) (*interpreter.Interpreter, error) {
if br.Interpreter == nil {
return nil, errors.New("No interpreter has been set for the Browser, cannot run script")
}
return br.Interpreter.(*interpreter.Interpreter), nil
}
// InitInterpreter initializes a new interpreter if not already set.
func InitInterpreter(br *lab.Browser) {
if br.Interpreter == nil {
br.Interpreter = interpreter.NewInterpreter(interp.Options{})
}
// logx.UserLevel = slog.LevelDebug // for debugging of init loading
}
// RunScript runs given script from list of Scripts in the Browser.
func RunScript(br *lab.Browser, scriptName string) error {
in, err := Interpreter(br)
if err != nil {
return errors.Log(err)
}
sc, ok := br.Scripts[scriptName]
if !ok {
err := errors.New("script name not found: " + scriptName)
return errors.Log(err)
}
logx.PrintlnDebug("\n################\nrunning script:\n", sc, "\n")
_, _, err = in.Eval(sc)
if err == nil {
err = in.Goal.TrState.DepthError()
}
in.Goal.TrState.ResetDepth()
return err
}
// RunScriptCode runs given script code string in Browser's interpreter.
func RunScriptCode(br *lab.Browser, code string) error {
in, err := Interpreter(br)
if err != nil {
return err
}
_, _, err = in.Eval(code)
return err
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package lab
import (
"cogentcore.org/core/tree"
"cogentcore.org/lab/plot"
"cogentcore.org/lab/plotcore"
)
// NewPlot is a simple helper function that does [plot.New] and [plotcore.NewPlot],
// only returning the [plot.Plot] for convenient use in lab plots. See [NewPlotWidget]
// for a version that also returns the [plotcore.Plot]. See also [NewPlotFrom].
func NewPlot(parent ...tree.Node) *plot.Plot {
plt, _ := NewPlotWidget(parent...)
return plt
}
// NewPlotWidget is a simple helper function that does [plot.New] and [plotcore.NewPlot],
// returning both the [plot.Plot] and [plotcore.Plot] for convenient use in lab plots.
// See [NewPlot] for a version that only returns the more commonly useful [plot.Plot].
func NewPlotWidget(parent ...tree.Node) (*plot.Plot, *plotcore.Plot) {
plt := plot.New()
pw := plotcore.NewPlot(parent...).SetPlot(plt)
return plt, pw
}
// NewPlotFrom is a version of [NewPlot] that copies plot data from the given starting plot.
func NewPlotFrom(from *plot.Plot, parent ...tree.Node) *plot.Plot {
plt := NewPlot(parent...)
*plt = *from
return plt
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package lab
import (
"fmt"
"path/filepath"
"strings"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/core"
"cogentcore.org/core/styles"
"cogentcore.org/core/text/textcore"
"cogentcore.org/lab/plot"
"cogentcore.org/lab/plotcore"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorcore"
"cogentcore.org/lab/tensorfs"
)
// Lab is the current Tabs, for yaegi / Go consistent access.
var Lab *Tabs
// Tabber is a [core.Tabs] based widget that has support for opening
// tabs for [plotcore.Editor] and [tensorcore.Table] editors,
// among others.
type Tabber interface {
core.Tabber
// AsLab returns the [lab.Tabs] widget with all the tabs methods.
AsLab() *Tabs
}
// NewTab recycles a tab with given label, or returns the existing one
// with given type of widget within it. The existing that is returned
// is the last one in the frame, allowing for there to be a toolbar at the top.
// mkfun function is called to create and configure a new widget
// if not already existing.
func NewTab[T any](tb Tabber, label string, mkfun func(tab *core.Frame) T) T {
tab := tb.AsLab().RecycleTab(label)
var zv T
if tab.HasChildren() {
nc := tab.NumChildren()
lc := tab.Child(nc - 1)
if tt, ok := lc.(T); ok {
return tt
}
err := fmt.Errorf("Name / Type conflict: tab %q does not have the expected type of content: is %T", label, lc)
core.ErrorSnackbar(tb.AsLab(), err)
return zv
}
w := mkfun(tab)
return w
}
// TabAt returns widget of given type at tab of given name, nil if tab not found.
func TabAt[T any](tb Tabber, label string) T {
var zv T
tab := tb.AsLab().TabByName(label)
if tab == nil {
return zv
}
if !tab.HasChildren() { // shouldn't happen
return zv
}
nc := tab.NumChildren()
lc := tab.Child(nc - 1)
if tt, ok := lc.(T); ok {
return tt
}
err := fmt.Errorf("Name / Type conflict: tab %q does not have the expected type of content: %T", label, lc)
core.ErrorSnackbar(tb.AsLab(), err)
return zv
}
// Tabs implements the [Tabber] interface.
type Tabs struct {
core.Tabs
}
func (ts *Tabs) Init() {
ts.Tabs.Init()
ts.Type = core.FunctionalTabs
}
func (ts *Tabs) AsLab() *Tabs {
return ts
}
// TensorTable recycles a tab with a tensorcore.Table widget
// to view given table.Table, using its own table.Table as tv.Table.
// Use tv.Table.Table to get the underlying *table.Table
// Use tv.Table.Sequential to update the Indexed to view
// all of the rows when done updating the Table, and then call br.Update()
func (ts *Tabs) TensorTable(label string, dt *table.Table) *tensorcore.Table {
tv := NewTab(ts, label, func(tab *core.Frame) *tensorcore.Table {
tb := core.NewToolbar(tab)
tv := tensorcore.NewTable(tab)
tb.Maker(tv.MakeToolbar)
return tv
})
tv.SetTable(dt)
ts.Update()
return tv
}
// TensorEditor recycles a tab with a tensorcore.TensorEditor widget
// to view given Tensor.
func (ts *Tabs) TensorEditor(label string, tsr tensor.Tensor) *tensorcore.TensorEditor {
tv := NewTab(ts, label, func(tab *core.Frame) *tensorcore.TensorEditor {
tb := core.NewToolbar(tab)
tv := tensorcore.NewTensorEditor(tab)
tb.Maker(tv.MakeToolbar)
return tv
})
tv.SetTensor(tsr)
ts.Update()
return tv
}
// TensorGrid recycles a tab with a tensorcore.TensorGrid widget
// to view given Tensor.
func (ts *Tabs) TensorGrid(label string, tsr tensor.Tensor) *tensorcore.TensorGrid {
tv := NewTab(ts, label, func(tab *core.Frame) *tensorcore.TensorGrid {
tb := core.NewToolbar(tab)
tv := tensorcore.NewTensorGrid(tab)
tb.Maker(tv.MakeToolbar)
return tv
})
tv.SetTensor(tsr)
ts.Update()
return tv
}
// DirAndFileNoSlash returns [fsx.DirAndFile] with slashes replaced with spaces.
// Slashes are also used in core Widget paths, so spaces are safer.
func DirAndFileNoSlash(fpath string) string {
return strings.ReplaceAll(fsx.DirAndFile(fpath), string(filepath.Separator), " ")
}
// GridTensorFS recycles a tab with a Grid of given [tensorfs.Node].
func (ts *Tabs) GridTensorFS(dfs *tensorfs.Node) *tensorcore.TensorGrid {
label := DirAndFileNoSlash(dfs.Path()) + " Grid"
if dfs.IsDir() {
core.MessageSnackbar(ts, "Use Edit instead of Grid to view a directory")
return nil
}
tsr := dfs.Tensor
return ts.TensorGrid(label, tsr)
}
// PlotTable recycles a tab with a Plot of given table.Table.
func (ts *Tabs) PlotTable(label string, dt *table.Table) *plotcore.Editor {
pl := NewTab(ts, label, func(tab *core.Frame) *plotcore.Editor {
tb := core.NewToolbar(tab)
pl := plotcore.NewEditor(tab)
tab.Styler(func(s *styles.Style) {
s.Direction = styles.Column
s.Grow.Set(1, 1)
})
tb.Maker(pl.MakeToolbar)
return pl
})
if pl != nil {
pl.SetTable(dt)
}
return pl
}
// NewPlot recycles a tab with a plotcore.Editor.
func (ts *Tabs) NewPlot(label string) *plotcore.Editor {
pl := NewTab(ts, label, func(tab *core.Frame) *plotcore.Editor {
tb := core.NewToolbar(tab)
pl := plotcore.NewEditor(tab)
tab.Styler(func(s *styles.Style) {
s.Direction = styles.Column
s.Grow.Set(1, 1)
})
tb.Maker(pl.MakeToolbar)
return pl
})
return pl
}
// PlotTensorFS recycles a tab with a Plot of given [tensorfs.Node].
func (ts *Tabs) PlotTensorFS(dfs *tensorfs.Node) *plotcore.Editor {
label := DirAndFileNoSlash(dfs.Path()) + " Plot"
if dfs.IsDir() {
return ts.PlotTable(label, tensorfs.DirTable(dfs, nil))
}
tsr := dfs.Tensor
dt := table.New(label)
dt.Columns.Rows = tsr.DimSize(0)
if ix, ok := tsr.(*tensor.Rows); ok {
dt.Indexes = ix.Indexes
}
rc := dt.AddIntColumn("Row")
for r := range dt.Columns.Rows {
rc.Values[r] = r
}
dt.AddColumn(dfs.Name(), tsr.AsValues())
return ts.PlotTable(label, dt)
}
// Plot recycles a tab with given Plot using given label.
func (ts *Tabs) Plot(label string, plt *plot.Plot) *plotcore.Plot {
pl := NewTab(ts, label, func(tab *core.Frame) *plotcore.Plot {
pl := plotcore.NewPlot(tab)
pl.Styler(func(s *styles.Style) {
s.Direction = styles.Column
s.Grow.Set(1, 1)
})
pl.SetPlot(plt)
return pl
})
if pl != nil {
ts.Update()
}
return pl
}
// GoUpdatePlot calls GoUpdatePlot on plot at tab with given name.
// Does nothing if tab name doesn't exist (returns nil).
func (ts *Tabs) GoUpdatePlot(label string) *plotcore.Editor {
pl := TabAt[*plotcore.Editor](ts, label)
if pl != nil {
pl.GoUpdatePlot()
}
return pl
}
// UpdatePlot calls UpdatePlot on plot at tab with given name.
// Does nothing if tab name doesn't exist (returns nil).
func (ts *Tabs) UpdatePlot(label string) *plotcore.Editor {
pl := TabAt[*plotcore.Editor](ts, label)
if pl != nil {
pl.UpdatePlot()
}
return pl
}
// SliceTable recycles a tab with a core.Table widget
// to view the given slice of structs.
func (ts *Tabs) SliceTable(label string, slc any) *core.Table {
tv := NewTab(ts, label, func(tab *core.Frame) *core.Table {
return core.NewTable(tab)
})
tv.SetSlice(slc)
ts.Update()
return tv
}
// EditorString recycles a [textcore.Editor] tab, displaying given string.
func (ts *Tabs) EditorString(label, content string) *textcore.Editor {
ed := NewTab(ts, label, func(tab *core.Frame) *textcore.Editor {
ed := textcore.NewEditor(tab)
ed.Styler(func(s *styles.Style) {
s.Grow.Set(1, 1)
})
return ed
})
if content != "" {
ed.Lines.SetText([]byte(content))
}
ts.Update()
return ed
}
// EditorFile opens an editor tab for given file.
func (ts *Tabs) EditorFile(label, filename string) *textcore.Editor {
ed := NewTab(ts, label, func(tab *core.Frame) *textcore.Editor {
ed := textcore.NewEditor(tab)
ed.Styler(func(s *styles.Style) {
s.Grow.Set(1, 1)
})
return ed
})
ed.Lines.Open(filename)
ts.Update()
return ed
}
// TabUpdateRender calls UpdateRender on content of given tab.
// This is the best way to update display widgets during running
// an ongoing computation.
func (ts *Tabs) TabUpdateRender(label string) core.Widget {
tab := ts.TabByName(label)
if tab == nil {
return nil
}
tab.UpdateRender()
return tab
}
// Code generated by "core generate"; DO NOT EDIT.
package lab
import (
"io/fs"
"cogentcore.org/core/core"
"cogentcore.org/core/tree"
"cogentcore.org/core/types"
)
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/lab.Basic", IDName: "basic", Doc: "Basic is a basic data browser with the files as the left panel,\nand the Tabber as the right panel.", Embeds: []types.Field{{Name: "Frame"}, {Name: "Browser"}}})
// NewBasic returns a new [Basic] with the given optional parent:
// Basic is a basic data browser with the files as the left panel,
// and the Tabber as the right panel.
func NewBasic(parent ...tree.Node) *Basic { return tree.New[Basic](parent...) }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/lab.Browser", IDName: "browser", Doc: "Browser holds all the elements of a data browser, for browsing data\neither on an OS filesystem or as a tensorfs virtual data filesystem.\nIt supports the automatic loading of [goal] scripts as toolbar actions to\nperform pre-programmed tasks on the data, to create app-like functionality.\nScripts are ordered alphabetically and any leading #- prefix is automatically\nremoved from the label, so you can use numbers to specify a custom order.\nIt is not a [core.Widget] itself, and is intended to be incorporated into\na [core.Frame] widget, potentially along with other custom elements.\nSee [Basic] for a basic implementation.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Methods: []types.Method{{Name: "UpdateFiles", Doc: "UpdateFiles Updates the files list.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "UpdateScripts", Doc: "UpdateScripts updates the Scripts and updates the toolbar.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}}, Fields: []types.Field{{Name: "FS", Doc: "FS is the filesystem, if browsing an FS."}, {Name: "DataRoot", Doc: "DataRoot is the path to the root of the data to browse."}, {Name: "StartDir", Doc: "StartDir is the starting directory, where the app was originally started."}, {Name: "ScriptsDir", Doc: "ScriptsDir is the directory containing scripts for toolbar actions.\nIt defaults to DataRoot/dbscripts"}, {Name: "Scripts", Doc: "Scripts are interpreted goal scripts (via yaegi) to automate\nroutine tasks."}, {Name: "Interpreter", Doc: "Interpreter is the interpreter to use for running Browser scripts.\nis of type: *goal/interpreter.Interpreter but can't use that directly\nto avoid importing goal unless needed. Import [labscripts] if needed."}, {Name: "Files", Doc: "Files is the [DataTree] tree browser of the tensorfs or files."}, {Name: "Tabs", Doc: "Tabs is the [Tabs] element managing tabs of data views."}, {Name: "Toolbar", Doc: "Toolbar is the top-level toolbar for the browser, if used."}, {Name: "Splits", Doc: "Splits is the overall [core.Splits] for the browser."}}})
// SetFS sets the [Browser.FS]:
// FS is the filesystem, if browsing an FS.
func (t *Browser) SetFS(v fs.FS) *Browser { t.FS = v; return t }
// SetDataRoot sets the [Browser.DataRoot]:
// DataRoot is the path to the root of the data to browse.
func (t *Browser) SetDataRoot(v string) *Browser { t.DataRoot = v; return t }
// SetStartDir sets the [Browser.StartDir]:
// StartDir is the starting directory, where the app was originally started.
func (t *Browser) SetStartDir(v string) *Browser { t.StartDir = v; return t }
// SetScriptsDir sets the [Browser.ScriptsDir]:
// ScriptsDir is the directory containing scripts for toolbar actions.
// It defaults to DataRoot/dbscripts
func (t *Browser) SetScriptsDir(v string) *Browser { t.ScriptsDir = v; return t }
// SetFiles sets the [Browser.Files]:
// Files is the [DataTree] tree browser of the tensorfs or files.
func (t *Browser) SetFiles(v *DataTree) *Browser { t.Files = v; return t }
// SetTabs sets the [Browser.Tabs]:
// Tabs is the [Tabs] element managing tabs of data views.
func (t *Browser) SetTabs(v *Tabs) *Browser { t.Tabs = v; return t }
// SetToolbar sets the [Browser.Toolbar]:
// Toolbar is the top-level toolbar for the browser, if used.
func (t *Browser) SetToolbar(v *core.Toolbar) *Browser { t.Toolbar = v; return t }
// SetSplits sets the [Browser.Splits]:
// Splits is the overall [core.Splits] for the browser.
func (t *Browser) SetSplits(v *core.Splits) *Browser { t.Splits = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/lab.DataTree", IDName: "data-tree", Doc: "DataTree is the databrowser version of [filetree.Tree],\nwhich provides the Tabber to show data editors.", Embeds: []types.Field{{Name: "Tree"}}, Fields: []types.Field{{Name: "Tabber", Doc: "Tabber is the [Tabber] for this tree."}}})
// NewDataTree returns a new [DataTree] with the given optional parent:
// DataTree is the databrowser version of [filetree.Tree],
// which provides the Tabber to show data editors.
func NewDataTree(parent ...tree.Node) *DataTree { return tree.New[DataTree](parent...) }
// SetTabber sets the [DataTree.Tabber]:
// Tabber is the [Tabber] for this tree.
func (t *DataTree) SetTabber(v Tabber) *DataTree { t.Tabber = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/lab.FileNode", IDName: "file-node", Doc: "FileNode is databrowser version of FileNode for FileTree", Methods: []types.Method{{Name: "EditFiles", Doc: "EditFiles calls EditFile on selected files", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "PlotFiles", Doc: "PlotFiles calls PlotFile on selected files", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "GridFiles", Doc: "GridFiles calls GridFile on selected files", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "DiffDirs", Doc: "DiffDirs displays a browser with differences between two selected directories", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}}, Embeds: []types.Field{{Name: "Node"}}})
// NewFileNode returns a new [FileNode] with the given optional parent:
// FileNode is databrowser version of FileNode for FileTree
func NewFileNode(parent ...tree.Node) *FileNode { return tree.New[FileNode](parent...) }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/lab.Tabs", IDName: "tabs", Doc: "Tabs implements the [Tabber] interface.", Embeds: []types.Field{{Name: "Tabs"}}})
// NewTabs returns a new [Tabs] with the given optional parent:
// Tabs implements the [Tabber] interface.
func NewTabs(parent ...tree.Node) *Tabs { return tree.New[Tabs](parent...) }
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package matrix
import (
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/stats/stats"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensor/tmath"
"gonum.org/v1/gonum/mat"
)
// Eig performs the eigen decomposition of the given square matrix,
// which is not symmetric. See EigSym for a symmetric square matrix.
// In this non-symmetric case, the results are typically complex valued,
// so the outputs are complex tensors. TODO: need complex support!
// The vectors are same size as the input. Each vector is a column
// in this 2D square matrix, ordered *lowest* to *highest* across the columns,
// i.e., maximum vector is the last column.
// The values are the size of one row, ordered *lowest* to *highest*.
// If the input tensor is > 2D, it is treated as a list of 2D matricies,
// and parallel threading is used where beneficial.
func Eig(a tensor.Tensor) (vecs, vals *tensor.Float64) {
vecs = tensor.NewFloat64()
vals = tensor.NewFloat64()
errors.Log(EigOut(a, vecs, vals))
return
}
// EigOut performs the eigen decomposition of the given square matrix,
// which is not symmetric. See EigSym for a symmetric square matrix.
// In this non-symmetric case, the results are typically complex valued,
// so the outputs are complex tensors. TODO: need complex support!
// The vectors are same size as the input. Each vector is a column
// in this 2D square matrix, ordered *lowest* to *highest* across the columns,
// i.e., maximum vector is the last column.
// The values are the size of one row, ordered *lowest* to *highest*.
// If the input tensor is > 2D, it is treated as a list of 2D matricies,
// and parallel threading is used where beneficial.
func EigOut(a tensor.Tensor, vecs, vals *tensor.Float64) error {
if err := StringCheck(a); err != nil {
return err
}
na := a.NumDims()
if na == 1 {
return mat.ErrShape
}
var asz []int
ea := a
if na > 2 {
asz = tensor.SplitAtInnerDims(a, 2)
if asz[0] == 1 {
ea = tensor.Reshape(a, asz[1:]...)
na = 2
}
}
if na == 2 {
if a.DimSize(0) != a.DimSize(1) {
return mat.ErrShape
}
ma, _ := NewMatrix(a)
vecs.SetShapeSizes(a.DimSize(0), a.DimSize(1))
vals.SetShapeSizes(a.DimSize(0))
do, _ := NewDense(vecs)
var eig mat.Eigen
ok := eig.Factorize(ma, mat.EigenRight)
if !ok {
return errors.New("gonum mat.Eigen Factorize failed")
}
_ = do
// eig.VectorsTo(do) // todo: requires complex!
// eig.Values(vals.Values)
return nil
}
ea = tensor.Reshape(a, asz...)
if ea.DimSize(1) != ea.DimSize(2) {
return mat.ErrShape
}
nr := ea.DimSize(0)
sz := ea.DimSize(1)
vecs.SetShapeSizes(nr, sz, sz)
vals.SetShapeSizes(nr, sz)
var errs []error
tensor.VectorizeThreaded(ea.DimSize(1)*ea.DimSize(2)*1000,
func(tsr ...tensor.Tensor) int { return nr },
func(r int, tsr ...tensor.Tensor) {
sa := tensor.Reslice(ea, r, tensor.FullAxis, tensor.FullAxis)
ma, _ := NewMatrix(sa)
do, _ := NewDense(vecs.RowTensor(r).(*tensor.Float64))
var eig mat.Eigen
ok := eig.Factorize(ma, mat.EigenRight)
if !ok {
errs = append(errs, errors.New("gonum mat.Eigen Factorize failed"))
}
_ = do
// eig.VectorsTo(do) // todo: requires complex!
// eig.Values(vals.Values[r*sz : (r+1)*sz])
})
return errors.Join(errs...)
}
// EigSym performs the eigen decomposition of the given symmetric square matrix,
// which produces real-valued results. When input is the [metric.CovarianceMatrix],
// this is known as Principal Components Analysis (PCA).
// The vectors are same size as the input. Each vector is a column
// in this 2D square matrix, ordered *lowest* to *highest* across the columns,
// i.e., maximum vector is the last column.
// The values are the size of one row, ordered *lowest* to *highest*.
// Note that Eig produces results in the *opposite* order of [SVD] (which is much faster).
// If the input tensor is > 2D, it is treated as a list of 2D matricies,
// and parallel threading is used where beneficial.
func EigSym(a tensor.Tensor) (vecs, vals *tensor.Float64) {
vecs = tensor.NewFloat64()
vals = tensor.NewFloat64()
errors.Log(EigSymOut(a, vecs, vals))
return
}
// EigSymOut performs the eigen decomposition of the given symmetric square matrix,
// which produces real-valued results. When input is the [metric.CovarianceMatrix],
// this is known as Principal Components Analysis (PCA).
// The vectors are same size as the input. Each vector is a column
// in this 2D square matrix, ordered *lowest* to *highest* across the columns,
// i.e., maximum vector is the last column.
// The values are the size of one row, ordered *lowest* to *highest*.
// Note that Eig produces results in the *opposite* order of [SVD] (which is much faster).
// If the input tensor is > 2D, it is treated as a list of 2D matricies,
// and parallel threading is used where beneficial.
func EigSymOut(a tensor.Tensor, vecs, vals *tensor.Float64) error {
if err := StringCheck(a); err != nil {
return err
}
na := a.NumDims()
if na == 1 {
return mat.ErrShape
}
var asz []int
ea := a
if na > 2 {
asz = tensor.SplitAtInnerDims(a, 2)
if asz[0] == 1 {
ea = tensor.Reshape(a, asz[1:]...)
na = 2
}
}
if na == 2 {
if a.DimSize(0) != a.DimSize(1) {
return mat.ErrShape
}
ma, _ := NewSymmetric(a)
vecs.SetShapeSizes(a.DimSize(0), a.DimSize(1))
vals.SetShapeSizes(a.DimSize(0))
do, _ := NewDense(vecs)
var eig mat.EigenSym
ok := eig.Factorize(ma, true)
if !ok {
return errors.New("gonum mat.EigenSym Factorize failed")
}
eig.VectorsTo(do)
eig.Values(vals.Values)
return nil
}
ea = tensor.Reshape(a, asz...)
if ea.DimSize(1) != ea.DimSize(2) {
return mat.ErrShape
}
nr := ea.DimSize(0)
sz := ea.DimSize(1)
vecs.SetShapeSizes(nr, sz, sz)
vals.SetShapeSizes(nr, sz)
var errs []error
tensor.VectorizeThreaded(ea.DimSize(1)*ea.DimSize(2)*1000,
func(tsr ...tensor.Tensor) int { return nr },
func(r int, tsr ...tensor.Tensor) {
sa := tensor.Reslice(ea, r, tensor.FullAxis, tensor.FullAxis)
ma, _ := NewSymmetric(sa)
do, _ := NewDense(vecs.RowTensor(r).(*tensor.Float64))
var eig mat.EigenSym
ok := eig.Factorize(ma, true)
if !ok {
errs = append(errs, errors.New("gonum mat.Eigen Factorize failed"))
}
eig.VectorsTo(do)
eig.Values(vals.Values[r*sz : (r+1)*sz])
})
return errors.Join(errs...)
}
// SVD performs the singular value decomposition of the given symmetric square matrix,
// which produces real-valued results, and is generally much faster than [EigSym],
// while producing the same results.
// The vectors are same size as the input. Each vector is a column
// in this 2D square matrix, ordered *highest* to *lowest* across the columns,
// i.e., maximum vector is the first column.
// The values are the size of one row ordered in alignment with the vectors.
// Note that SVD produces results in the *opposite* order of [EigSym].
// If the input tensor is > 2D, it is treated as a list of 2D matricies,
// and parallel threading is used where beneficial.
func SVD(a tensor.Tensor) (vecs, vals *tensor.Float64) {
vecs = tensor.NewFloat64()
vals = tensor.NewFloat64()
errors.Log(SVDOut(a, vecs, vals))
return
}
// SVDOut performs the singular value decomposition of the given symmetric square matrix,
// which produces real-valued results, and is generally much faster than [EigSym],
// while producing the same results.
// The vectors are same size as the input. Each vector is a column
// in this 2D square matrix, ordered *highest* to *lowest* across the columns,
// i.e., maximum vector is the first column.
// The values are the size of one row ordered in alignment with the vectors.
// Note that SVD produces results in the *opposite* order of [EigSym].
// If the input tensor is > 2D, it is treated as a list of 2D matricies,
// and parallel threading is used where beneficial.
func SVDOut(a tensor.Tensor, vecs, vals *tensor.Float64) error {
if err := StringCheck(a); err != nil {
return err
}
na := a.NumDims()
if na == 1 {
return mat.ErrShape
}
var asz []int
ea := a
if na > 2 {
asz = tensor.SplitAtInnerDims(a, 2)
if asz[0] == 1 {
ea = tensor.Reshape(a, asz[1:]...)
na = 2
}
}
if na == 2 {
if a.DimSize(0) != a.DimSize(1) {
return mat.ErrShape
}
ma, _ := NewSymmetric(a)
vecs.SetShapeSizes(a.DimSize(0), a.DimSize(1))
vals.SetShapeSizes(a.DimSize(0))
do, _ := NewDense(vecs)
var eig mat.SVD
ok := eig.Factorize(ma, mat.SVDFull)
if !ok {
return errors.New("gonum mat.SVD Factorize failed")
}
eig.UTo(do)
eig.Values(vals.Values)
return nil
}
ea = tensor.Reshape(a, asz...)
if ea.DimSize(1) != ea.DimSize(2) {
return mat.ErrShape
}
nr := ea.DimSize(0)
sz := ea.DimSize(1)
vecs.SetShapeSizes(nr, sz, sz)
vals.SetShapeSizes(nr, sz)
var errs []error
tensor.VectorizeThreaded(ea.DimSize(1)*ea.DimSize(2)*1000,
func(tsr ...tensor.Tensor) int { return nr },
func(r int, tsr ...tensor.Tensor) {
sa := tensor.Reslice(ea, r, tensor.FullAxis, tensor.FullAxis)
ma, _ := NewSymmetric(sa)
do, _ := NewDense(vecs.RowTensor(r).(*tensor.Float64))
var eig mat.SVD
ok := eig.Factorize(ma, mat.SVDFull)
if !ok {
errs = append(errs, errors.New("gonum mat.SVD Factorize failed"))
}
eig.UTo(do)
eig.Values(vals.Values[r*sz : (r+1)*sz])
})
return errors.Join(errs...)
}
// SVDValues performs the singular value decomposition of the given
// symmetric square matrix, which produces real-valued results,
// and is generally much faster than [EigSym], while producing the same results.
// This version only generates eigenvalues, not vectors: see [SVD].
// The values are the size of one row ordered highest to lowest,
// which is the opposite of [EigSym].
// If the input tensor is > 2D, it is treated as a list of 2D matricies,
// and parallel threading is used where beneficial.
func SVDValues(a tensor.Tensor) *tensor.Float64 {
vals := tensor.NewFloat64()
errors.Log(SVDValuesOut(a, vals))
return vals
}
// SVDValuesOut performs the singular value decomposition of the given
// symmetric square matrix, which produces real-valued results,
// and is generally much faster than [EigSym], while producing the same results.
// This version only generates eigenvalues, not vectors: see [SVDOut].
// The values are the size of one row ordered highest to lowest,
// which is the opposite of [EigSym].
// If the input tensor is > 2D, it is treated as a list of 2D matricies,
// and parallel threading is used where beneficial.
func SVDValuesOut(a tensor.Tensor, vals *tensor.Float64) error {
if err := StringCheck(a); err != nil {
return err
}
na := a.NumDims()
if na == 1 {
return mat.ErrShape
}
var asz []int
ea := a
if na > 2 {
asz = tensor.SplitAtInnerDims(a, 2)
if asz[0] == 1 {
ea = tensor.Reshape(a, asz[1:]...)
na = 2
}
}
if na == 2 {
if a.DimSize(0) != a.DimSize(1) {
return mat.ErrShape
}
ma, _ := NewSymmetric(a)
vals.SetShapeSizes(a.DimSize(0))
var eig mat.SVD
ok := eig.Factorize(ma, mat.SVDNone)
if !ok {
return errors.New("gonum mat.SVD Factorize failed")
}
eig.Values(vals.Values)
return nil
}
ea = tensor.Reshape(a, asz...)
if ea.DimSize(1) != ea.DimSize(2) {
return mat.ErrShape
}
nr := ea.DimSize(0)
sz := ea.DimSize(1)
vals.SetShapeSizes(nr, sz)
var errs []error
tensor.VectorizeThreaded(ea.DimSize(1)*ea.DimSize(2)*1000,
func(tsr ...tensor.Tensor) int { return nr },
func(r int, tsr ...tensor.Tensor) {
sa := tensor.Reslice(ea, r, tensor.FullAxis, tensor.FullAxis)
ma, _ := NewSymmetric(sa)
var eig mat.SVD
ok := eig.Factorize(ma, mat.SVDNone)
if !ok {
errs = append(errs, errors.New("gonum mat.SVD Factorize failed"))
}
eig.Values(vals.Values[r*sz : (r+1)*sz])
})
return errors.Join(errs...)
}
// ProjectOnMatrixColumn is a convenience function for projecting given vector
// of values along a specific column (2nd dimension) of the given 2D matrix,
// specified by the scalar colindex, putting results into out.
// If the vec is more than 1 dimensional, then it is treated as rows x cells,
// and each row of cells is projected through the matrix column, producing a
// 1D output with the number of rows. Otherwise a single number is produced.
// This is typically done with results from SVD or EigSym (PCA).
func ProjectOnMatrixColumn(mtx, vec, colindex tensor.Tensor) tensor.Values {
out := tensor.NewOfType(vec.DataType())
errors.Log(ProjectOnMatrixColumnOut(mtx, vec, colindex, out))
return out
}
// ProjectOnMatrixColumnOut is a convenience function for projecting given vector
// of values along a specific column (2nd dimension) of the given 2D matrix,
// specified by the scalar colindex, putting results into out.
// If the vec is more than 1 dimensional, then it is treated as rows x cells,
// and each row of cells is projected through the matrix column, producing a
// 1D output with the number of rows. Otherwise a single number is produced.
// This is typically done with results from SVD or EigSym (PCA).
func ProjectOnMatrixColumnOut(mtx, vec, colindex tensor.Tensor, out tensor.Values) error {
ci := int(colindex.Float1D(0))
col := tensor.As1D(tensor.Reslice(mtx, tensor.Slice{}, ci))
// fmt.Println(mtx.String(), col.String())
rows, cells := vec.Shape().RowCellSize()
if rows > 0 && cells > 0 {
msum := tensor.NewFloat64Scalar(0)
out.SetShapeSizes(rows)
mout := tensor.NewFloat64(cells)
for i := range rows {
err := tmath.MulOut(tensor.Cells1D(vec, i), col, mout)
if err != nil {
return err
}
stats.SumOut(mout, msum)
out.SetFloat1D(msum.Float1D(0), i)
}
} else {
mout := tensor.NewFloat64(1)
tmath.MulOut(vec, col, mout)
stats.SumOut(mout, out)
}
return nil
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package matrix
import (
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/num"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/vector"
)
// offCols is a helper function to process the optional offset_cols args
func offCols(size int, offset_cols ...int) (off, cols int) {
off = 0
cols = size
if len(offset_cols) >= 1 {
off = offset_cols[0]
}
if len(offset_cols) == 2 {
cols = offset_cols[1]
}
return
}
// Identity returns a new 2D Float64 tensor with 1s along the diagonal and
// 0s elsewhere, with the given row and column size.
// - If one additional parameter is passed, it is the offset,
// to set values above (positive) or below (negative) the diagonal.
// - If a second additional parameter is passed, it is the number of columns
// for a non-square matrix (first size parameter = number of rows).
func Identity(size int, offset_cols ...int) *tensor.Float64 {
off, cols := offCols(size, offset_cols...)
tsr := tensor.NewFloat64(size, cols)
for r := range size {
c := r + off
if c < 0 || c >= cols {
continue
}
tsr.SetFloat(1, r, c)
}
return tsr
}
// DiagonalN returns the number of elements in the along the diagonal
// of a 2D matrix of given row and column size.
// - If one additional parameter is passed, it is the offset,
// to include values above (positive) or below (negative) the diagonal.
// - If a second additional parameter is passed, it is the number of columns
// for a non-square matrix (first size parameter = number of rows).
func DiagonalN(size int, offset_cols ...int) int {
off, cols := offCols(size, offset_cols...)
rows := size
if num.Abs(off) > 0 {
oa := num.Abs(off)
if off > 0 {
if cols > rows {
return DiagonalN(rows, 0, cols-oa)
} else {
return DiagonalN(rows-oa, 0, cols-oa)
}
} else {
if rows > cols {
return DiagonalN(rows-oa, 0, cols)
} else {
return DiagonalN(rows-oa, 0, cols-oa)
}
}
}
n := min(rows, cols)
return n
}
// DiagonalIndices returns a list of indices for the diagonal elements of
// a 2D matrix of given row and column size.
// The result is a 2D list of indices, where the outer (row) dimension
// is the number of indices, and the inner dimension is 2 for the r, c coords.
// - If one additional parameter is passed, it is the offset,
// to set values above (positive) or below (negative) the diagonal.
// - If a second additional parameter is passed, it is the number of columns
// for a non-square matrix (first size parameter = number of rows).
func DiagonalIndices(size int, offset_cols ...int) *tensor.Int {
off, cols := offCols(size, offset_cols...)
dn := DiagonalN(size, off, cols)
tsr := tensor.NewInt(dn, 2)
idx := 0
for r := range size {
c := r + off
if c < 0 || c >= cols {
continue
}
tsr.SetInt(r, idx, 0)
tsr.SetInt(c, idx, 1)
idx++
}
return tsr
}
// Diagonal returns an [Indexed] view of the given tensor for the diagonal
// values, as a 1D list. An error is logged if the tensor is not 2D.
// Use the optional offset parameter to get values above (positive) or
// below (negative) the diagonal.
func Diagonal(tsr tensor.Tensor, offset ...int) *tensor.Indexed {
if tsr.NumDims() != 2 {
errors.Log(errors.New("matrix.TriLView requires a 2D tensor"))
return nil
}
off := 0
if len(offset) == 1 {
off = offset[0]
}
return tensor.NewIndexed(tsr, DiagonalIndices(tsr.DimSize(0), off, tsr.DimSize(1)))
}
// Trace returns the sum of the [Diagonal] elements of the given
// tensor, as a tensor scalar.
// An error is logged if the tensor is not 2D.
// Use the optional offset parameter to get values above (positive) or
// below (negative) the diagonal.
func Trace(tsr tensor.Tensor, offset ...int) tensor.Values {
return vector.Sum(Diagonal(tsr, offset...))
}
// Tri returns a new 2D Float64 tensor with 1s along the diagonal and
// below it, and 0s elsewhere (i.e., a filled lower triangle).
// - If one additional parameter is passed, it is the offset,
// to include values above (positive) or below (negative) the diagonal.
// - If a second additional parameter is passed, it is the number of columns
// for a non-square matrix (first size parameter = number of rows).
func Tri(size int, offset_cols ...int) *tensor.Float64 {
off, cols := offCols(size, offset_cols...)
tsr := tensor.NewFloat64(size, cols)
for r := range size {
for c := range cols {
if c <= r+off {
tsr.SetFloat(1, r, c)
}
}
}
return tsr
}
// TriUpper returns a new 2D Float64 tensor with 1s along the diagonal and
// above it, and 0s elsewhere (i.e., a filled upper triangle).
// - If one additional parameter is passed, it is the offset,
// to include values above (positive) or below (negative) the diagonal.
// - If a second additional parameter is passed, it is the number of columns
// for a non-square matrix (first size parameter = number of rows).
func TriUpper(size int, offset_cols ...int) *tensor.Float64 {
off, cols := offCols(size, offset_cols...)
tsr := tensor.NewFloat64(size, cols)
for r := range size {
for c := range cols {
if c >= r+off {
tsr.SetFloat(1, r, c)
}
}
}
return tsr
}
// TriUNum returns the number of elements in the upper triangular region
// of a 2D matrix of given row and column size, where the triangle includes the
// elements along the diagonal.
// - If one additional parameter is passed, it is the offset,
// to include values above (positive) or below (negative) the diagonal.
// - If a second additional parameter is passed, it is the number of columns
// for a non-square matrix (first size parameter = number of rows).
func TriUNum(size int, offset_cols ...int) int {
off, cols := offCols(size, offset_cols...)
rows := size
if off > 0 {
if cols > rows {
return TriUNum(rows, 0, cols-off)
} else {
return TriUNum(rows-off, 0, cols-off)
}
} else if off < 0 { // invert
return cols*rows - TriUNum(cols, -(off-1), rows)
}
if cols <= size {
return cols + (cols*(cols-1))/2
}
return rows + (rows*(2*cols-rows-1))/2
}
// TriLNum returns the number of elements in the lower triangular region
// of a 2D matrix of given row and column size, where the triangle includes the
// elements along the diagonal.
// - If one additional parameter is passed, it is the offset,
// to include values above (positive) or below (negative) the diagonal.
// - If a second additional parameter is passed, it is the number of columns
// for a non-square matrix (first size parameter = number of rows).
func TriLNum(size int, offset_cols ...int) int {
off, cols := offCols(size, offset_cols...)
return TriUNum(cols, -off, size)
}
// TriLIndicies returns the list of r, c indexes for the lower triangular
// portion of a square matrix of size n, including the diagonal.
// The result is a 2D list of indices, where the outer (row) dimension
// is the number of indices, and the inner dimension is 2 for the r, c coords.
// - If one additional parameter is passed, it is the offset,
// to include values above (positive) or below (negative) the diagonal.
// - If a second additional parameter is passed, it is the number of columns
// for a non-square matrix.
func TriLIndicies(size int, offset_cols ...int) *tensor.Int {
off, cols := offCols(size, offset_cols...)
trin := TriLNum(size, off, cols)
coords := tensor.NewInt(trin, 2)
i := 0
for r := range size {
for c := range cols {
if c <= r+off {
coords.SetInt(r, i, 0)
coords.SetInt(c, i, 1)
i++
}
}
}
return coords
}
// TriUIndicies returns the list of r, c indexes for the upper triangular
// portion of a square matrix of size n, including the diagonal.
// If one additional parameter is passed, it is the offset,
// to include values above (positive) or below (negative) the diagonal.
// If a second additional parameter is passed, it is the number of columns
// for a non-square matrix.
// The result is a 2D list of indices, where the outer (row) dimension
// is the number of indices, and the inner dimension is 2 for the r, c coords.
func TriUIndicies(size int, offset_cols ...int) *tensor.Int {
off, cols := offCols(size, offset_cols...)
trin := TriUNum(size, off, cols)
coords := tensor.NewInt(trin, 2)
i := 0
for r := range size {
for c := range cols {
if c >= r+off {
coords.SetInt(r, i, 0)
coords.SetInt(c, i, 1)
i++
}
}
}
return coords
}
// TriLView returns an [Indexed] view of the given tensor for the lower triangular
// region of values, as a 1D list. An error is logged if the tensor is not 2D.
// Use the optional offset parameter to get values above (positive) or
// below (negative) the diagonal.
func TriLView(tsr tensor.Tensor, offset ...int) *tensor.Indexed {
if tsr.NumDims() != 2 {
errors.Log(errors.New("matrix.TriLView requires a 2D tensor"))
return nil
}
off := 0
if len(offset) == 1 {
off = offset[0]
}
return tensor.NewIndexed(tsr, TriLIndicies(tsr.DimSize(0), off, tsr.DimSize(1)))
}
// TriUView returns an [Indexed] view of the given tensor for the upper triangular
// region of values, as a 1D list. An error is logged if the tensor is not 2D.
// Use the optional offset parameter to get values above (positive) or
// below (negative) the diagonal.
func TriUView(tsr tensor.Tensor, offset ...int) *tensor.Indexed {
if tsr.NumDims() != 2 {
errors.Log(errors.New("matrix.TriUView requires a 2D tensor"))
return nil
}
off := 0
if len(offset) == 1 {
off = offset[0]
}
return tensor.NewIndexed(tsr, TriUIndicies(tsr.DimSize(0), off, tsr.DimSize(1)))
}
// TriL returns a copy of the given tensor containing the lower triangular
// region of values (including the diagonal), with the lower triangular region
// zeroed. An error is logged if the tensor is not 2D.
// Use the optional offset parameter to include values above (positive) or
// below (negative) the diagonal.
func TriL(tsr tensor.Tensor, offset ...int) tensor.Tensor {
if tsr.NumDims() != 2 {
errors.Log(errors.New("matrix.TriL requires a 2D tensor"))
return nil
}
off := 0
if len(offset) == 1 {
off = offset[0]
}
off += 1
tc := tensor.Clone(tsr)
tv := TriUView(tc, off) // opposite
tensor.SetAllFloat64(tv, 0)
return tc
}
// TriU returns a copy of the given tensor containing the upper triangular
// region of values (including the diagonal), with the lower triangular region
// zeroed. An error is logged if the tensor is not 2D.
// Use the optional offset parameter to include values above (positive) or
// below (negative) the diagonal.
func TriU(tsr tensor.Tensor, offset ...int) tensor.Tensor {
if tsr.NumDims() != 2 {
errors.Log(errors.New("matrix.TriU requires a 2D tensor"))
return nil
}
off := 0
if len(offset) == 1 {
off = offset[0]
}
off -= 1
tc := tensor.Clone(tsr)
tv := TriLView(tc, off) // opposite
tensor.SetAllFloat64(tv, 0)
return tc
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package matrix
import (
"errors"
"cogentcore.org/lab/tensor"
"gonum.org/v1/gonum/mat"
)
// Matrix provides a view of the given [tensor.Tensor] as a [gonum]
// [mat.Matrix] interface type.
type Matrix struct {
Tensor tensor.Tensor
}
func StringCheck(tsr tensor.Tensor) error {
if tsr.IsString() {
return errors.New("matrix: tensor has string values; must be numeric")
}
return nil
}
// NewMatrix returns given [tensor.Tensor] as a [gonum] [mat.Matrix].
// It returns an error if the tensor is not 2D.
func NewMatrix(tsr tensor.Tensor) (*Matrix, error) {
if err := StringCheck(tsr); err != nil {
return nil, err
}
nd := tsr.NumDims()
if nd != 2 {
err := errors.New("matrix.NewMatrix: tensor is not 2D")
return nil, err
}
return &Matrix{Tensor: tsr}, nil
}
// Dims is the gonum/mat.Matrix interface method for returning the
// dimension sizes of the 2D Matrix. Assumes Row-major ordering.
func (mx *Matrix) Dims() (r, c int) {
return mx.Tensor.DimSize(0), mx.Tensor.DimSize(1)
}
// At is the gonum/mat.Matrix interface method for returning 2D
// matrix element at given row, column index. Assumes Row-major ordering.
func (mx *Matrix) At(i, j int) float64 {
return mx.Tensor.Float(i, j)
}
// T is the gonum/mat.Matrix transpose method.
// It performs an implicit transpose by returning the receiver inside a Transpose.
func (mx *Matrix) T() mat.Matrix {
return mat.Transpose{mx}
}
//////// Symmetric
// Symmetric provides a view of the given [tensor.Tensor] as a [gonum]
// [mat.Symmetric] matrix interface type.
type Symmetric struct {
Matrix
}
// NewSymmetric returns given [tensor.Tensor] as a [gonum] [mat.Symmetric] matrix.
// It returns an error if the tensor is not 2D or not symmetric.
func NewSymmetric(tsr tensor.Tensor) (*Symmetric, error) {
if tsr.IsString() {
err := errors.New("matrix.NewSymmetric: tensor has string values; must be numeric")
return nil, err
}
nd := tsr.NumDims()
if nd != 2 {
err := errors.New("matrix.NewSymmetric: tensor is not 2D")
return nil, err
}
if tsr.DimSize(0) != tsr.DimSize(1) {
err := errors.New("matrix.NewSymmetric: tensor is not symmetric")
return nil, err
}
sy := &Symmetric{}
sy.Tensor = tsr
return sy, nil
}
// SymmetricDim is the gonum/mat.Matrix interface method for returning the
// dimensionality of a symmetric 2D Matrix.
func (sy *Symmetric) SymmetricDim() (r int) {
return sy.Tensor.DimSize(0)
}
// NewDense returns given [tensor.Float64] as a [gonum] [mat.Dense]
// Matrix, on which many of the matrix operations are defined.
// It functions similar to the [tensor.Values] type, as the output
// of matrix operations. The Dense type serves as a view onto
// the tensor's data, so operations directly modify it.
func NewDense(tsr *tensor.Float64) (*mat.Dense, error) {
nd := tsr.NumDims()
if nd != 2 {
err := errors.New("matrix.NewDense: tensor is not 2D")
return nil, err
}
return mat.NewDense(tsr.DimSize(0), tsr.DimSize(1), tsr.Values), nil
}
// CopyFromDense copies a gonum mat.Dense matrix into given Tensor
// using standard Float64 interface
func CopyFromDense(to tensor.Values, dm *mat.Dense) {
nr, nc := dm.Dims()
to.SetShapeSizes(nr, nc)
idx := 0
for ri := 0; ri < nr; ri++ {
for ci := 0; ci < nc; ci++ {
v := dm.At(ri, ci)
to.SetFloat1D(v, idx)
idx++
}
}
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package matrix
import (
"slices"
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/tensor"
"gonum.org/v1/gonum/mat"
)
// CallOut1 calls an Out function with 1 input arg. All matrix functions
// require *tensor.Float64 outputs.
func CallOut1(fun func(a tensor.Tensor, out *tensor.Float64) error, a tensor.Tensor) *tensor.Float64 {
out := tensor.NewFloat64()
errors.Log(fun(a, out))
return out
}
// CallOut2 calls an Out function with 2 input args. All matrix functions
// require *tensor.Float64 outputs.
func CallOut2(fun func(a, b tensor.Tensor, out *tensor.Float64) error, a, b tensor.Tensor) *tensor.Float64 {
out := tensor.NewFloat64()
errors.Log(fun(a, b, out))
return out
}
// Mul performs matrix multiplication, using the following rules based
// on the shapes of the relevant tensors. If the tensor shapes are not
// suitable, an error is logged (see [MulOut] for a version returning the error).
// N > 2 dimensional cases use parallel threading where beneficial.
// - If both arguments are 2-D they are multiplied like conventional matrices.
// - If either argument is N-D, N > 2, it is treated as a stack of matrices
// residing in the last two indexes and broadcast accordingly.
// - If the first argument is 1-D, it is promoted to a matrix by prepending
// a 1 to its dimensions. After matrix multiplication the prepended 1 is removed.
// - If the second argument is 1-D, it is promoted to a matrix by appending
// a 1 to its dimensions. After matrix multiplication the appended 1 is removed.
func Mul(a, b tensor.Tensor) *tensor.Float64 {
return CallOut2(MulOut, a, b)
}
// MulOut performs matrix multiplication, into the given output tensor,
// using the following rules based on the shapes of the relevant tensors.
// If the tensor shapes are not suitable, a [gonum] [mat.ErrShape] error is returned.
// N > 2 dimensional cases use parallel threading where beneficial.
// - If both arguments are 2-D they are multiplied like conventional matrices.
// The result has shape a.Rows, b.Columns.
// - If either argument is N-D, N > 2, it is treated as a stack of matrices
// residing in the last two indexes and broadcast accordingly. Both cannot
// be > 2 dimensional, unless their outer dimension size is 1 or the same.
// - If the first argument is 1-D, it is promoted to a matrix by prepending
// a 1 to its dimensions. After matrix multiplication the prepended 1 is removed.
// - If the second argument is 1-D, it is promoted to a matrix by appending
// a 1 to its dimensions. After matrix multiplication the appended 1 is removed.
func MulOut(a, b tensor.Tensor, out *tensor.Float64) error {
if err := StringCheck(a); err != nil {
return err
}
if err := StringCheck(b); err != nil {
return err
}
na := a.NumDims()
nb := b.NumDims()
ea := a
eb := b
collapse := false
colDim := 0
if na == 1 {
ea = tensor.Reshape(a, 1, a.DimSize(0))
collapse = true
colDim = -2
na = 2
}
if nb == 1 {
eb = tensor.Reshape(b, b.DimSize(0), 1)
collapse = true
colDim = -1
nb = 2
}
if na > 2 {
asz := tensor.SplitAtInnerDims(a, 2)
if asz[0] == 1 {
ea = tensor.Reshape(a, asz[1:]...)
na = 2
} else {
ea = tensor.Reshape(a, asz...)
}
}
if nb > 2 {
bsz := tensor.SplitAtInnerDims(b, 2)
if bsz[0] == 1 {
eb = tensor.Reshape(b, bsz[1:]...)
nb = 2
} else {
eb = tensor.Reshape(b, bsz...)
}
}
switch {
case na == nb && na == 2:
if ea.DimSize(1) != eb.DimSize(0) {
return mat.ErrShape
}
ma, _ := NewMatrix(ea)
mb, _ := NewMatrix(eb)
out.SetShapeSizes(ea.DimSize(0), eb.DimSize(1))
do, _ := NewDense(out)
do.Mul(ma, mb)
case na > 2 && nb == 2:
if ea.DimSize(2) != eb.DimSize(0) {
return mat.ErrShape
}
mb, _ := NewMatrix(eb)
nr := ea.DimSize(0)
out.SetShapeSizes(nr, ea.DimSize(1), eb.DimSize(1))
tensor.VectorizeThreaded(ea.DimSize(1)*ea.DimSize(2)*eb.Len()*100, // always beneficial
func(tsr ...tensor.Tensor) int { return nr },
func(r int, tsr ...tensor.Tensor) {
sa := tensor.Reslice(ea, r, tensor.FullAxis, tensor.FullAxis)
ma, _ := NewMatrix(sa)
do, _ := NewDense(out.RowTensor(r).(*tensor.Float64))
do.Mul(ma, mb)
})
case nb > 2 && na == 2:
if ea.DimSize(1) != eb.DimSize(1) {
return mat.ErrShape
}
ma, _ := NewMatrix(ea)
nr := eb.DimSize(0)
out.SetShapeSizes(nr, ea.DimSize(0), eb.DimSize(2))
tensor.VectorizeThreaded(ea.Len()*eb.DimSize(1)*eb.DimSize(2)*100,
func(tsr ...tensor.Tensor) int { return nr },
func(r int, tsr ...tensor.Tensor) {
sb := tensor.Reslice(eb, r, tensor.FullAxis, tensor.FullAxis)
mb, _ := NewMatrix(sb)
do, _ := NewDense(out.RowTensor(r).(*tensor.Float64))
do.Mul(ma, mb)
})
case na > 2 && nb > 2:
if ea.DimSize(0) != eb.DimSize(0) {
return errors.New("matrix.Mul: a and b input matricies are > 2 dimensional; must have same outer dimension sizes")
}
if ea.DimSize(2) != eb.DimSize(1) {
return mat.ErrShape
}
nr := ea.DimSize(0)
out.SetShapeSizes(nr, ea.DimSize(1), eb.DimSize(2))
tensor.VectorizeThreaded(ea.DimSize(1)*ea.DimSize(2)*eb.DimSize(1)*eb.DimSize(2),
func(tsr ...tensor.Tensor) int { return nr },
func(r int, tsr ...tensor.Tensor) {
sa := tensor.Reslice(ea, r, tensor.FullAxis, tensor.FullAxis)
ma, _ := NewMatrix(sa)
sb := tensor.Reslice(eb, r, tensor.FullAxis, tensor.FullAxis)
mb, _ := NewMatrix(sb)
do, _ := NewDense(out.RowTensor(r).(*tensor.Float64))
do.Mul(ma, mb)
})
default:
return mat.ErrShape
}
if collapse {
nd := out.NumDims()
sz := slices.Clone(out.Shape().Sizes)
if colDim == -1 {
out.SetShapeSizes(sz[:nd-1]...)
} else {
out.SetShapeSizes(append(sz[:nd-2], sz[nd-1])...)
}
}
return nil
}
// todo: following should handle N>2 dim case.
// Det returns the determinant of the given tensor.
// For a 2D matrix [[a, b], [c, d]] it this is ad - bc.
// See also [LogDet] for a version that is more numerically
// stable for large matricies.
func Det(a tensor.Tensor) *tensor.Float64 {
m, err := NewMatrix(a)
if errors.Log(err) != nil {
return tensor.NewFloat64Scalar(0)
}
return tensor.NewFloat64Scalar(mat.Det(m))
}
// LogDet returns the determinant of the given tensor,
// as the log and sign of the value, which is more
// numerically stable. The return is a 1D vector of length 2,
// with the first value being the log, and the second the sign.
func LogDet(a tensor.Tensor) *tensor.Float64 {
m, err := NewMatrix(a)
if errors.Log(err) != nil {
return tensor.NewFloat64Scalar(0)
}
l, s := mat.LogDet(m)
return tensor.NewFloat64FromValues(l, s)
}
// Inverse performs matrix inversion of a square matrix,
// logging an error for non-invertable cases.
// See [InverseOut] for a version that returns an error.
// If the input tensor is > 2D, it is treated as a list of 2D matricies
// which are each inverted.
func Inverse(a tensor.Tensor) *tensor.Float64 {
return CallOut1(InverseOut, a)
}
// InverseOut performs matrix inversion of a square matrix,
// returning an error for non-invertable cases. If the input tensor
// is > 2D, it is treated as a list of 2D matricies which are each inverted.
func InverseOut(a tensor.Tensor, out *tensor.Float64) error {
if err := StringCheck(a); err != nil {
return err
}
na := a.NumDims()
if na == 1 {
return mat.ErrShape
}
var asz []int
ea := a
if na > 2 {
asz = tensor.SplitAtInnerDims(a, 2)
if asz[0] == 1 {
ea = tensor.Reshape(a, asz[1:]...)
na = 2
}
}
if na == 2 {
if a.DimSize(0) != a.DimSize(1) {
return mat.ErrShape
}
ma, _ := NewMatrix(a)
out.SetShapeSizes(a.DimSize(0), a.DimSize(1))
do, _ := NewDense(out)
return do.Inverse(ma)
}
ea = tensor.Reshape(a, asz...)
if ea.DimSize(1) != ea.DimSize(2) {
return mat.ErrShape
}
nr := ea.DimSize(0)
out.SetShapeSizes(nr, ea.DimSize(1), ea.DimSize(2))
var errs []error
tensor.VectorizeThreaded(ea.DimSize(1)*ea.DimSize(2)*100,
func(tsr ...tensor.Tensor) int { return nr },
func(r int, tsr ...tensor.Tensor) {
sa := tensor.Reslice(ea, r, tensor.FullAxis, tensor.FullAxis)
ma, _ := NewMatrix(sa)
do, _ := NewDense(out.RowTensor(r).(*tensor.Float64))
err := do.Inverse(ma)
if err != nil {
errs = append(errs, err)
}
})
return errors.Join(errs...)
}
// Copyright (c) 2019, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package patterns
import (
"cogentcore.org/lab/base/randx"
"cogentcore.org/lab/tensor"
)
// FlipBits turns nOff bits that are currently On to Off and
// nOn bits that are currently Off to On, using permuted lists.
func FlipBits(tsr tensor.Values, nOff, nOn int, onVal, offVal float64) {
ln := tsr.Len()
if ln == 0 {
return
}
var ons, offs []int
for i := range ln {
vl := tsr.Float1D(i)
if vl == offVal {
offs = append(offs, i)
} else {
ons = append(ons, i)
}
}
randx.PermuteInts(ons, RandSource)
randx.PermuteInts(offs, RandSource)
if nOff > len(ons) {
nOff = len(ons)
}
if nOn > len(offs) {
nOn = len(offs)
}
for i := range nOff {
tsr.SetFloat1D(offVal, ons[i])
}
for i := range nOn {
tsr.SetFloat1D(onVal, offs[i])
}
}
// FlipBitsRows turns nOff bits that are currently On to Off and
// nOn bits that are currently Off to On, using permuted lists.
// Iterates over the outer-most tensor dimension as rows.
func FlipBitsRows(tsr tensor.Values, nOff, nOn int, onVal, offVal float64) {
rows, _ := tsr.Shape().RowCellSize()
for i := range rows {
trow := tsr.SubSpace(i)
FlipBits(trow, nOff, nOn, onVal, offVal)
}
}
// Copyright (c) 2019, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package patterns
import (
"fmt"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/reflectx"
"cogentcore.org/lab/tensor"
)
// Mix mixes patterns from different tensors into a combined set of patterns,
// over the outermost row dimension (i.e., each source is a list of patterns over rows).
// The source tensors must have the same cell size, and the existing shape of the destination
// will be used if compatible, otherwise reshaped with linear list of sub-tensors.
// Each source list wraps around if shorter than the total number of rows specified.
func Mix(dest tensor.Values, rows int, srcs ...tensor.Values) error {
var cells int
for i, src := range srcs {
_, c := src.Shape().RowCellSize()
if i == 0 {
cells = c
} else {
if c != cells {
err := errors.Log(fmt.Errorf("MixPatterns: cells size of source number %d, %d != first source: %d", i, c, cells))
return err
}
}
}
totlen := len(srcs) * cells * rows
if dest.Len() != totlen {
_, dcells := dest.Shape().RowCellSize()
if dcells == cells*len(srcs) {
dest.SetNumRows(rows)
} else {
sz := append([]int{rows}, len(srcs), cells)
dest.SetShapeSizes(sz...)
}
}
dtype := dest.DataType()
for i, src := range srcs {
si := i * cells
srows := src.DimSize(0)
for row := range rows {
srow := row % srows
for ci := range cells {
switch {
case reflectx.KindIsFloat(dtype):
dest.SetFloatRow(src.FloatRow(srow, ci), row, si+ci)
}
}
}
}
return nil
}
// Copyright (c) 2019, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package patterns
import (
"fmt"
"math"
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/base/randx"
"cogentcore.org/lab/stats/metric"
"cogentcore.org/lab/tensor"
)
// NFromPct returns the number of bits for given pct (proportion 0-1),
// relative to total n: just int(math.Round(pct * n))
func NFromPct(pct float64, n int) int {
return int(math.Round(pct * float64(n)))
}
// PermutedBinary sets the given tensor to contain nOn onVal values and the
// remainder are offVal values, using a permuted order of tensor elements (i.e.,
// randomly shuffled or permuted).
func PermutedBinary(tsr tensor.Values, nOn int, onVal, offVal float64) {
ln := tsr.Len()
if ln == 0 {
return
}
pord := RandSource.Perm(ln)
for i := range ln {
if i < nOn {
tsr.SetFloat1D(onVal, pord[i])
} else {
tsr.SetFloat1D(offVal, pord[i])
}
}
}
// PermutedBinaryRows uses the [tensor.RowMajor] view of a tensor as a column of rows
// as in a [table.Table], setting each row to contain nOn onVal values with the
// remainder being offVal values, using a permuted order of tensor elements
// (i.e., randomly shuffled or permuted). See also [PermutedBinaryMinDiff].
func PermutedBinaryRows(tsr tensor.Values, nOn int, onVal, offVal float64) {
rows, cells := tsr.Shape().RowCellSize()
if rows == 0 || cells == 0 {
return
}
pord := RandSource.Perm(cells)
for rw := range rows {
stidx := rw * cells
for i := 0; i < cells; i++ {
if i < nOn {
tsr.SetFloat1D(onVal, stidx+pord[i])
} else {
tsr.SetFloat1D(offVal, stidx+pord[i])
}
}
randx.PermuteInts(pord, RandSource)
}
}
// MinDiffPrintIterations set this to true to see the iteration stats for
// PermutedBinaryMinDiff -- for large, long-running cases.
var MinDiffPrintIterations = false
// PermutedBinaryMinDiff uses the [tensor.RowMajor] view of a tensor as a column of rows
// as in a [table.Table], setting each row to contain nOn onVal values, with the
// remainder being offVal values, using a permuted order of tensor elements
// (i.e., randomly shuffled or permuted). This version (see also [PermutedBinaryRows])
// ensures that all patterns have at least a given minimum distance
// from each other, expressed using minDiff = number of bits that must be different
// (can't be > nOn). If the mindiff constraint cannot be met within 100 iterations,
// an error is returned and automatically logged.
func PermutedBinaryMinDiff(tsr tensor.Values, nOn int, onVal, offVal float64, minDiff int) error {
rows, cells := tsr.Shape().RowCellSize()
if rows == 0 || cells == 0 {
return errors.New("empty tensor")
}
pord := RandSource.Perm(cells)
iters := 100
nunder := make([]int, rows) // per row
fails := 0
for itr := range iters {
for rw := range rows {
if itr > 0 && nunder[rw] == 0 {
continue
}
stidx := rw * cells
for i := range cells {
if i < nOn {
tsr.SetFloat1D(onVal, stidx+pord[i])
} else {
tsr.SetFloat1D(offVal, stidx+pord[i])
}
}
randx.PermuteInts(pord, RandSource)
}
for i := range nunder {
nunder[i] = 0
}
nbad := 0
mxnun := 0
for r1 := range rows {
r1v := tsr.SubSpace(r1)
for r2 := r1 + 1; r2 < rows; r2++ {
r2v := tsr.SubSpace(r2)
dst := metric.Hamming(tensor.As1D(r1v), tensor.As1D(r2v)).Float1D(0)
df := int(math.Round(float64(.5 * dst)))
if df < minDiff {
nunder[r1]++
mxnun = max(nunder[r1])
nunder[r2]++
mxnun = max(nunder[r2])
nbad++
}
}
}
if nbad == 0 {
break
}
fails++
if MinDiffPrintIterations {
fmt.Printf("PermutedBinaryMinDiff: Itr: %d NBad: %d MaxN: %d\n", itr, nbad, mxnun)
}
}
if fails == iters {
err := errors.Log(fmt.Errorf("PermutedBinaryMinDiff: minimum difference of: %d was not met: %d times, rows: %d", minDiff, fails, rows))
return err
}
return nil
}
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package patterns
import "cogentcore.org/lab/base/randx"
var (
// RandSource is a random source to use for all random numbers used in patterns.
// By default it just uses the standard Go math/rand source.
// If initialized, e.g., by calling NewRand(seed), then a separate stream of
// random numbers will be generated for all calls, and the seed is saved as
// RandSeed. It can be reinstated by calling RestoreSeed.
// Can also set RandSource to another existing randx.Rand source to use it.
RandSource = &randx.SysRand{}
// Random seed last set by NewRand or SetRandSeed.
RandSeed int64
)
// NewRand sets RandSource to a new separate random number stream
// using given seed, which is saved as RandSeed -- see RestoreSeed.
func NewRand(seed int64) {
RandSource = randx.NewSysRand(seed)
RandSeed = seed
}
// SetRandSeed sets existing random number stream to use given random
// seed, starting from the next call. Saves the seed in RandSeed -- see RestoreSeed.
func SetRandSeed(seed int64) {
RandSeed = seed
RestoreSeed()
}
// RestoreSeed restores the random seed last used -- random number sequence
// will repeat what was generated from that point onward.
func RestoreSeed() {
RandSource.Seed(RandSeed)
}
// Copyright (c) 2019, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package patterns
//go:generate core generate -add-types
import (
"fmt"
"strconv"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/metadata"
"cogentcore.org/lab/stats/stats"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorfs"
)
// NOnInTensor returns the number of bits active in given tensor
func NOnInTensor(trow tensor.Values) int {
return stats.Sum(trow).Int1D(0)
}
// PctActInTensor returns the percent activity in given tensor (NOn / size)
func PctActInTensor(trow tensor.Values) float32 {
return float32(NOnInTensor(trow)) / float32(trow.Len())
}
// Note: AppendFrom can be used to concatenate tensors.
// NameRows sets strings as prefix + row number with given number
// of leading zeros.
func NameRows(tsr tensor.Values, prefix string, nzeros int) {
ft := fmt.Sprintf("%s%%0%dd", prefix, nzeros)
rows := tsr.DimSize(0)
for i := range rows {
tsr.SetString1D(fmt.Sprintf(ft, i), i)
}
}
// Shuffle returns a [tensor.Rows] view of the given source tensor
// with the outer row-wise dimension randomly shuffled (permuted).
func Shuffle(src tensor.Values) *tensor.Rows {
idx := RandSource.Perm(src.DimSize(0))
return tensor.NewRows(src, idx...)
}
// ReplicateRows adds nCopies rows of the source tensor pattern into
// the destination tensor. The destination shape is set to ensure
// it can contain the results, preserving any existing rows of data.
func ReplicateRows(dest, src tensor.Values, nCopies int) {
curRows := 0
if dest.NumDims() > 0 {
curRows = dest.DimSize(0)
}
totRows := curRows + nCopies
dshp := append([]int{totRows}, src.Shape().Sizes...)
dest.SetShapeSizes(dshp...)
for rw := range nCopies {
dest.SetRowTensor(src, curRows+rw)
}
}
// SplitRows splits a source tensor into a set of tensors in the given
// tensorfs directory, with the given list of names, splitting at given
// rows. There should be 1 more name than rows. If names are omitted then
// the source name + incrementing counter will be used.
func SplitRows(dir *tensorfs.Node, src tensor.Values, names []string, rows ...int) error {
hasNames := len(names) != 0
if hasNames && len(names) != len(rows)+1 {
err := errors.Log(fmt.Errorf("patterns.SplitRows: must pass one more name than number of rows to split on"))
return err
}
all := append(rows, src.DimSize(0)) // final row
srcName := metadata.Name(src)
srcShape := src.ShapeSizes()
dtype := src.DataType()
prev := 0
for i, cur := range all {
if prev >= cur {
err := errors.Log(fmt.Errorf("patterns.SplitRows: rows must increase progressively"))
return err
}
name := ""
switch {
case hasNames:
name = names[i]
case len(srcName) > 0:
name = fmt.Sprintf("%s_%d", srcName, i)
default:
name = strconv.Itoa(i)
}
nrows := cur - prev
srcShape[0] = nrows
spl := tensorfs.ValueType(dir, name, dtype, srcShape...)
for rw := range nrows {
spl.SubSpace(rw).CopyFrom(src.SubSpace(prev + rw))
}
prev = cur
}
return nil
}
// AddVocabDrift adds a row-by-row drifting pool to the vocabulary,
// starting from the given row in existing vocabulary item
// (which becomes starting row in this one -- drift starts in second row).
// The current row patterns are generated by taking the previous row
// pattern and flipping pctDrift percent of active bits (min of 1 bit).
// func AddVocabDrift(mp Vocab, name string, rows int, pctDrift float32, copyFrom string, copyRow int) (tensor.Values, error) {
// cp, err := mp.ByName(copyFrom)
// if err != nil {
// return nil, err
// }
// tsr := &tensor.Float32{}
// cpshp := cp.Shape().Sizes
// cpshp[0] = rows
// tsr.SetShapeSizes(cpshp...)
// mp[name] = tsr
// cprow := cp.SubSpace(copyRow).(tensor.Values)
// trow := tsr.SubSpace(0)
// trow.CopyFrom(cprow)
// nOn := NOnInTensor(cprow)
// rmdr := 0.0 // remainder carryover in drift
// drift := float64(nOn) * float64(pctDrift) // precise fractional amount of drift
// for i := 1; i < rows; i++ {
// srow := tsr.SubSpace(i - 1)
// trow := tsr.SubSpace(i)
// trow.CopyFrom(srow)
// curDrift := math.Round(drift + rmdr) // integer amount
// nDrift := int(curDrift)
// if nDrift > 0 {
// FlipBits(trow, nDrift, nDrift, 1, 0)
// }
// rmdr += drift - curDrift // accumulate remainder
// }
// return tsr, nil
// }
// Code generated by "goal build"; DO NOT EDIT.
//line body.goal:1
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package physics
import (
"math"
"cogentcore.org/core/math32"
)
//gosl:start
// BodyVars are body state variables stored in tensor.Float32
type BodyVars int32 //enums:enum
const (
// BodyShape is the shape type of the object, as a Shapes type.
BodyShape BodyVars = iota
// BodyDynamic is the index into Dynamics for this body,
// which is -1 for static bodies. Use this to get current
// Pos and Quat values for a dynamic body.
BodyDynamic
// BodyWorld partitions bodies into different worlds for
// collision detection: Global bodies = -1 can collide with
// everything; otherwise only items within the same world collide.
// NewBody uses [World.CurrentWorld] to initialize.
BodyWorld
// BodyGroup partitions bodies within worlds into different groups
// for collision detection. 0 does not collide with anything.
// Negative numbers are global within a world, except they don't
// collide amongst themselves (all non-dynamic bodies should go
// in -1 because they don't collide amongst each-other, but do
// potentially collide with dynamics).
// Positive numbers only collide amongst themselves, and with
// negative groups, but not other positive groups. To avoid
// unwanted collisions, put bodies into separate groups.
// There is an automatic constraint that the two objects
// within a single joint do not collide with each other, so this
// does not need to be handled here.
BodyGroup
// BodyHSize is the half-size (e.g., radius) of the body.
// Values depend on shape type: X is generally radius,
// Y is half-height.
BodyHSizeX
BodyHSizeY
BodyHSizeZ
// BodyThick is the thickness of the body, as a hollow shape.
// If 0, then it is a solid shape (default).
BodyThick
// physical properties
// BodyMass is the mass of the object.
BodyMass
// BodyInvMass is 1/mass of the object or 0 if no mass.
BodyInvMass
// BodyBounce specifies the COR or coefficient of restitution (0..1),
// which determines how elastic the collision is,
// i.e., final velocity / initial velocity.
BodyBounce
// BodyFriction is the standard coefficient for linear friction (mu).
BodyFriction
// BodyFrictionTortion is resistance to spinning at the contact point.
BodyFrictionTortion
// BodyFrictionRolling is resistance to rolling motion at contact.
BodyFrictionRolling
// 3D position of body (structural center).
BodyPosX
BodyPosY
BodyPosZ
// Quaternion rotation of body.
BodyQuatX
BodyQuatY
BodyQuatZ
BodyQuatW
// Relative center-of-mass offset from 3D position of body.
BodyComX
BodyComY
BodyComZ
// Inertia 3x3 matrix (column matrix organization, r,c labels).
BodyInertiaXX
BodyInertiaYX
BodyInertiaZX
BodyInertiaXY
BodyInertiaYY
BodyInertiaZY
BodyInertiaXZ
BodyInertiaYZ
BodyInertiaZZ
// InvInertia inverse inertia 3x3 matrix (column matrix organization, r,c labels).
BodyInvInertiaXX
BodyInvInertiaYX
BodyInvInertiaZX
BodyInvInertiaXY
BodyInvInertiaYY
BodyInvInertiaZY
BodyInvInertiaXZ
BodyInvInertiaYZ
BodyInvInertiaZZ
// radius for broadphase collision
BodyRadius
)
func GetBodyShape(idx int32) Shapes {
return Shapes(math.Float32bits(Bodies.Value(int(idx), int(BodyShape))))
}
func SetBodyShape(idx int32, shape Shapes) {
Bodies.Set(math.Float32frombits(uint32(shape)), int(idx), int(BodyShape))
}
func SetBodyDynamic(idx, dynIdx int32) {
Bodies.Set(math.Float32frombits(uint32(dynIdx)), int(idx), int(BodyDynamic))
}
func GetBodyDynamic(idx int32) int32 {
return int32(math.Float32bits(Bodies.Value(int(idx), int(BodyDynamic))))
}
// SetBodyWorld partitions bodies into different worlds for
// collision detection: Global bodies = -1 can collide with
// everything; otherwise only items within the same world collide.
func SetBodyWorld(idx, w int32) {
Bodies.Set(math.Float32frombits(uint32(w)), int(idx), int(BodyWorld))
}
func GetBodyWorld(idx int32) int32 {
return int32(math.Float32bits(Bodies.Value(int(idx), int(BodyWorld))))
}
// SetBodyGroup partitions bodies within worlds into different groups
// for collision detection. 0 does not collide with anything.
// Negative numbers are global within a world, except they don't
// collide amongst themselves (all non-dynamic bodies should go
// in -1 because they don't collide amongst each-other, but do
// potentially collide with dynamics).
// Positive numbers only collide amongst themselves, and with
// negative groups, but not other positive groups. To avoid
// unwanted collisions, put bodies into separate groups.
// There is an automatic constraint that the two objects
// within a single joint do not collide with each other, so this
// does not need to be handled here.
func SetBodyGroup(idx, w int32) {
Bodies.Set(math.Float32frombits(uint32(w)), int(idx), int(BodyGroup))
}
func GetBodyGroup(idx int32) int32 {
return int32(math.Float32bits(Bodies.Value(int(idx), int(BodyGroup))))
}
func BodyHSize(idx int32) math32.Vector3 {
return math32.Vec3(Bodies.Value(int(idx), int(BodyHSizeX)), Bodies.Value(int(idx), int(BodyHSizeY)), Bodies.Value(int(idx), int(BodyHSizeZ)))
}
func SetBodyHSize(idx int32, size math32.Vector3) {
Bodies.Set(size.X, int(idx), int(BodyHSizeX))
Bodies.Set(size.Y, int(idx), int(BodyHSizeY))
Bodies.Set(size.Z, int(idx), int(BodyHSizeZ))
}
func BodyPos(idx int32) math32.Vector3 {
return math32.Vec3(Bodies.Value(int(idx), int(BodyPosX)), Bodies.Value(int(idx), int(BodyPosY)), Bodies.Value(int(idx), int(BodyPosZ)))
}
func SetBodyPos(idx int32, pos math32.Vector3) {
Bodies.Set(pos.X, int(idx), int(BodyPosX))
Bodies.Set(pos.Y, int(idx), int(BodyPosY))
Bodies.Set(pos.Z, int(idx), int(BodyPosZ))
}
func BodyQuat(idx int32) math32.Quat {
return math32.NewQuat(Bodies.Value(int(idx), int(BodyQuatX)), Bodies.Value(int(idx), int(BodyQuatY)), Bodies.Value(int(idx), int(BodyQuatZ)), Bodies.Value(int(idx), int(BodyQuatW)))
}
func SetBodyQuat(idx int32, rot math32.Quat) {
Bodies.Set(rot.X, int(idx), int(BodyQuatX))
Bodies.Set(rot.Y, int(idx), int(BodyQuatY))
Bodies.Set(rot.Z, int(idx), int(BodyQuatZ))
Bodies.Set(rot.W, int(idx), int(BodyQuatW))
}
// BodyDynamicPos gets the position for dynamic bodies or
// static position if not dynamic. cni is the current / next index.
func BodyDynamicPos(idx, cni int32) math32.Vector3 {
didx := GetBodyDynamic(idx)
if didx < 0 {
return BodyPos(idx)
}
return DynamicPos(didx, cni)
}
// BodyDynamicQuat gets the quat rotation for dynamic bodies or
// static rotation if not dynamic. cni is the current / next index.
func BodyDynamicQuat(idx, cni int32) math32.Quat {
didx := GetBodyDynamic(idx)
if didx < 0 {
return BodyQuat(idx)
}
return DynamicQuat(didx, cni)
}
func BodyCom(idx int32) math32.Vector3 {
return math32.Vec3(Bodies.Value(int(idx), int(BodyComX)), Bodies.Value(int(idx), int(BodyComY)), Bodies.Value(int(idx), int(BodyComZ)))
}
func SetBodyCom(idx int32, pos math32.Vector3) {
Bodies.Set(pos.X, int(idx), int(BodyComX))
Bodies.Set(pos.Y, int(idx), int(BodyComY))
Bodies.Set(pos.Z, int(idx), int(BodyComZ))
}
func BodyInertia(idx int32) math32.Matrix3 {
return math32.Mat3(Bodies.Value(int(idx), int(BodyInertiaXX)), Bodies.Value(int(idx), int(BodyInertiaYX)), Bodies.Value(int(idx), int(BodyInertiaZX)),
Bodies.Value(int(idx), int(BodyInertiaXY)), Bodies.Value(int(idx), int(BodyInertiaYY)), Bodies.Value(int(idx), int(BodyInertiaZY)),
Bodies.Value(int(idx), int(BodyInertiaXZ)), Bodies.Value(int(idx), int(BodyInertiaYZ)), Bodies.Value(int(idx), int(BodyInertiaZZ)))
}
func BodyInvInertia(idx int32) math32.Matrix3 {
return math32.Mat3(Bodies.Value(int(idx), int(BodyInvInertiaXX)), Bodies.Value(int(idx), int(BodyInvInertiaYX)), Bodies.Value(int(idx), int(BodyInvInertiaZX)),
Bodies.Value(int(idx), int(BodyInvInertiaXY)), Bodies.Value(int(idx), int(BodyInvInertiaYY)), Bodies.Value(int(idx), int(BodyInvInertiaZY)),
Bodies.Value(int(idx), int(BodyInvInertiaXZ)), Bodies.Value(int(idx), int(BodyInvInertiaYZ)), Bodies.Value(int(idx), int(BodyInvInertiaZZ)))
}
func SetBodyInertia(idx int32, inertia math32.Matrix3) {
for i := range 9 {
Bodies.Set(inertia[i], int(idx), int(int(BodyInertiaXX)+i))
}
}
func SetBodyInvInertia(idx int32, invInertia math32.Matrix3) {
for i := range 9 {
Bodies.Set(invInertia[i], int(idx), int(int(BodyInvInertiaXX)+i))
}
}
// SetBodyThick specifies the thickness of the body, as a hollow shape.
// if 0, then it is solid.
func SetBodyThick(idx int32, val float32) {
Bodies.Set(val, int(idx), int(BodyThick))
}
// SetBodyBounce specifies the COR or coefficient of restitution (0..1),
// which determines how elastic the collision is,
// i.e., final velocity / initial velocity.
func SetBodyBounce(idx int32, val float32) {
Bodies.Set(val, int(idx), int(BodyBounce))
}
// SetBodyFriction is the standard coefficient for linear friction (mu).
func SetBodyFriction(idx int32, val float32) {
Bodies.Set(val, int(idx), int(BodyFriction))
}
// SetBodyFrictionTortion is resistance to spinning at the contact point.
func SetBodyFrictionTortion(idx int32, val float32) {
Bodies.Set(val, int(idx), int(BodyFrictionTortion))
}
// SetBodyFrictionRolling is resistance to rolling motion at contact.
func SetBodyFrictionRolling(idx int32, val float32) {
Bodies.Set(val, int(idx), int(BodyFrictionRolling))
}
//gosl:end
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package builder
import (
"cogentcore.org/core/math32"
"cogentcore.org/lab/physics"
"cogentcore.org/lab/physics/phyxyz"
)
// Body is a rigid body.
type Body struct {
// World is the world number for physics: -1 = globals, else positive
// are distinct non-interacting worlds.
World int
// WorldIndex is the index of world within builder Worlds list.
WorldIndex int
// Object is the index within World's Objects list.
Object int
// ObjectBody is the index within the Object's Bodies list.
ObjectBody int
// Shape of the body.
Shape physics.Shapes
// Dynamic makes this a dynamic body.
Dynamic bool
// Group partitions bodies within worlds into different groups
// for collision detection. 0 does not collide with anything.
// Negative numbers are global within a world, except they don't
// collide amongst themselves (all non-dynamic bodies should go
// in -1 because they don't collide amongst each-other, but do
// potentially collide with dynamics).
// Positive numbers only collide amongst themselves, and with
// negative groups, but not other positive groups. To avoid
// unwanted collisions, put bodies into separate groups.
// There is an automatic constraint that the two objects
// within a single joint do not collide with each other, so this
// does not need to be handled here.
Group int
// HSize is the half-size (e.g., radius) of the body.
// Values depend on shape type: X is generally radius,
// Y is half-height.
HSize math32.Vector3
// Thick is the thickness of the body, as a hollow shape.
// If 0, then it is a solid shape (default).
Thick float32
// Mass of the object. Only relevant for Dynamic bodies.
Mass float32
// Pose has the position and rotation.
Pose Pose
// Com is the center-of-mass offset from the Pose.Pos.
Com math32.Vector3
// Bounce specifies the COR or coefficient of restitution (0..1),
// which determines how elastic the collision is,
// i.e., final velocity / initial velocity.
Bounce float32
// Friction is the standard coefficient for linear friction (mu).
Friction float32
// FrictionTortion is resistance to spinning at the contact point.
FrictionTortion float32
// FrictionRolling is resistance to rolling motion at contact.
FrictionRolling float32
// Optional [phyxyz.Skin] for visualizing the body.
Skin *phyxyz.Skin
// BodyIndex is the index of this body in the [physics.Model] Bodies list,
// once built.
BodyIndex int32
// DynamicIndex is the index of this dynamic body in the
// [physics.Model] Dynamics list, once built.
DynamicIndex int32
}
// NewBody adds a new body with given parameters.
// Returns the [Body] which can then be further customized.
// Use this for Static elements; NewDynamic for dynamic elements.
func (ob *Object) NewBody(shape physics.Shapes, hsize, pos math32.Vector3, rot math32.Quat) *Body {
idx := len(ob.Bodies)
bd := &Body{World: ob.World, WorldIndex: ob.WorldIndex, Object: ob.Object, ObjectBody: idx, Shape: shape, HSize: hsize}
bd.Pose.Pos = pos
bd.Pose.Quat = rot
bd.Group = -1 // default static
ob.Bodies = append(ob.Bodies, bd)
return ob.Bodies[idx]
}
// NewDynamic adds a new dynamic body with given parameters.
// Returns the [Body] which can then be further customized.
func (ob *Object) NewDynamic(shape physics.Shapes, mass float32, hsize, pos math32.Vector3, rot math32.Quat) *Body {
bd := ob.NewBody(shape, hsize, pos, rot)
bd.Dynamic = true
bd.Mass = mass
bd.Group = 1
return bd
}
// NewBodySkin adds a new body with given parameters, including name and
// color parameters used for intializing a [phyxyz.Skin] in given [phyxyz.Scene].
// Returns the [Body] which can then be further customized.
// Use this for Static elements; NewDynamicSkin for dynamic elements.
func (ob *Object) NewBodySkin(sc *phyxyz.Scene, name string, shape physics.Shapes, clr string, hsize, pos math32.Vector3, rot math32.Quat) *Body {
bd := ob.NewBody(shape, hsize, pos, rot)
bd.Group = -1 // default static
bd.NewSkin(sc, name, clr)
return bd
}
// NewSkin adds a new skin for body with given name and color parameters.
func (bd *Body) NewSkin(sc *phyxyz.Scene, name string, clr string) *phyxyz.Skin {
sk := sc.NewSkin(bd.Shape, name, clr, bd.HSize, bd.Pose.Pos, bd.Pose.Quat)
bd.Skin = sk
return sk
}
// NewDynamicSkin adds a new dynamic body with given parameters,
// including name and color parameters used for intializing a [phyxyz.Skin]
// in given [phyxyz.Scene].
// Returns the [Body] which can then be further customized.
func (ob *Object) NewDynamicSkin(sc *phyxyz.Scene, name string, shape physics.Shapes, clr string, mass float32, hsize, pos math32.Vector3, rot math32.Quat) *Body {
bd := ob.NewBodySkin(sc, name, shape, clr, hsize, pos, rot)
bd.Dynamic = true
bd.Mass = mass
bd.Group = 1
return bd
}
func (bd *Body) Copy(sb *Body) {
*bd = *sb
bd.Skin = nil // skins are unique
}
/////// Physics functions
func (bd *Body) NewPhysicsBody(ml *physics.Model, world int) {
var bi, di int32
if bd.Dynamic {
bi, di = ml.NewDynamic(bd.Shape, bd.Mass, bd.HSize, bd.Pose.Pos, bd.Pose.Quat)
} else {
bi = ml.NewBody(bd.Shape, bd.HSize, bd.Pose.Pos, bd.Pose.Quat)
di = -1
}
bd.BodyIndex = bi
bd.DynamicIndex = di
physics.SetBodyWorld(bi, int32(world))
physics.SetBodyGroup(bi, int32(bd.Group))
// fmt.Println("\t\t", bi, di, bd.Pose.Pos, bd.Pose.Quat)
if bd.Skin != nil {
bd.Skin.BodyIndex = bi
bd.Skin.DynamicIndex = di
}
physics.SetBodyThick(bi, bd.Thick)
physics.SetBodyCom(bi, bd.Com)
physics.SetBodyBounce(bi, bd.Bounce)
physics.SetBodyFriction(bi, bd.Friction)
physics.SetBodyFrictionTortion(bi, bd.FrictionTortion)
physics.SetBodyFrictionRolling(bi, bd.FrictionRolling)
}
// PoseToPhysics sets the current body poses to the physics current state.
// For Dynamic bodies, sets dynamic state. Also updates world-anchored joints.
func (bd *Body) PoseToPhysics() {
if bd.DynamicIndex >= 0 {
params := physics.GetParams(0)
physics.SetDynamicPos(bd.DynamicIndex, params.Next, bd.Pose.Pos)
physics.SetDynamicQuat(bd.DynamicIndex, params.Next, bd.Pose.Quat)
} else {
physics.SetBodyPos(bd.BodyIndex, bd.Pose.Pos)
physics.SetBodyQuat(bd.BodyIndex, bd.Pose.Quat)
}
}
// PoseFromPhysics gets the current body poses from the physics current state.
func (bd *Body) PoseFromPhysics() {
if bd.DynamicIndex >= 0 {
params := physics.GetParams(0)
bd.Pose.Pos = physics.DynamicPos(bd.DynamicIndex, params.Next)
bd.Pose.Quat = physics.DynamicQuat(bd.DynamicIndex, params.Next)
} else {
bd.Pose.Pos = physics.BodyPos(bd.BodyIndex)
bd.Pose.Quat = physics.BodyQuat(bd.BodyIndex)
}
}
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package builder
//go:generate core generate -add-types -setters
import (
"cogentcore.org/core/math32"
"cogentcore.org/lab/physics"
"cogentcore.org/lab/physics/phyxyz"
)
// Builder is the global container of [physics.Model] elements,
// organized into worlds that are independently updated.
type Builder struct {
// Worlds are the independent world elements.
Worlds []*World
// ReplicasStart is the starting Worlds index for replicated world bodies.
// Set by ReplicateWorld, and used to set corresponding value in Model.
ReplicasStart int
// ReplicasN is the total number of replicated Worlds (including source).
// Set by ReplicateWorld, and used to set corresponding value in Model.
ReplicasN int
}
func NewBuilder() *Builder {
return &Builder{}
}
// Reset starts over making a new model.
func (bl *Builder) Reset() {
bl.Worlds = nil
}
func (bl *Builder) World(idx int) *World {
return bl.Worlds[idx]
}
// NewGlobalWorld creates a new world with World index = -1,
// which are globals that collide with all worlds.
func (bl *Builder) NewGlobalWorld() *World {
idx := len(bl.Worlds)
bl.Worlds = append(bl.Worlds, &World{World: -1})
return bl.Worlds[idx]
}
// NewWorld creates a new standard (non-global) world, with
// world index = index of last one + 1.
func (bl *Builder) NewWorld() *World {
wn := 0
idx := len(bl.Worlds)
if idx > 0 {
wn = bl.Worlds[idx-1].World + 1
}
bl.Worlds = append(bl.Worlds, &World{World: wn, WorldIndex: idx})
return bl.Worlds[idx]
}
// Build builds a physics model, with optional [phyxyz.Scene] for
// visualization (using Skin elements created for bodies).
func (bl *Builder) Build(ml *physics.Model, sc *phyxyz.Scene) {
bSt := int32(-1)
bN := int32(0)
jSt := int32(-1)
jN := int32(0)
for wi, wl := range bl.Worlds {
// fmt.Println("\n######## World:", wl.World)
for _, ob := range wl.Objects {
// fmt.Println("\n\t#### Object")
for _, bd := range ob.Bodies {
bd.NewPhysicsBody(ml, wl.World)
if bl.ReplicasN > 0 && wi == bl.ReplicasStart {
bN++
if bSt < 0 {
bSt = bd.BodyIndex
}
}
}
if len(ob.Joints) == 0 {
continue
}
ml.NewObject()
for _, jd := range ob.Joints {
jd.NewPhysicsJoint(ml, ob)
if bl.ReplicasN > 0 && wi == bl.ReplicasStart {
jN++
if jSt < 0 {
jSt = jd.JointIndex
}
}
}
}
}
if bN > 0 {
ml.ReplicasN = int32(bl.ReplicasN)
ml.ReplicaBodiesStart = bSt
ml.ReplicaBodiesN = bN
ml.ReplicaJointsStart = jSt
ml.ReplicaJointsN = jN
}
}
// InitState initializes the current state variables in the builder.
// This does not call InitState in physics, because that depends on
// whether the Sccene is being used.
func (bl *Builder) InitState() {
for _, wl := range bl.Worlds {
for _, ob := range wl.Objects {
ob.InitState()
}
}
}
// RunSensors runs the sensor functions for this Builder.
func (bl *Builder) RunSensors() {
for _, wl := range bl.Worlds {
wl.RunSensors()
}
}
// ReplicateWorld makes copies of given world to form an X,Y grid of
// worlds with given optional offsets (Y, X) added between world objects.
// Note that worldIdx is the index in Worlds, not the world number.
// Because different worlds do not interact, offsets are not necessary
// and can potentially affect numerical accuracy.
// If the given [phyxyz.Scene] is non-nil, then new skins will be made
// for the replicated bodies. Otherwise, the [phyxyz.Scene] can view
// different replicas.
func (bl *Builder) ReplicateWorld(sc *phyxyz.Scene, worldIdx, nY, nX int, offs ...math32.Vector3) {
src := bl.World(worldIdx)
var Yoff, Xoff math32.Vector3
if len(offs) > 0 {
Yoff = offs[0]
}
if len(offs) > 1 {
Xoff = offs[1]
}
for y := range nY {
for x := range nX {
if x == 0 && y == 0 {
continue
}
nw := bl.NewWorld()
wi := nw.WorldIndex
nw.Copy(src)
nw.SetWorldIndex(wi)
off := Yoff.MulScalar(float32(y)).Add(Xoff.MulScalar(float32(x)))
nw.Move(off)
if sc != nil {
nw.CopySkins(sc, src)
}
}
}
bl.ReplicasStart = worldIdx
bl.ReplicasN = nY * nX
}
// CloneSkins copies existing Body skins into the given [phyxyz.Scene],
// thereby configuring the given scene to view the physics model for this builder.
func (bl *Builder) CloneSkins(sc *phyxyz.Scene) {
for _, wl := range bl.Worlds {
for _, ob := range wl.Objects {
for _, bd := range ob.Bodies {
if bd.Skin == nil {
continue
}
sc.AddSkinClone(bd.Skin)
}
}
}
}
// ReplicaWorld returns the replica World at given replica index,
// Where replica is index into replicated worlds (0 = original).
func (bl *Builder) ReplicaWorld(replica int) *World {
return bl.Worlds[bl.ReplicasStart+replica]
}
// ReplicaObject returns the replica corresponding to given [Object],
// Where replica is index into replicated worlds (0 = original).
func (bl *Builder) ReplicaObject(ob *Object, replica int) *Object {
wl := bl.ReplicaWorld(replica)
return wl.Object(ob.Object)
}
// ReplicaBody returns the replica corresponding to given [Body],
// Where replica is index into replicated worlds (0 = original).
func (bl *Builder) ReplicaBody(bd *Body, replica int) *Body {
wl := bl.ReplicaWorld(replica)
return wl.Object(bd.Object).Body(bd.ObjectBody)
}
// ReplicaJoint returns the replica corresponding to given [Joint],
// Where replica is index into replicated worlds (0 = original).
func (bl *Builder) ReplicaJoint(bd *Joint, replica int) *Joint {
wl := bl.ReplicaWorld(replica)
return wl.Object(bd.Object).Joint(bd.ObjectJoint)
}
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package builder
import (
"cogentcore.org/core/math32"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/lab/physics"
)
// Joint describes a joint between two bodies.
type Joint struct {
// World is the world number for physics: -1 = globals, else positive
// are distinct non-interacting worlds.
World int
// WorldIndex is the index of world within builder Worlds list.
WorldIndex int
// Object is the index within World's Objects list.
Object int
// ObjectJoint is the index within Object's Joints list.
ObjectJoint int
// Parent is index within an Object for parent body.
// -1 for world-anchored parent.
Parent int
// Parent is index within an Object for parent body.
Child int
// Type is the type of the joint.
Type physics.JointTypes
// PPose is the parent position and orientation of the joint
// in the parent's body-centered coordinates.
PPose Pose
// CPose is the child position and orientation of the joint
// in the parent's body-centered coordinates.
CPose Pose
// ParentFixed does not update the parent side of the joint.
ParentFixed bool
// NoLinearRotation ignores the rotational (angular) effects of
// linear joint position constraints (i.e., Coriolis and centrifugal forces)
// which can otherwise interfere with rotational position constraints in
// joints with both linear and angular DoFs
// (e.g., [PlaneXZ], for which this is on by default).
NoLinearRotation bool
// LinearDoFN is the number of linear degrees of freedom (3 max).
LinearDoFN int
// AngularDoFN is the number of linear degrees of freedom (3 max).
AngularDoFN int
// DoFs are the degrees-of-freedom for this joint.
DoFs []*DoF
// JointIndex is the index of this joint in [physics.Joints] when built.
JointIndex int32
}
// Controls are the per degrees-of-freedom (DoF) joint control inputs.
type Controls struct {
// Force is the force input driving the joint.
Force float32
// Pos is the position target value, where 0 is the initial
// position. For angular joints, this is in radians.
Pos float32
// Stiff determines how strongly the target position
// is enforced: 0 = not at all; larger = stronger (e.g., 1000 or higher).
// Set to 0 to allow the joint to be fully flexible.
Stiff float32
// Vel is the velocity target value. For example, 0
// effectively damps joint movement in proportion to Damp parameter.
Vel float32
// Damp determines how strongly the target velocity is enforced:
// 0 = not at all; larger = stronger (e.g., 1 is reasonable).
// Set to 0 to allow the joint to be fully flexible.
Damp float32
}
func (ct *Controls) Defaults() {
ct.Stiff = 1000
ct.Damp = 20
}
// DoF is a degree-of-freedom for a [Joint].
type DoF struct {
// Axis is the axis of articulation.
Axis math32.Vector3
// Limit has the limits for motion of this DoF.
Limit minmax.F32
// Init are the initial control values.
Init Controls
// Current are the current control values (based on method calls).
Current Controls
}
func (df *DoF) Defaults() {
df.Limit.Min = -physics.JointLimitUnlimited
df.Limit.Max = physics.JointLimitUnlimited
df.Init.Defaults()
df.Current.Defaults()
}
func (df *DoF) InitState() {
df.Current = df.Init
}
func (jd *Joint) DoF(idx int) *DoF {
return jd.DoFs[idx]
}
func (jd *Joint) Copy(sj *Joint) {
*jd = *sj
jd.DoFs = make([]*DoF, len(sj.DoFs))
for i := range jd.DoFs {
jd.DoFs[i] = &DoF{}
jd.DoF(i).Copy(sj.DoF(i))
}
}
func (df *DoF) Copy(sd *DoF) {
*df = *sd
}
// newJoint adds a new joint of given type.
func (ob *Object) newJoint(typ physics.JointTypes, parent, child *Body, ppos, cpos math32.Vector3, linDoF, angDoF int) *Joint {
pidx := -1
if parent != nil {
pidx = parent.ObjectBody
}
idx := len(ob.Joints)
ob.Joints = append(ob.Joints, &Joint{World: ob.World, WorldIndex: ob.WorldIndex, Object: ob.Object, ObjectJoint: idx, Parent: pidx, Child: child.ObjectBody, Type: typ, LinearDoFN: linDoF, AngularDoFN: angDoF})
jd := ob.Joint(idx)
jd.PPose.Pos = ppos
jd.PPose.Quat = math32.NewQuatIdentity()
jd.CPose.Pos = cpos
jd.CPose.Quat = math32.NewQuatIdentity()
ndof := linDoF + angDoF
if ndof > 0 {
jd.DoFs = make([]*DoF, linDoF+angDoF)
for i := range ndof {
dof := &DoF{}
jd.DoFs[i] = dof
dof.Defaults()
}
}
return jd
}
// NewJointFixed adds a new Fixed joint as a child of given parent.
// Use nil for parent to add a world-anchored joint.
// ppos, cpos are the relative positions from the parent, child.
// These are for the non-rotated body (i.e., body rotation is applied
// to these positions as well).
// Sets relative rotation matricies to identity by default.
func (ob *Object) NewJointFixed(parent, child *Body, ppos, cpos math32.Vector3) *Joint {
jd := ob.newJoint(physics.Fixed, parent, child, ppos, cpos, 0, 0)
jd.NoLinearRotation = true
return jd
}
// NewJointPrismatic adds a new Prismatic (slider) joint as a child
// of given parent. Use nil for parent to add a world-anchored joint.
// ppos, cpos are the relative positions from the parent, child.
// These are for the non-rotated body (i.e., body rotation is applied
// to these positions as well).
// Sets relative rotation matricies to identity by default.
// axis is the axis of articulation for the joint.
// Use [SetJointDoF] to set the remaining DoF parameters.
func (ob *Object) NewJointPrismatic(parent, child *Body, ppos, cpos, axis math32.Vector3) *Joint {
jd := ob.newJoint(physics.Prismatic, parent, child, ppos, cpos, 1, 0)
jd.DoFs[0].Axis = axis
return jd
}
// NewJointRevolute adds a new Revolute (hinge, axel) joint as a child
// of given parent. Use nil for parent to add a world-anchored joint.
// ppos, cpos are the relative positions from the parent, child.
// These are for the non-rotated body (i.e., body rotation is applied
// to these positions as well).
// Sets relative rotation matricies to identity by default.
// axis is the axis of articulation for the joint.
// Use [SetJointDoF] to set the remaining DoF parameters.
func (ob *Object) NewJointRevolute(parent, child *Body, ppos, cpos, axis math32.Vector3) *Joint {
jd := ob.newJoint(physics.Revolute, parent, child, ppos, cpos, 0, 1)
jd.DoFs[0].Axis = axis
return jd
}
// NewJointBall adds a new Ball joint (3 angular DoF) as a child
// of given parent. Use nil for parent to add a world-anchored joint.
// ppos, cpos are the relative positions from the parent, child.
// These are for the non-rotated body (i.e., body rotation is applied
// to these positions as well).
// Sets relative rotation matricies to identity by default.
// Use [SetJointDoF] to set the remaining DoF parameters.
func (ob *Object) NewJointBall(parent, child *Body, ppos, cpos math32.Vector3) *Joint {
jd := ob.newJoint(physics.Ball, parent, child, ppos, cpos, 0, 3)
return jd
}
// NewJointDistance adds a new Distance joint (6 DoF),
// with distance constrained only on the first linear X axis,
// as a child of given parent. Use nil for parent to add a world-anchored joint.
// ppos, cpos are the relative positions from the parent, child.
// These are for the non-rotated body (i.e., body rotation is applied
// to these positions as well).
// Sets relative rotation matricies to identity by default.
// Use [SetJointDoF] to set the remaining DoF parameters.
func (ob *Object) NewJointDistance(parent, child *Body, ppos, cpos math32.Vector3, minDist, maxDist float32) *Joint {
jd := ob.newJoint(physics.Ball, parent, child, ppos, cpos, 3, 3)
jd.DoFs[0].Limit.Min = minDist
jd.DoFs[0].Limit.Max = maxDist
return jd
}
// NewJointFree adds a new Free joint as a child
// of given parent. Use nil for parent to add a world-anchored joint.
// ppos, cpos are the relative positions from the parent, child.
// These are for the non-rotated body (i.e., body rotation is applied
// to these positions as well).
// Sets relative rotation matricies to identity by default.
// Use [SetJointDoF] to set the remaining DoF parameters.
func (ob *Object) NewJointFree(parent, child *Body, ppos, cpos math32.Vector3) *Joint {
jd := ob.newJoint(physics.Free, parent, child, ppos, cpos, 0, 0)
return jd
}
// NewJointPlaneXZ adds a new 3 DoF Planar motion joint suitable for
// controlling the motion of a body on the standard X-Z plane (Y = up).
// The two linear DoF control position in X, Z, and 3rd angular
// controls rotation in Y axis.
// Use -1 for parent to add a world-anchored joint (typical).
// ppos, cpos are the relative positions from the parent, child.
// Sets relative rotation matricies to identity by default.
// Use [SetJointDoF] to set the remaining DoF parameters.
func (ob *Object) NewJointPlaneXZ(parent, child *Body, ppos, cpos math32.Vector3) *Joint {
jd := ob.newJoint(physics.PlaneXZ, parent, child, ppos, cpos, 2, 1)
jd.NoLinearRotation = true
return jd
}
// NewPhysicsJoint makes the physics joint for joint
func (jd *Joint) NewPhysicsJoint(ml *physics.Model, ob *Object) int32 {
pi := jd.Parent
pdi := int32(-1)
if pi >= 0 {
pb := ob.Body(pi)
pdi = pb.DynamicIndex // todo: validate
}
cb := ob.Body(jd.Child)
cdi := cb.DynamicIndex
ji := int32(0)
switch jd.Type {
case physics.Prismatic:
ji = ml.NewJointPrismatic(pdi, cdi, jd.PPose.Pos, jd.CPose.Pos, jd.DoFs[0].Axis)
case physics.Revolute:
ji = ml.NewJointRevolute(pdi, cdi, jd.PPose.Pos, jd.CPose.Pos, jd.DoFs[0].Axis)
case physics.Ball:
ji = ml.NewJointBall(pdi, cdi, jd.PPose.Pos, jd.CPose.Pos)
case physics.Fixed:
ji = ml.NewJointFixed(pdi, cdi, jd.PPose.Pos, jd.CPose.Pos)
case physics.Distance:
ji = ml.NewJointBall(pdi, cdi, jd.PPose.Pos, jd.CPose.Pos)
case physics.Free:
ji = ml.NewJointFree(pdi, cdi, jd.PPose.Pos, jd.CPose.Pos)
case physics.PlaneXZ:
ji = ml.NewJointPlaneXZ(pdi, cdi, jd.PPose.Pos, jd.CPose.Pos)
}
physics.SetJointParentFixed(ji, jd.ParentFixed)
physics.SetJointNoLinearRotation(ji, jd.NoLinearRotation)
for i := range jd.LinearDoFN {
d := jd.DoF(i)
di := int32(i)
physics.SetJointDoF(ji, di, physics.JointLimitLower, d.Limit.Min)
physics.SetJointDoF(ji, di, physics.JointLimitUpper, d.Limit.Max)
physics.SetJointTargetPos(ji, di, d.Init.Pos, d.Init.Stiff)
physics.SetJointTargetVel(ji, di, d.Init.Vel, d.Init.Damp)
d.Axis = physics.JointAxis(ji, di)
}
for i := range jd.AngularDoFN {
di := int32(i + jd.LinearDoFN)
d := jd.DoF(int(di))
physics.SetJointDoF(ji, di, physics.JointLimitLower, d.Limit.Min)
physics.SetJointDoF(ji, di, physics.JointLimitUpper, d.Limit.Max)
physics.SetJointTargetPos(ji, di, d.Init.Pos, d.Init.Stiff)
physics.SetJointTargetVel(ji, di, d.Init.Vel, d.Init.Damp)
d.Axis = physics.JointAxis(ji, di)
}
jd.JointIndex = ji
// fmt.Printf("\tjoint: %p %d\n", jd, jd.JointIndex)
// if pdi < 0 {
// fmt.Println("\t\t\t", jd.PPose.Pos)
// }
return ji
}
// IsGlobal returns true if this joint has a global world anchor parent.
func (jd *Joint) IsGlobal() bool {
return jd.Parent < 0
}
// InitState initializes current state variables in the Joint.
func (jd *Joint) InitState() {
ji := jd.JointIndex
for di := range jd.DoFs {
d := jd.DoF(di)
d.InitState()
physics.SetJointTargetPos(ji, int32(di), d.Init.Pos, d.Init.Stiff)
physics.SetJointTargetVel(ji, int32(di), d.Init.Vel, d.Init.Damp)
}
}
// PoseToPhysics sets the current world-anchored joint pose
// to the physics current state.
func (jd *Joint) PoseToPhysics() {
if !jd.IsGlobal() {
return
}
physics.SetJointPPos(jd.JointIndex, jd.PPose.Pos)
physics.SetJointPQuat(jd.JointIndex, jd.PPose.Quat)
}
// PoseFromPhysics gets the current world-anchored joint pose
// from the physics current state.
func (jd *Joint) PoseFromPhysics() {
if !jd.IsGlobal() {
return
}
jd.PPose.Pos = physics.JointPPos(jd.JointIndex)
jd.PPose.Quat = physics.JointPQuat(jd.JointIndex)
}
// SetTargetVel sets the target position for given DoF for
// this joint in the physics model. Records into [DoF.Current].
func (jd *Joint) SetTargetVel(dof int32, vel, damp float32) {
d := jd.DoF(int(dof))
d.Current.Vel = vel
d.Current.Damp = damp
physics.SetJointTargetVel(jd.JointIndex, dof, vel, damp)
}
// SetTargetPos sets the target position for given DoF for
// this joint in the physics model. Records into [DoF.Current].
func (jd *Joint) SetTargetPos(dof int32, pos, stiff float32) {
d := jd.DoF(int(dof))
d.Current.Pos = pos
d.Current.Stiff = stiff
physics.SetJointTargetPos(jd.JointIndex, dof, pos, stiff)
}
// AddTargetPos adds to the Current target position for given DoF for
// this joint in the physics model, setting stiffness.
func (jd *Joint) AddTargetPos(dof int32, pos, stiff float32) {
d := jd.DoF(int(dof))
d.Current.Pos += pos
d.Current.Stiff = stiff
physics.SetJointTargetPos(jd.JointIndex, dof, d.Current.Pos, stiff)
}
// SetTargetAngle sets the target angular position
// and stiffness for given joint, DoF to given values.
// Stiffness determines how strongly the joint constraint is enforced
// (0 = not at all; 1000+ = strongly).
// Angle is in Degrees, not radians. Usable range is within -180..180
// which is enforced, and values near the edge can be unstable at higher
// stiffness levels.
func (jd *Joint) SetTargetAngle(dof int32, angDeg, stiff float32) {
pos := math32.WrapPi(math32.DegToRad(angDeg))
// pos := math32.DegToRad(angDeg)
d := jd.DoF(int(dof))
d.Current.Pos = pos
d.Current.Stiff = stiff
physics.SetJointTargetPos(jd.JointIndex, dof, pos, stiff)
}
// AddTargetAngle adds to the Current target angular position,
// and sets stiffness for given joint, DoF to given values.
// Stiffness determines how strongly the joint constraint is enforced
// (0 = not at all; 1000+ = strongly).
// Angle is in Degrees, not radians. Usable range is within -180..180
// which is enforced, and values near the edge can be unstable at higher
// stiffness levels.
func (jd *Joint) AddTargetAngle(dof int32, angDeg, stiff float32) {
d := jd.DoF(int(dof))
d.Current.Pos = math32.WrapPi(d.Current.Pos + math32.DegToRad(angDeg))
// d.Current.Pos = d.Current.Pos + math32.DegToRad(angDeg)
d.Current.Stiff = stiff
physics.SetJointTargetPos(jd.JointIndex, dof, d.Current.Pos, stiff)
}
// AddPlaneXZPos adds to the Current target X and Z axis positions for
// a PlaneXZ joint, using the given Y axis rotation angle to project
// along the current angle direction. angOff provides an angle offset to
// add to the Y axis angle.
func (jd *Joint) AddPlaneXZPos(ang, delta, stiff float32) {
dx := delta * math32.Cos(ang)
dz := delta * math32.Sin(ang)
jd.AddTargetPos(0, dx, stiff)
jd.AddTargetPos(1, dz, stiff)
}
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package builder
import (
"slices"
"cogentcore.org/core/math32"
"cogentcore.org/lab/physics/phyxyz"
)
// Object is an object within the [World].
// Each object is a coherent collection of bodies, typically
// connected by joints. This is an organizational convenience
// for positioning elements; has no physical implications.
type Object struct {
// World is the world number for physics: -1 = globals, else positive
// are distinct non-interacting worlds.
World int
// WorldIndex is the index of world within builder Worlds list.
WorldIndex int
// Object is the index within World's Objects list.
Object int
// Bodies are the bodies in the object.
Bodies []*Body
// Joints are joints connecting object bodies.
// Joint indexes here refer strictly within bodies.
Joints []*Joint
// Sensors are functions that can be configured to report arbitrary values
// on given body element. The output must be stored directly somewhere via
// the closure function: the utility of the sensor function is being able
// to capture all the configuration-time parameters needed to make it work,
// and to have it automatically called on replicated objects.
Sensors []func(obj *Object)
}
func (ob *Object) Body(idx int) *Body {
return ob.Bodies[idx]
}
func (ob *Object) Joint(idx int) *Joint {
return ob.Joints[idx]
}
// Copy copies all bodies and joints from given source world into this one.
// (The objects will be identical after, regardless of current starting
// condition).
func (ob *Object) Copy(so *Object) {
ob.World = so.World
ob.Object = so.Object
ob.Bodies = make([]*Body, len(so.Bodies))
ob.Joints = make([]*Joint, len(so.Joints))
ob.Sensors = make([]func(obj *Object), len(so.Sensors))
for i := range ob.Bodies {
ob.Bodies[i] = &Body{}
ob.Body(i).Copy(so.Body(i))
}
for i := range ob.Joints {
ob.Joints[i] = &Joint{}
ob.Joint(i).Copy(so.Joint(i))
}
copy(ob.Sensors, so.Sensors)
}
// CopySkins makes new skins for bodies based on those in source object.
// Which must have same number of bodies.
func (ob *Object) CopySkins(sc *phyxyz.Scene, so *Object) {
for i := range ob.Bodies {
bd := ob.Body(i)
sb := so.Body(i)
bd.NewSkin(sc, sb.Skin.Name, sb.Skin.Color)
}
}
// InitState initializes current state variables in the object.
func (ob *Object) InitState() {
for _, jd := range ob.Joints {
jd.InitState()
}
}
// HasBodyIndex returns true if a body in the object has any of
// given body index(es).
func (ob *Object) HasBodyIndex(bodyIndex ...int32) bool {
for _, bd := range ob.Bodies {
if slices.Contains(bodyIndex, bd.BodyIndex) {
return true
}
}
return false
}
//////// Transforms
// PoseToPhysics sets the current body poses to the physics current state.
// For Dynamic bodies, sets dynamic state. Also updates world-anchored joints.
func (ob *Object) PoseToPhysics() {
for _, bd := range ob.Bodies {
bd.PoseToPhysics()
}
for _, jd := range ob.Joints {
jd.PoseToPhysics()
}
}
// PoseFromPhysics gets the current body poses from the physics current state.
// Also updates world-anchored joints.
func (ob *Object) PoseFromPhysics() {
for _, bd := range ob.Bodies {
bd.PoseFromPhysics()
}
for _, jd := range ob.Joints {
jd.PoseFromPhysics()
}
}
// Move applies positional and rotational transforms to all bodies,
// and world-anchored joints.
func (ob *Object) Move(pos math32.Vector3) {
for _, bd := range ob.Bodies {
bd.Pose.Move(pos)
}
for _, jd := range ob.Joints {
if jd.IsGlobal() {
jd.PPose.Move(pos)
}
}
}
// RotateAround rotates around a given point
func (ob *Object) RotateAround(rot math32.Quat, around math32.Vector3) {
for _, bd := range ob.Bodies {
bd.Pose.RotateAround(rot, around)
}
for _, jd := range ob.Joints {
if jd.IsGlobal() {
jd.PPose.RotateAround(rot, around)
}
}
}
// RotateAroundBody rotates around a given body in object.
func (ob *Object) RotateAroundBody(body int, rot math32.Quat) {
bd := ob.Body(body)
ob.RotateAround(rot, bd.Pose.Pos)
}
// MoveOnAxis moves (translates) the specified distance on the
// specified local axis, relative to the given body in object.
// The axis is normalized prior to aplying the distance factor.
func (ob *Object) MoveOnAxisBody(body int, x, y, z, dist float32) {
bd := ob.Body(body)
delta := bd.Pose.Quat.MulVector(math32.Vec3(x, y, z).Normal()).MulScalar(dist)
ob.Move(delta)
}
// RotateOnAxisBody rotates around the specified local axis the
// specified angle in degrees, relative to the given body in the object.
func (ob *Object) RotateOnAxisBody(body int, x, y, z, angle float32) {
rot := math32.NewQuatAxisAngle(math32.Vec3(x, y, z), math32.DegToRad(angle))
ob.RotateAroundBody(body, rot)
}
//////// Sensors
// NewSensor adds a new sensor function for this object.
// The closure function can capture local variables at the time
// of configuration, and write results wherever and however it is useful.
func (ob *Object) NewSensor(fun func(obj *Object)) {
ob.Sensors = append(ob.Sensors, fun)
}
// RunSensors runs the sensor functions for this object.
func (ob *Object) RunSensors() {
for _, sf := range ob.Sensors {
sf(ob)
}
}
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package builder
import (
"cogentcore.org/lab/physics"
"cogentcore.org/lab/physics/phyxyz"
)
// Physics provides a container and manager for the main physics elements:
// [Builder], [physics.Model], and [phyxyz.Scene]. This is helpful for
// models used within other apps (e.g., an AI simulation), whereas
// [phyxyz.Editor] provides a standalone GUI interface for testing models.
type Physics struct {
// Model has the physics Model.
Model *physics.Model
// Builder for configuring the Model.
Builder *Builder
// Scene for visualizing the Model
Scene *phyxyz.Scene
}
// Build calls Builder.Build with Model and Scene args,
// and then Init on the Scene.
func (ph *Physics) Build() {
ph.Builder.Build(ph.Model, ph.Scene)
if ph.Scene != nil {
ph.Scene.Init(ph.Model)
}
}
// InitState calls Scene.InitState or Model.InitState and Builder InitState.
func (ph *Physics) InitState() {
if ph.Scene != nil {
ph.Scene.InitState(ph.Model)
} else {
ph.Model.InitState()
}
if ph.Builder != nil {
ph.Builder.InitState()
}
}
// Step advances the physics world n steps, updating the scene every time.
func (ph *Physics) Step(n int) {
for range n {
ph.Model.Step()
if ph.Scene != nil {
ph.Scene.Update()
}
}
}
// StepQuiet advances the physics world n steps.
func (ph *Physics) StepQuiet(n int) {
for range n {
ph.Model.Step()
}
}
// Copyright (c) 2019, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package builder
import (
"cogentcore.org/core/core"
"cogentcore.org/core/icons"
"cogentcore.org/core/math32"
"cogentcore.org/core/tree"
)
// Pose represents the 3D position and rotation.
type Pose struct {
// Pos is the position of center of mass of object.
Pos math32.Vector3
// Quat is the rotation specified as a quaternion.
Quat math32.Quat
}
// Defaults sets defaults only if current values are nil
func (ps *Pose) Defaults() {
if ps.Quat.IsNil() {
ps.Quat.SetIdentity()
}
}
// Transform applies positional and rotational transform to pose.
func (ps *Pose) Transform(pos math32.Vector3, rot math32.Quat) {
ps.Pos = rot.MulVector(ps.Pos).Add(pos)
ps.Quat = rot.Mul(ps.Quat)
}
//////// Moving
// Move moves (translates) Pos by given amount, and sets the LinVel to the given
// delta -- this can be useful for Scripted motion to track movement.
func (ps *Pose) Move(delta math32.Vector3) {
ps.Pos.SetAdd(delta)
}
// MoveOnAxis moves (translates) the specified distance on the specified local axis,
// relative to the current rotation orientation.
// The axis is normalized prior to aplying the distance factor.
// Sets the LinVel to motion vector.
func (ps *Pose) MoveOnAxis(x, y, z, dist float32) { //types:add
delta := ps.Quat.MulVector(math32.Vec3(x, y, z).Normal()).MulScalar(dist)
ps.Pos.SetAdd(delta)
}
// MoveOnAxisAbs moves (translates) the specified distance on the specified local axis,
// in absolute X,Y,Z coordinates (does not apply the Quat rotation factor.
// The axis is normalized prior to aplying the distance factor.
// Sets the LinVel to motion vector.
func (ps *Pose) MoveOnAxisAbs(x, y, z, dist float32) { //types:add
delta := math32.Vec3(x, y, z).Normal().MulScalar(dist)
ps.Pos.SetAdd(delta)
}
//////// Rotating
func (ps *Pose) RotateAround(rot math32.Quat, around math32.Vector3) {
ps.Pos = rot.MulVector(ps.Pos.Sub(around)).Add(around)
ps.Quat = rot.Mul(ps.Quat)
}
// SetEulerRotation sets the rotation in Euler angles (degrees).
func (ps *Pose) SetEulerRotation(x, y, z float32) { //types:add
ps.Quat.SetFromEuler(math32.Vec3(x, y, z).MulScalar(math32.DegToRadFactor))
}
// SetEulerRotationRad sets the rotation in Euler angles (radians).
func (ps *Pose) SetEulerRotationRad(x, y, z float32) {
ps.Quat.SetFromEuler(math32.Vec3(x, y, z))
}
// EulerRotation returns the current rotation in Euler angles (degrees).
func (ps *Pose) EulerRotation() math32.Vector3 { //types:add
return ps.Quat.ToEuler().MulScalar(math32.RadToDegFactor)
}
// EulerRotationRad returns the current rotation in Euler angles (radians).
func (ps *Pose) EulerRotationRad() math32.Vector3 {
return ps.Quat.ToEuler()
}
// SetAxisRotation sets rotation from local axis and angle in degrees.
func (ps *Pose) SetAxisRotation(x, y, z, angle float32) { //types:add
ps.Quat.SetFromAxisAngle(math32.Vec3(x, y, z), math32.DegToRad(angle))
}
// SetAxisRotationRad sets rotation from local axis and angle in radians.
func (ps *Pose) SetAxisRotationRad(x, y, z, angle float32) {
ps.Quat.SetFromAxisAngle(math32.Vec3(x, y, z), angle)
}
// RotateOnAxis rotates around the specified local axis the specified angle in degrees.
func (ps *Pose) RotateOnAxis(x, y, z, angle float32) { //types:add
ps.Quat.SetMul(math32.NewQuatAxisAngle(math32.Vec3(x, y, z), math32.DegToRad(angle)))
}
// RotateOnAxisRad rotates around the specified local axis the specified angle in radians.
func (ps *Pose) RotateOnAxisRad(x, y, z, angle float32) {
ps.Quat.SetMul(math32.NewQuatAxisAngle(math32.Vec3(x, y, z), angle))
}
// RotateEuler rotates by given Euler angles (in degrees) relative to existing rotation.
func (ps *Pose) RotateEuler(x, y, z float32) { //types:add
ps.Quat.SetMul(math32.NewQuatEuler(math32.Vec3(x, y, z).MulScalar(math32.DegToRadFactor)))
}
// RotateEulerRad rotates by given Euler angles (in radians) relative to existing rotation.
func (ps *Pose) RotateEulerRad(x, y, z, angle float32) {
ps.Quat.SetMul(math32.NewQuatEuler(math32.Vec3(x, y, z)))
}
// MakePoseToolbar returns a toolbar function for physics state updates,
// calling the given updt function after making the change.
func MakePoseToolbar(ps *Pose, updt func()) func(p *tree.Plan) {
return func(p *tree.Plan) {
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(ps.SetEulerRotation).SetAfterFunc(updt).SetIcon(icons.Rotate90DegreesCcw)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(ps.SetAxisRotation).SetAfterFunc(updt).SetIcon(icons.Rotate90DegreesCcw)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(ps.RotateEuler).SetAfterFunc(updt).SetIcon(icons.Rotate90DegreesCcw)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(ps.RotateOnAxis).SetAfterFunc(updt).SetIcon(icons.Rotate90DegreesCcw)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(ps.EulerRotation).SetAfterFunc(updt).SetShowReturn(true).SetIcon(icons.Rotate90DegreesCcw)
})
tree.Add(p, func(w *core.Separator) {})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(ps.MoveOnAxis).SetAfterFunc(updt).SetIcon(icons.MoveItem)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(ps.MoveOnAxisAbs).SetAfterFunc(updt).SetIcon(icons.MoveItem)
})
}
}
// Code generated by "core generate -add-types -setters"; DO NOT EDIT.
package builder
import (
"cogentcore.org/core/math32"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/core/types"
"cogentcore.org/lab/physics"
"cogentcore.org/lab/physics/phyxyz"
)
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/physics/builder.Body", IDName: "body", Doc: "Body is a rigid body.", Fields: []types.Field{{Name: "World", Doc: "World is the world number for physics: -1 = globals, else positive\nare distinct non-interacting worlds."}, {Name: "WorldIndex", Doc: "WorldIndex is the index of world within builder Worlds list."}, {Name: "Object", Doc: "Object is the index within World's Objects list."}, {Name: "ObjectBody", Doc: "ObjectBody is the index within the Object's Bodies list."}, {Name: "Shape", Doc: "Shape of the body."}, {Name: "Dynamic", Doc: "Dynamic makes this a dynamic body."}, {Name: "Group", Doc: "Group partitions bodies within worlds into different groups\nfor collision detection. 0 does not collide with anything.\nNegative numbers are global within a world, except they don't\ncollide amongst themselves (all non-dynamic bodies should go\nin -1 because they don't collide amongst each-other, but do\npotentially collide with dynamics).\nPositive numbers only collide amongst themselves, and with\nnegative groups, but not other positive groups. To avoid\nunwanted collisions, put bodies into separate groups.\nThere is an automatic constraint that the two objects\nwithin a single joint do not collide with each other, so this\ndoes not need to be handled here."}, {Name: "HSize", Doc: "HSize is the half-size (e.g., radius) of the body.\nValues depend on shape type: X is generally radius,\nY is half-height."}, {Name: "Thick", Doc: "Thick is the thickness of the body, as a hollow shape.\nIf 0, then it is a solid shape (default)."}, {Name: "Mass", Doc: "Mass of the object. Only relevant for Dynamic bodies."}, {Name: "Pose", Doc: "Pose has the position and rotation."}, {Name: "Com", Doc: "Com is the center-of-mass offset from the Pose.Pos."}, {Name: "Bounce", Doc: "Bounce specifies the COR or coefficient of restitution (0..1),\nwhich determines how elastic the collision is,\ni.e., final velocity / initial velocity."}, {Name: "Friction", Doc: "Friction is the standard coefficient for linear friction (mu)."}, {Name: "FrictionTortion", Doc: "FrictionTortion is resistance to spinning at the contact point."}, {Name: "FrictionRolling", Doc: "FrictionRolling is resistance to rolling motion at contact."}, {Name: "Skin", Doc: "Optional [phyxyz.Skin] for visualizing the body."}, {Name: "BodyIndex", Doc: "BodyIndex is the index of this body in the [physics.Model] Bodies list,\nonce built."}, {Name: "DynamicIndex", Doc: "DynamicIndex is the index of this dynamic body in the\n[physics.Model] Dynamics list, once built."}}})
// SetWorld sets the [Body.World]:
// World is the world number for physics: -1 = globals, else positive
// are distinct non-interacting worlds.
func (t *Body) SetWorld(v int) *Body { t.World = v; return t }
// SetWorldIndex sets the [Body.WorldIndex]:
// WorldIndex is the index of world within builder Worlds list.
func (t *Body) SetWorldIndex(v int) *Body { t.WorldIndex = v; return t }
// SetObject sets the [Body.Object]:
// Object is the index within World's Objects list.
func (t *Body) SetObject(v int) *Body { t.Object = v; return t }
// SetObjectBody sets the [Body.ObjectBody]:
// ObjectBody is the index within the Object's Bodies list.
func (t *Body) SetObjectBody(v int) *Body { t.ObjectBody = v; return t }
// SetShape sets the [Body.Shape]:
// Shape of the body.
func (t *Body) SetShape(v physics.Shapes) *Body { t.Shape = v; return t }
// SetDynamic sets the [Body.Dynamic]:
// Dynamic makes this a dynamic body.
func (t *Body) SetDynamic(v bool) *Body { t.Dynamic = v; return t }
// SetGroup sets the [Body.Group]:
// Group partitions bodies within worlds into different groups
// for collision detection. 0 does not collide with anything.
// Negative numbers are global within a world, except they don't
// collide amongst themselves (all non-dynamic bodies should go
// in -1 because they don't collide amongst each-other, but do
// potentially collide with dynamics).
// Positive numbers only collide amongst themselves, and with
// negative groups, but not other positive groups. To avoid
// unwanted collisions, put bodies into separate groups.
// There is an automatic constraint that the two objects
// within a single joint do not collide with each other, so this
// does not need to be handled here.
func (t *Body) SetGroup(v int) *Body { t.Group = v; return t }
// SetHSize sets the [Body.HSize]:
// HSize is the half-size (e.g., radius) of the body.
// Values depend on shape type: X is generally radius,
// Y is half-height.
func (t *Body) SetHSize(v math32.Vector3) *Body { t.HSize = v; return t }
// SetThick sets the [Body.Thick]:
// Thick is the thickness of the body, as a hollow shape.
// If 0, then it is a solid shape (default).
func (t *Body) SetThick(v float32) *Body { t.Thick = v; return t }
// SetMass sets the [Body.Mass]:
// Mass of the object. Only relevant for Dynamic bodies.
func (t *Body) SetMass(v float32) *Body { t.Mass = v; return t }
// SetPose sets the [Body.Pose]:
// Pose has the position and rotation.
func (t *Body) SetPose(v Pose) *Body { t.Pose = v; return t }
// SetCom sets the [Body.Com]:
// Com is the center-of-mass offset from the Pose.Pos.
func (t *Body) SetCom(v math32.Vector3) *Body { t.Com = v; return t }
// SetBounce sets the [Body.Bounce]:
// Bounce specifies the COR or coefficient of restitution (0..1),
// which determines how elastic the collision is,
// i.e., final velocity / initial velocity.
func (t *Body) SetBounce(v float32) *Body { t.Bounce = v; return t }
// SetFriction sets the [Body.Friction]:
// Friction is the standard coefficient for linear friction (mu).
func (t *Body) SetFriction(v float32) *Body { t.Friction = v; return t }
// SetFrictionTortion sets the [Body.FrictionTortion]:
// FrictionTortion is resistance to spinning at the contact point.
func (t *Body) SetFrictionTortion(v float32) *Body { t.FrictionTortion = v; return t }
// SetFrictionRolling sets the [Body.FrictionRolling]:
// FrictionRolling is resistance to rolling motion at contact.
func (t *Body) SetFrictionRolling(v float32) *Body { t.FrictionRolling = v; return t }
// SetSkin sets the [Body.Skin]:
// Optional [phyxyz.Skin] for visualizing the body.
func (t *Body) SetSkin(v *phyxyz.Skin) *Body { t.Skin = v; return t }
// SetBodyIndex sets the [Body.BodyIndex]:
// BodyIndex is the index of this body in the [physics.Model] Bodies list,
// once built.
func (t *Body) SetBodyIndex(v int32) *Body { t.BodyIndex = v; return t }
// SetDynamicIndex sets the [Body.DynamicIndex]:
// DynamicIndex is the index of this dynamic body in the
// [physics.Model] Dynamics list, once built.
func (t *Body) SetDynamicIndex(v int32) *Body { t.DynamicIndex = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/physics/builder.Builder", IDName: "builder", Doc: "Builder is the global container of [physics.Model] elements,\norganized into worlds that are independently updated.", Fields: []types.Field{{Name: "Worlds", Doc: "Worlds are the independent world elements."}, {Name: "ReplicasStart", Doc: "ReplicasStart is the starting Worlds index for replicated world bodies.\nSet by ReplicateWorld, and used to set corresponding value in Model."}, {Name: "ReplicasN", Doc: "ReplicasN is the total number of replicated Worlds (including source).\nSet by ReplicateWorld, and used to set corresponding value in Model."}}})
// SetWorlds sets the [Builder.Worlds]:
// Worlds are the independent world elements.
func (t *Builder) SetWorlds(v ...*World) *Builder { t.Worlds = v; return t }
// SetReplicasStart sets the [Builder.ReplicasStart]:
// ReplicasStart is the starting Worlds index for replicated world bodies.
// Set by ReplicateWorld, and used to set corresponding value in Model.
func (t *Builder) SetReplicasStart(v int) *Builder { t.ReplicasStart = v; return t }
// SetReplicasN sets the [Builder.ReplicasN]:
// ReplicasN is the total number of replicated Worlds (including source).
// Set by ReplicateWorld, and used to set corresponding value in Model.
func (t *Builder) SetReplicasN(v int) *Builder { t.ReplicasN = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/physics/builder.Joint", IDName: "joint", Doc: "Joint describes a joint between two bodies.", Fields: []types.Field{{Name: "World", Doc: "World is the world number for physics: -1 = globals, else positive\nare distinct non-interacting worlds."}, {Name: "WorldIndex", Doc: "WorldIndex is the index of world within builder Worlds list."}, {Name: "Object", Doc: "Object is the index within World's Objects list."}, {Name: "ObjectJoint", Doc: "ObjectJoint is the index within Object's Joints list."}, {Name: "Parent", Doc: "Parent is index within an Object for parent body.\n-1 for world-anchored parent."}, {Name: "Child", Doc: "Parent is index within an Object for parent body."}, {Name: "Type", Doc: "Type is the type of the joint."}, {Name: "PPose", Doc: "PPose is the parent position and orientation of the joint\nin the parent's body-centered coordinates."}, {Name: "CPose", Doc: "CPose is the child position and orientation of the joint\nin the parent's body-centered coordinates."}, {Name: "ParentFixed", Doc: "ParentFixed does not update the parent side of the joint."}, {Name: "NoLinearRotation", Doc: "NoLinearRotation ignores the rotational (angular) effects of\nlinear joint position constraints (i.e., Coriolis and centrifugal forces)\nwhich can otherwise interfere with rotational position constraints in\njoints with both linear and angular DoFs\n(e.g., [PlaneXZ], for which this is on by default)."}, {Name: "LinearDoFN", Doc: "LinearDoFN is the number of linear degrees of freedom (3 max)."}, {Name: "AngularDoFN", Doc: "AngularDoFN is the number of linear degrees of freedom (3 max)."}, {Name: "DoFs", Doc: "DoFs are the degrees-of-freedom for this joint."}, {Name: "JointIndex", Doc: "JointIndex is the index of this joint in [physics.Joints] when built."}}})
// SetWorld sets the [Joint.World]:
// World is the world number for physics: -1 = globals, else positive
// are distinct non-interacting worlds.
func (t *Joint) SetWorld(v int) *Joint { t.World = v; return t }
// SetWorldIndex sets the [Joint.WorldIndex]:
// WorldIndex is the index of world within builder Worlds list.
func (t *Joint) SetWorldIndex(v int) *Joint { t.WorldIndex = v; return t }
// SetObject sets the [Joint.Object]:
// Object is the index within World's Objects list.
func (t *Joint) SetObject(v int) *Joint { t.Object = v; return t }
// SetObjectJoint sets the [Joint.ObjectJoint]:
// ObjectJoint is the index within Object's Joints list.
func (t *Joint) SetObjectJoint(v int) *Joint { t.ObjectJoint = v; return t }
// SetParent sets the [Joint.Parent]:
// Parent is index within an Object for parent body.
// -1 for world-anchored parent.
func (t *Joint) SetParent(v int) *Joint { t.Parent = v; return t }
// SetChild sets the [Joint.Child]:
// Parent is index within an Object for parent body.
func (t *Joint) SetChild(v int) *Joint { t.Child = v; return t }
// SetType sets the [Joint.Type]:
// Type is the type of the joint.
func (t *Joint) SetType(v physics.JointTypes) *Joint { t.Type = v; return t }
// SetPPose sets the [Joint.PPose]:
// PPose is the parent position and orientation of the joint
// in the parent's body-centered coordinates.
func (t *Joint) SetPPose(v Pose) *Joint { t.PPose = v; return t }
// SetCPose sets the [Joint.CPose]:
// CPose is the child position and orientation of the joint
// in the parent's body-centered coordinates.
func (t *Joint) SetCPose(v Pose) *Joint { t.CPose = v; return t }
// SetParentFixed sets the [Joint.ParentFixed]:
// ParentFixed does not update the parent side of the joint.
func (t *Joint) SetParentFixed(v bool) *Joint { t.ParentFixed = v; return t }
// SetNoLinearRotation sets the [Joint.NoLinearRotation]:
// NoLinearRotation ignores the rotational (angular) effects of
// linear joint position constraints (i.e., Coriolis and centrifugal forces)
// which can otherwise interfere with rotational position constraints in
// joints with both linear and angular DoFs
// (e.g., [PlaneXZ], for which this is on by default).
func (t *Joint) SetNoLinearRotation(v bool) *Joint { t.NoLinearRotation = v; return t }
// SetLinearDoFN sets the [Joint.LinearDoFN]:
// LinearDoFN is the number of linear degrees of freedom (3 max).
func (t *Joint) SetLinearDoFN(v int) *Joint { t.LinearDoFN = v; return t }
// SetAngularDoFN sets the [Joint.AngularDoFN]:
// AngularDoFN is the number of linear degrees of freedom (3 max).
func (t *Joint) SetAngularDoFN(v int) *Joint { t.AngularDoFN = v; return t }
// SetDoFs sets the [Joint.DoFs]:
// DoFs are the degrees-of-freedom for this joint.
func (t *Joint) SetDoFs(v ...*DoF) *Joint { t.DoFs = v; return t }
// SetJointIndex sets the [Joint.JointIndex]:
// JointIndex is the index of this joint in [physics.Joints] when built.
func (t *Joint) SetJointIndex(v int32) *Joint { t.JointIndex = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/physics/builder.Controls", IDName: "controls", Doc: "Controls are the per degrees-of-freedom (DoF) joint control inputs.", Fields: []types.Field{{Name: "Force", Doc: "Force is the force input driving the joint."}, {Name: "Pos", Doc: "Pos is the position target value, where 0 is the initial\nposition. For angular joints, this is in radians."}, {Name: "Stiff", Doc: "Stiff determines how strongly the target position\nis enforced: 0 = not at all; larger = stronger (e.g., 1000 or higher).\nSet to 0 to allow the joint to be fully flexible."}, {Name: "Vel", Doc: "Vel is the velocity target value. For example, 0\neffectively damps joint movement in proportion to Damp parameter."}, {Name: "Damp", Doc: "Damp determines how strongly the target velocity is enforced:\n0 = not at all; larger = stronger (e.g., 1 is reasonable).\nSet to 0 to allow the joint to be fully flexible."}}})
// SetForce sets the [Controls.Force]:
// Force is the force input driving the joint.
func (t *Controls) SetForce(v float32) *Controls { t.Force = v; return t }
// SetPos sets the [Controls.Pos]:
// Pos is the position target value, where 0 is the initial
// position. For angular joints, this is in radians.
func (t *Controls) SetPos(v float32) *Controls { t.Pos = v; return t }
// SetStiff sets the [Controls.Stiff]:
// Stiff determines how strongly the target position
// is enforced: 0 = not at all; larger = stronger (e.g., 1000 or higher).
// Set to 0 to allow the joint to be fully flexible.
func (t *Controls) SetStiff(v float32) *Controls { t.Stiff = v; return t }
// SetVel sets the [Controls.Vel]:
// Vel is the velocity target value. For example, 0
// effectively damps joint movement in proportion to Damp parameter.
func (t *Controls) SetVel(v float32) *Controls { t.Vel = v; return t }
// SetDamp sets the [Controls.Damp]:
// Damp determines how strongly the target velocity is enforced:
// 0 = not at all; larger = stronger (e.g., 1 is reasonable).
// Set to 0 to allow the joint to be fully flexible.
func (t *Controls) SetDamp(v float32) *Controls { t.Damp = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/physics/builder.DoF", IDName: "do-f", Doc: "DoF is a degree-of-freedom for a [Joint].", Fields: []types.Field{{Name: "Axis", Doc: "Axis is the axis of articulation."}, {Name: "Limit", Doc: "Limit has the limits for motion of this DoF."}, {Name: "Init", Doc: "Init are the initial control values."}, {Name: "Current", Doc: "Current are the current control values (based on method calls)."}}})
// SetAxis sets the [DoF.Axis]:
// Axis is the axis of articulation.
func (t *DoF) SetAxis(v math32.Vector3) *DoF { t.Axis = v; return t }
// SetLimit sets the [DoF.Limit]:
// Limit has the limits for motion of this DoF.
func (t *DoF) SetLimit(v minmax.F32) *DoF { t.Limit = v; return t }
// SetInit sets the [DoF.Init]:
// Init are the initial control values.
func (t *DoF) SetInit(v Controls) *DoF { t.Init = v; return t }
// SetCurrent sets the [DoF.Current]:
// Current are the current control values (based on method calls).
func (t *DoF) SetCurrent(v Controls) *DoF { t.Current = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/physics/builder.Object", IDName: "object", Doc: "Object is an object within the [World].\nEach object is a coherent collection of bodies, typically\nconnected by joints. This is an organizational convenience\nfor positioning elements; has no physical implications.", Fields: []types.Field{{Name: "World", Doc: "World is the world number for physics: -1 = globals, else positive\nare distinct non-interacting worlds."}, {Name: "WorldIndex", Doc: "WorldIndex is the index of world within builder Worlds list."}, {Name: "Object", Doc: "Object is the index within World's Objects list."}, {Name: "Bodies", Doc: "Bodies are the bodies in the object."}, {Name: "Joints", Doc: "Joints are joints connecting object bodies.\nJoint indexes here refer strictly within bodies."}, {Name: "Sensors", Doc: "Sensors are functions that can be configured to report arbitrary values\non given body element. The output must be stored directly somewhere via\nthe closure function: the utility of the sensor function is being able\nto capture all the configuration-time parameters needed to make it work,\nand to have it automatically called on replicated objects."}}})
// SetWorld sets the [Object.World]:
// World is the world number for physics: -1 = globals, else positive
// are distinct non-interacting worlds.
func (t *Object) SetWorld(v int) *Object { t.World = v; return t }
// SetWorldIndex sets the [Object.WorldIndex]:
// WorldIndex is the index of world within builder Worlds list.
func (t *Object) SetWorldIndex(v int) *Object { t.WorldIndex = v; return t }
// SetObject sets the [Object.Object]:
// Object is the index within World's Objects list.
func (t *Object) SetObject(v int) *Object { t.Object = v; return t }
// SetBodies sets the [Object.Bodies]:
// Bodies are the bodies in the object.
func (t *Object) SetBodies(v ...*Body) *Object { t.Bodies = v; return t }
// SetJoints sets the [Object.Joints]:
// Joints are joints connecting object bodies.
// Joint indexes here refer strictly within bodies.
func (t *Object) SetJoints(v ...*Joint) *Object { t.Joints = v; return t }
// SetSensors sets the [Object.Sensors]:
// Sensors are functions that can be configured to report arbitrary values
// on given body element. The output must be stored directly somewhere via
// the closure function: the utility of the sensor function is being able
// to capture all the configuration-time parameters needed to make it work,
// and to have it automatically called on replicated objects.
func (t *Object) SetSensors(v ...func(obj *Object)) *Object { t.Sensors = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/physics/builder.Physics", IDName: "physics", Doc: "Physics provides a container and manager for the main physics elements:\n[Builder], [physics.Model], and [phyxyz.Scene]. This is helpful for\nmodels used within other apps (e.g., an AI simulation), whereas\n[phyxyz.Editor] provides a standalone GUI interface for testing models.", Fields: []types.Field{{Name: "Model", Doc: "Model has the physics Model."}, {Name: "Builder", Doc: "Builder for configuring the Model."}, {Name: "Scene", Doc: "Scene for visualizing the Model"}}})
// SetModel sets the [Physics.Model]:
// Model has the physics Model.
func (t *Physics) SetModel(v *physics.Model) *Physics { t.Model = v; return t }
// SetBuilder sets the [Physics.Builder]:
// Builder for configuring the Model.
func (t *Physics) SetBuilder(v *Builder) *Physics { t.Builder = v; return t }
// SetScene sets the [Physics.Scene]:
// Scene for visualizing the Model
func (t *Physics) SetScene(v *phyxyz.Scene) *Physics { t.Scene = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/physics/builder.Pose", IDName: "pose", Doc: "Pose represents the 3D position and rotation.", Methods: []types.Method{{Name: "MoveOnAxis", Doc: "MoveOnAxis moves (translates) the specified distance on the specified local axis,\nrelative to the current rotation orientation.\nThe axis is normalized prior to aplying the distance factor.\nSets the LinVel to motion vector.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"x", "y", "z", "dist"}}, {Name: "MoveOnAxisAbs", Doc: "MoveOnAxisAbs moves (translates) the specified distance on the specified local axis,\nin absolute X,Y,Z coordinates (does not apply the Quat rotation factor.\nThe axis is normalized prior to aplying the distance factor.\nSets the LinVel to motion vector.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"x", "y", "z", "dist"}}, {Name: "SetEulerRotation", Doc: "SetEulerRotation sets the rotation in Euler angles (degrees).", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"x", "y", "z"}}, {Name: "EulerRotation", Doc: "EulerRotation returns the current rotation in Euler angles (degrees).", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Returns: []string{"Vector3"}}, {Name: "SetAxisRotation", Doc: "SetAxisRotation sets rotation from local axis and angle in degrees.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"x", "y", "z", "angle"}}, {Name: "RotateOnAxis", Doc: "RotateOnAxis rotates around the specified local axis the specified angle in degrees.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"x", "y", "z", "angle"}}, {Name: "RotateEuler", Doc: "RotateEuler rotates by given Euler angles (in degrees) relative to existing rotation.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"x", "y", "z"}}}, Fields: []types.Field{{Name: "Pos", Doc: "Pos is the position of center of mass of object."}, {Name: "Quat", Doc: "Quat is the rotation specified as a quaternion."}}})
// SetPos sets the [Pose.Pos]:
// Pos is the position of center of mass of object.
func (t *Pose) SetPos(v math32.Vector3) *Pose { t.Pos = v; return t }
// SetQuat sets the [Pose.Quat]:
// Quat is the rotation specified as a quaternion.
func (t *Pose) SetQuat(v math32.Quat) *Pose { t.Quat = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/physics/builder.World", IDName: "world", Doc: "World is one world within the Builder.", Fields: []types.Field{{Name: "World", Doc: "World is the world number for physics: -1 = globals, else positive\nare distinct non-interacting worlds."}, {Name: "WorldIndex", Doc: "WorldIndex is the index of world within builder Worlds list."}, {Name: "Objects", Doc: "Objects are the objects within the [World].\nEach object is a coherent collection of bodies, typically\nconnected by joints. This is an organizational convenience\nfor positioning elements; has no physical implications."}}})
// SetWorld sets the [World.World]:
// World is the world number for physics: -1 = globals, else positive
// are distinct non-interacting worlds.
func (t *World) SetWorld(v int) *World { t.World = v; return t }
// SetObjects sets the [World.Objects]:
// Objects are the objects within the [World].
// Each object is a coherent collection of bodies, typically
// connected by joints. This is an organizational convenience
// for positioning elements; has no physical implications.
func (t *World) SetObjects(v ...*Object) *World { t.Objects = v; return t }
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package builder
import (
"cogentcore.org/core/math32"
"cogentcore.org/lab/physics/phyxyz"
)
// World is one world within the Builder.
type World struct {
// World is the world number for physics: -1 = globals, else positive
// are distinct non-interacting worlds.
World int
// WorldIndex is the index of world within builder Worlds list.
WorldIndex int `set:"-"`
// Objects are the objects within the [World].
// Each object is a coherent collection of bodies, typically
// connected by joints. This is an organizational convenience
// for positioning elements; has no physical implications.
Objects []*Object
}
func (wl *World) Object(idx int) *Object {
return wl.Objects[idx]
}
func (wl *World) NewObject() *Object {
idx := len(wl.Objects)
wl.Objects = append(wl.Objects, &Object{World: wl.World, WorldIndex: wl.WorldIndex, Object: idx})
return wl.Objects[idx]
}
// Copy copies all objects from given source world into this one.
// (The worlds will be identical after, regardless of current starting
// condition).
func (wl *World) Copy(ow *World) {
wl.Objects = make([]*Object, len(ow.Objects))
for i := range wl.Objects {
wl.Objects[i] = &Object{}
wl.Object(i).Copy(ow.Object(i))
}
}
// CopySkins makes new skins for bodies in world,
// based on those in source world, which must be a Copy.
func (wl *World) CopySkins(sc *phyxyz.Scene, ow *World) {
for i, ob := range wl.Objects {
ob.CopySkins(sc, ow.Object(i))
}
}
// SetWorldIndex sets the WorldIndex for this and all children.
func (wl *World) SetWorldIndex(wi int) {
wl.WorldIndex = wi
for _, ob := range wl.Objects {
ob.WorldIndex = wi
for _, bd := range ob.Bodies {
bd.WorldIndex = wi
}
for _, jd := range ob.Joints {
jd.WorldIndex = wi
}
}
}
// Move moves all objects in world by given delta.
func (wl *World) Move(delta math32.Vector3) {
for _, ob := range wl.Objects {
ob.Move(delta)
}
}
// PoseToPhysics sets the current body poses to the physics current state.
// For Dynamic bodies, sets dynamic state. Also updates world-anchored joints.
func (wl *World) PoseToPhysics() {
for _, ob := range wl.Objects {
ob.PoseToPhysics()
}
}
// RunSensors runs the sensor functions for this World.
func (wl *World) RunSensors() {
for _, ob := range wl.Objects {
ob.RunSensors()
}
}
// Code generated by "goal build"; DO NOT EDIT.
//line config.goal:1
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package physics
import (
// "fmt"
"cogentcore.org/lab/tensor"
)
// Config does final configuration prior to running
// after everything has been added. Does SetAsCurrent, GPUInit.
func (ml *Model) Config() {
ml.ConfigJoints()
ml.ConfigBodyCollidePairs()
ml.SetMaxContacts()
ml.SetAsCurrent()
ml.ConfigBodies()
ml.GPUInit()
ml.InitState()
}
// ConfigJoints does all of the initialization associated with joints.
func (ml *Model) ConfigJoints() {
// accumulate parent and child joints per dynamic
params := &ml.Params[0]
nj := params.JointsN
nd := params.DynamicsN
bjp := make([][]int32, nd)
bjc := make([][]int32, nd)
maxi := 0
for ji := range nj {
jpi := JointParentIndex(ji)
jci := JointChildIndex(ji)
// bpi := DynamicBody(jpi)
// bci := DynamicBody(jci)
// todo: could ensure that all elements are in same world, but not really needed
if jpi >= 0 {
bjp[jpi] = append(bjp[jpi], ji)
maxi = max(maxi, len(bjp[jpi]))
}
bjc[jci] = append(bjc[jci], ji)
maxi = max(maxi, len(bjc[jci]))
}
params.BodyJointsMax = int32(maxi)
if nd == 0 {
nd = 1
}
if maxi == 0 {
maxi = 1
}
ml.BodyJoints.SetShapeSizes(int(nd), 2, maxi+1)
for di := range nd {
np := int32(len(bjp[di]))
ml.BodyJoints.Set(np, int(di), int(0), int(0))
for i, ji := range bjp[di] {
ml.BodyJoints.Set(ji, int(di), int(0), int(1+i))
}
nc := int32(len(bjc[di]))
ml.BodyJoints.Set(nc, int(di), int(1), int(0))
for i, ji := range bjc[di] {
ml.BodyJoints.Set(ji, int(di), int(1), int(1+i))
}
}
if nj == 0 {
ml.Objects = tensor.NewInt32(1, 1)
ml.Joints = tensor.NewFloat32(1, int(JointVarsN))
ml.JointDoFs = tensor.NewFloat32(1, int(JointDoFVarsN))
ml.JointControls = tensor.NewFloat32(1, int(JointControlVarsN))
}
}
// ConfigBodies updates computed body values from current values.
// Call if body params (mass, size) change.
func (ml *Model) ConfigBodies() {
params := &ml.Params[0]
nb := params.BodiesN
for bi := range nb {
shape := GetBodyShape(bi)
size := BodyHSize(bi)
mass := Bodies.Value(int(bi), int(BodyMass))
ml.SetMass(bi, shape, size, mass)
}
}
// InitControlState initializes the JointTargetPosCur values to 0.
// This is done on the CPU prior to copying up to GPU, in InitState.
func (ml *Model) InitControlState() {
params := GetParams(0)
for j := range params.JointDoFsN {
JointControls.Set(0, int(j), int(JointTargetPosCur))
}
}
// InitState initializes the simulation state.
func (ml *Model) InitState() {
params := GetParams(0)
ml.InitControlState()
ml.ToGPUInfra()
RunInitDynamics(int(params.DynamicsN))
RunDone(DynamicsVar)
}
// Code generated by "goal build"; DO NOT EDIT.
//line contact.goal:1
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package physics
import (
// "fmt"
"math"
"sync/atomic"
"cogentcore.org/core/math32"
"cogentcore.org/lab/gosl/slmath"
"cogentcore.org/lab/tensor"
)
//gosl:start
// Contact is one pairwise point of contact between two bodies.
// Contacts are represented in spherical terms relative to the
// spherical BBox of A and B.
type ContactVars int32 //enums:enum
const (
// first body index
ContactA ContactVars = iota
// the other body index
ContactB
// contact point index for A-B pair
ContactPointIdx
// contact point on body A
ContactAPointX
ContactAPointY
ContactAPointZ
// contact point on body B
ContactBPointX
ContactBPointY
ContactBPointZ
// contact offset on body A
ContactAOffX
ContactAOffY
ContactAOffZ
// contact offset on body B
ContactBOffX
ContactBOffY
ContactBOffZ
// Contact thickness
ContactAThick
ContactBThick
// normal pointing from center of B to center of A
ContactNormX
ContactNormY
ContactNormZ
// contact weighting -- 1 if contact made; for restitution
// use this to filter contacts when updating body.
ContactWeight
// computed contact deltas, A
ContactADeltaX
ContactADeltaY
ContactADeltaZ
ContactAAngDeltaX
ContactAAngDeltaY
ContactAAngDeltaZ
// computed contact deltas, B
ContactBDeltaX
ContactBDeltaY
ContactBDeltaZ
ContactBAngDeltaX
ContactBAngDeltaY
ContactBAngDeltaZ
)
// number of broad-phase contact values: just the indexes
const BroadContactVarsN = ContactAPointX
func SetBroadContactA(idx, bodIdx int32) {
BroadContacts.Set(math.Float32frombits(uint32(bodIdx)), int(idx), int(ContactA))
}
func GetBroadContactA(idx int32) int32 {
return int32(math.Float32bits(BroadContacts.Value(int(idx), int(ContactA))))
}
func SetBroadContactB(idx, bodIdx int32) {
BroadContacts.Set(math.Float32frombits(uint32(bodIdx)), int(idx), int(ContactB))
}
func GetBroadContactB(idx int32) int32 {
return int32(math.Float32bits(BroadContacts.Value(int(idx), int(ContactB))))
}
func SetBroadContactPointIdx(idx, ptIdx int32) {
BroadContacts.Set(math.Float32frombits(uint32(ptIdx)), int(idx), int(ContactPointIdx))
}
func GetBroadContactPointIdx(idx int32) int32 {
return int32(math.Float32bits(BroadContacts.Value(int(idx), int(ContactPointIdx))))
}
//////// Narrow
func SetContactA(idx, bodIdx int32) {
Contacts.Set(math.Float32frombits(uint32(bodIdx)), int(idx), int(ContactA))
}
func GetContactA(idx int32) int32 {
return int32(math.Float32bits(Contacts.Value(int(idx), int(ContactA))))
}
func SetContactB(idx, bodIdx int32) {
Contacts.Set(math.Float32frombits(uint32(bodIdx)), int(idx), int(ContactB))
}
func GetContactB(idx int32) int32 {
return int32(math.Float32bits(Contacts.Value(int(idx), int(ContactB))))
}
func SetContactPointIdx(idx, ptIdx int32) {
Contacts.Set(math.Float32frombits(uint32(ptIdx)), int(idx), int(ContactPointIdx))
}
func GetContactPointIdx(idx int32) int32 {
return int32(math.Float32bits(Contacts.Value(int(idx), int(ContactPointIdx))))
}
func ContactAPoint(idx int32) math32.Vector3 {
return math32.Vec3(Contacts.Value(int(idx), int(ContactAPointX)), Contacts.Value(int(idx), int(ContactAPointY)), Contacts.Value(int(idx), int(ContactAPointZ)))
}
func SetContactAPoint(idx int32, pos math32.Vector3) {
Contacts.Set(pos.X, int(idx), int(ContactAPointX))
Contacts.Set(pos.Y, int(idx), int(ContactAPointY))
Contacts.Set(pos.Z, int(idx), int(ContactAPointZ))
}
func ContactBPoint(idx int32) math32.Vector3 {
return math32.Vec3(Contacts.Value(int(idx), int(ContactBPointX)), Contacts.Value(int(idx), int(ContactBPointY)), Contacts.Value(int(idx), int(ContactBPointZ)))
}
func SetContactBPoint(idx int32, pos math32.Vector3) {
Contacts.Set(pos.X, int(idx), int(ContactBPointX))
Contacts.Set(pos.Y, int(idx), int(ContactBPointY))
Contacts.Set(pos.Z, int(idx), int(ContactBPointZ))
}
func ContactAOff(idx int32) math32.Vector3 {
return math32.Vec3(Contacts.Value(int(idx), int(ContactAOffX)), Contacts.Value(int(idx), int(ContactAOffY)), Contacts.Value(int(idx), int(ContactAOffZ)))
}
func SetContactAOff(idx int32, pos math32.Vector3) {
Contacts.Set(pos.X, int(idx), int(ContactAOffX))
Contacts.Set(pos.Y, int(idx), int(ContactAOffY))
Contacts.Set(pos.Z, int(idx), int(ContactAOffZ))
}
func ContactBOff(idx int32) math32.Vector3 {
return math32.Vec3(Contacts.Value(int(idx), int(ContactBOffX)), Contacts.Value(int(idx), int(ContactBOffY)), Contacts.Value(int(idx), int(ContactBOffZ)))
}
func SetContactBOff(idx int32, pos math32.Vector3) {
Contacts.Set(pos.X, int(idx), int(ContactBOffX))
Contacts.Set(pos.Y, int(idx), int(ContactBOffY))
Contacts.Set(pos.Z, int(idx), int(ContactBOffZ))
}
func ContactNorm(idx int32) math32.Vector3 {
return math32.Vec3(Contacts.Value(int(idx), int(ContactNormX)), Contacts.Value(int(idx), int(ContactNormY)), Contacts.Value(int(idx), int(ContactNormZ)))
}
func SetContactNorm(idx int32, pos math32.Vector3) {
Contacts.Set(pos.X, int(idx), int(ContactNormX))
Contacts.Set(pos.Y, int(idx), int(ContactNormY))
Contacts.Set(pos.Z, int(idx), int(ContactNormZ))
}
func ContactADelta(idx int32) math32.Vector3 {
return math32.Vec3(Contacts.Value(int(idx), int(ContactADeltaX)), Contacts.Value(int(idx), int(ContactADeltaY)), Contacts.Value(int(idx), int(ContactADeltaZ)))
}
func SetContactADelta(idx int32, pos math32.Vector3) {
Contacts.Set(pos.X, int(idx), int(ContactADeltaX))
Contacts.Set(pos.Y, int(idx), int(ContactADeltaY))
Contacts.Set(pos.Z, int(idx), int(ContactADeltaZ))
}
func ContactAAngDelta(idx int32) math32.Vector3 {
return math32.Vec3(Contacts.Value(int(idx), int(ContactAAngDeltaX)), Contacts.Value(int(idx), int(ContactAAngDeltaY)), Contacts.Value(int(idx), int(ContactAAngDeltaZ)))
}
func SetContactAAngDelta(idx int32, pos math32.Vector3) {
Contacts.Set(pos.X, int(idx), int(ContactAAngDeltaX))
Contacts.Set(pos.Y, int(idx), int(ContactAAngDeltaY))
Contacts.Set(pos.Z, int(idx), int(ContactAAngDeltaZ))
}
func ContactBDelta(idx int32) math32.Vector3 {
return math32.Vec3(Contacts.Value(int(idx), int(ContactBDeltaX)), Contacts.Value(int(idx), int(ContactBDeltaY)), Contacts.Value(int(idx), int(ContactBDeltaZ)))
}
func SetContactBDelta(idx int32, pos math32.Vector3) {
Contacts.Set(pos.X, int(idx), int(ContactBDeltaX))
Contacts.Set(pos.Y, int(idx), int(ContactBDeltaY))
Contacts.Set(pos.Z, int(idx), int(ContactBDeltaZ))
}
func ContactBAngDelta(idx int32) math32.Vector3 {
return math32.Vec3(Contacts.Value(int(idx), int(ContactBAngDeltaX)), Contacts.Value(int(idx), int(ContactBAngDeltaY)), Contacts.Value(int(idx), int(ContactBAngDeltaZ)))
}
func SetContactBAngDelta(idx int32, pos math32.Vector3) {
Contacts.Set(pos.X, int(idx), int(ContactBAngDeltaX))
Contacts.Set(pos.Y, int(idx), int(ContactBAngDeltaY))
Contacts.Set(pos.Z, int(idx), int(ContactBAngDeltaZ))
}
func WorldsCollide(wa, wb int32) bool {
if wa != -1 && wb != -1 && wa != wb {
return false
}
return true
}
func GroupsCollide(ga, gb int32) bool {
if ga == 0 || gb == 0 {
return false
}
if ga > 0 {
return ga == gb || gb < 0
}
if ga < 0 {
return ga != gb
}
return false
}
// newton: geometry/kernels.py: broadphase_collision_pairs
// CollisionBroad performs broad-phase collision detection, generating Contacts.
func CollisionBroad(i uint32) { //gosl:kernel
params := GetParams(0)
ci := int32(i)
if ci >= params.BodyCollidePairsN {
return
}
biA := BodyCollidePairs.Value(int(ci), int(0))
biB := BodyCollidePairs.Value(int(ci), int(1))
xwAR := BodyDynamicPos(biA, params.Cur)
xwAQ := BodyDynamicQuat(biA, params.Cur)
xwBR := BodyDynamicPos(biB, params.Cur)
// xwBQ := BodyDynamicQuat(bb, params.Cur)
// note: sA <= sB
sA := GetBodyShape(biA)
sB := GetBodyShape(biB)
rb := Bodies.Value(int(biB), int(BodyRadius))
// if type_a == GeoType.PLANE and type_b == GeoType.PLANE:
// return
// could be per-shape
// margin = wp.max(shape_contact_margin[shape_a], shape_contact_margin[shape_b])
margin := params.ContactMargin
// bounding sphere check
infPlane := false
if sA == Plane {
szA := BodyHSize(biA)
if szA.X == 0 {
infPlane = true
}
queryB := slmath.MulSpatialPoint(xwAR, xwAQ, xwBR)
closest := ClosestPointPlane(szA.X, szA.Z, queryB)
d := slmath.Length3(queryB.Sub(closest))
if d > rb+margin {
return
}
// fmt.Println("broad ct plane:", queryB, szA, closest, d, rb, margin)
} else {
d := slmath.Length3(xwAR.Sub(xwBR))
ra := Bodies.Value(int(biA), int(BodyRadius))
if d > ra+rb+margin {
return
}
}
var ncB int32
ncA := ShapePairContacts(sA, sB, infPlane, &ncB)
// note: ignoring contact_point_limit code for now
enci := atomic.AddInt32(&BroadContactsN.Values[0], ncA+ncB)
// Go returns post-added value, while WGSL returns pre-added value
//gosl:wgsl
// enci += ncA + ncB // wgsl now matches Go
//gosl:end
nci := enci - (ncA + ncB) // starting index
if nci >= params.ContactsMax { // shouldn't happen!
// fmt.Println("over max!", nci, params.ContactsMax)
return
}
AddBroadContacts(biA, biB, nci, ncA, ncB)
}
// newton: geometry/kernels.py: allocate_contact_points
// AddBroadContacts adds broad-phase contact records in prep for narrow phase.
func AddBroadContacts(biA, biB, nci, ncA, ncB int32) {
for i := range ncA {
SetBroadContactA(nci+i, biA)
SetBroadContactB(nci+i, biB)
SetBroadContactPointIdx(nci+i, i)
}
for i := range ncB {
SetBroadContactA(nci+ncA+i, biB) // flipped
SetBroadContactB(nci+ncA+i, biA)
SetBroadContactPointIdx(nci+i, i)
}
}
// newton: geometry/kernels.py: generate_handle_contact_pairs / handle_contact_pairs
// CollisionNarrow performs narrow-phase collision on Contacts.
func CollisionNarrow(i uint32) { //gosl:kernel
params := GetParams(0)
ci := int32(i)
cmax := BroadContactsN.Values[0]
if ci >= cmax {
return
}
biA := GetBroadContactA(ci)
biB := GetBroadContactB(ci)
cpi := GetBroadContactPointIdx(ci)
sA := GetBodyShape(biA)
sB := GetBodyShape(biB)
gdA := NewGeomData(biA, params.Cur, sA)
gdB := NewGeomData(biB, params.Cur, sB)
// could be per-shape
// margin = wp.max(shape_contact_margin[shape_a], shape_contact_margin[shape_b])
margin := params.ContactMargin
dist := float32(1.0e6)
maxIter := params.MaxGeomIter
// note: no Cone on anything
var ptA, ptB, norm, nnorm math32.Vector3
switch gdA.Shape {
case Plane:
switch gdB.Shape {
case Sphere:
dist = ColSpherePlane(cpi, maxIter, &gdB, &gdA, &ptB, &ptA, &nnorm) // reverse
norm = slmath.Negate3(nnorm)
case Capsule:
dist = ColCapsulePlane(cpi, maxIter, &gdB, &gdA, &ptB, &ptA, &nnorm) // reverse
norm = slmath.Negate3(nnorm)
case Cylinder:
dist = ColCylinderPlane(cpi, maxIter, &gdB, &gdA, &ptB, &ptA, &nnorm) // reverse
norm = slmath.Negate3(nnorm)
case Box:
dist = ColBoxPlane(cpi, maxIter, &gdB, &gdA, &ptB, &ptA, &nnorm) // reverse
norm = slmath.Negate3(nnorm)
default:
}
case Sphere:
switch gdB.Shape {
case Sphere:
dist = ColSphereSphere(cpi, maxIter, &gdA, &gdB, &ptA, &ptB, &norm)
case Capsule:
dist = ColSphereCapsule(cpi, maxIter, &gdA, &gdB, &ptA, &ptB, &norm)
// no cylinder
case Box:
dist = ColSphereBox(cpi, maxIter, &gdA, &gdB, &ptA, &ptB, &norm)
default:
}
case Capsule:
switch gdB.Shape {
case Capsule:
dist = ColCapsuleCapsule(cpi, maxIter, &gdA, &gdB, &ptA, &ptB, &norm)
// no cylinder
case Box:
dist = ColBoxCapsule(cpi, maxIter, &gdB, &gdA, &ptB, &ptA, &nnorm) // reverse
norm = slmath.Negate3(nnorm)
default:
}
case Box:
switch gdB.Shape {
case Box:
dist = ColBoxBox(cpi, maxIter, &gdA, &gdB, &ptA, &ptB, &norm)
default:
}
default:
}
var ctA, ctB, offA, offB math32.Vector3
var distActual, offMagA, offMagB float32
actual := ContactPoints(dist, margin, &gdA, &gdB, ptA, ptB, norm, &ctA, &ctB, &offA, &offB, &distActual, &offMagA, &offMagB)
if !actual {
return
}
enci := atomic.AddInt32(&ContactsN.Values[0], 1)
// Go returns post-added value, while WGSL returns pre-added value
//gosl:wgsl
// enci += int32(1) // wgsl now matches Go
//gosl:end
nci := enci - 1
SetContactA(nci, biA)
SetContactB(nci, biB)
SetContactPointIdx(nci, cpi)
SetContactAPoint(nci, ctA)
SetContactBPoint(nci, ctB)
SetContactAOff(nci, offA)
SetContactBOff(nci, offB)
SetContactNorm(nci, norm)
Contacts.Set(offMagA, int(nci), int(ContactAThick))
Contacts.Set(offMagB, int(nci), int(ContactBThick))
}
// newton: solvers/xpbd/kernels.py: solve_body_contact_positions
// StepBodyContacts generates contact forces for bodies.
func StepBodyContacts(i uint32) { //gosl:kernel
params := GetParams(0)
ci := int32(i)
cmax := ContactsN.Values[0]
if ci >= cmax {
return
}
biA := GetContactA(ci)
biB := GetContactB(ci)
diA := GetBodyDynamic(biA)
diB := GetBodyDynamic(biB)
r1A := BodyDynamicPos(biA, params.Next)
q1A := BodyDynamicQuat(biA, params.Next)
r1B := BodyDynamicPos(biB, params.Next)
q1B := BodyDynamicQuat(biB, params.Next)
ctA := ContactAPoint(ci)
offA := ContactAOff(ci)
ctB := ContactBPoint(ci)
offB := ContactBOff(ci)
ctAw := slmath.MulSpatialPoint(r1A, q1A, ctA)
ctBw := slmath.MulSpatialPoint(r1B, q1B, ctB)
thickA := Contacts.Value(int(ci), int(ContactAThick))
thickB := Contacts.Value(int(ci), int(ContactBThick))
thick := thickA + thickB
nnorm := ContactNorm(ci)
norm := slmath.Negate3(nnorm)
// margin := params.ContactMargin
d := slmath.Dot3(norm, ctBw.Sub(ctAw)) - thick
if d >= 0.0 { // todo: should this be margin or not?
Contacts.Set(0.0, int(ci), int(ContactWeight))
z := math32.Vec3(0, 0, 0)
SetContactADelta(ci, z)
SetContactBDelta(ci, z)
SetContactAAngDelta(ci, z)
SetContactBAngDelta(ci, z)
return
}
comA := BodyCom(biA)
mInvA := Bodies.Value(int(biA), int(BodyInvMass))
iInvA := BodyInvInertia(biA)
comB := BodyCom(biB)
mInvB := Bodies.Value(int(biB), int(BodyInvMass))
iInvB := BodyInvInertia(biB)
var w1A, w1B math32.Vector3
if diA >= 0 {
w1A = DynamicAngDelta(diA, params.Next)
}
if diB >= 0 {
w1B = DynamicAngDelta(diB, params.Next)
}
// use average contact material properties
mu := 0.5 * (Bodies.Value(int(biA), int(BodyFriction)) + Bodies.Value(int(biB), int(BodyFriction)))
frTors := 0.5 * (Bodies.Value(int(biA), int(BodyFrictionTortion)) + Bodies.Value(int(biB), int(BodyFrictionTortion)))
frRoll := 0.5 * (Bodies.Value(int(biA), int(BodyFrictionRolling)) + Bodies.Value(int(biB), int(BodyFrictionRolling)))
bounce := 0.5 * (Bodies.Value(int(biA), int(BodyBounce)) + Bodies.Value(int(biB), int(BodyBounce)))
// moment arms
dA := ctAw.Sub(slmath.MulSpatialPoint(r1A, q1A, comA))
dB := ctBw.Sub(slmath.MulSpatialPoint(r1B, q1B, comB))
angA := slmath.Negate3(slmath.Cross3(dA, norm))
angB := slmath.Cross3(dB, norm)
lambdaN := ContactConstraint(d, q1A, q1B, mInvA, mInvB, iInvA, iInvB, nnorm, norm, angA, angB, params.ContactRelax, params.Dt)
linDeltaA := slmath.Negate3(norm).MulScalar(lambdaN)
linDeltaB := norm.MulScalar(lambdaN)
angDeltaA := angA.MulScalar(lambdaN)
angDeltaB := angB.MulScalar(lambdaN)
// linear friction
if mu > 0.0 {
// add on displacement from surface offsets, this ensures
// we include any rotational effects due to thickness from feature
// need to use the current rotation to account for friction due to
// angular effects (e.g.: slipping contact)
ctAm := ctAw.Add(slmath.MulQuatVector(q1A, offA))
ctBm := ctBw.Add(slmath.MulQuatVector(q1B, offB))
// update delta
delta := ctBm.Sub(ctAm)
frDelta := delta.Sub(norm.MulScalar(slmath.Dot3(norm, delta)))
perp := slmath.Normal3(frDelta)
dAm := ctAm.Sub(slmath.MulSpatialPoint(r1A, q1A, comA))
dBm := ctBm.Sub(slmath.MulSpatialPoint(r1B, q1B, comB))
angA = slmath.Negate3(slmath.Cross3(dAm, perp))
angB = slmath.Cross3(dBm, perp)
err := slmath.Length3(frDelta)
if err > 0.0 {
lambdaFr := ContactConstraint(err, q1A, q1B, mInvA, mInvB, iInvA, iInvB, slmath.Negate3(perp), perp, angA, angB, params.ContactRelax, params.Dt)
// limit friction based on incremental normal force,
// good approximation to limiting on total force
lambdaFr = max(lambdaFr, -lambdaN*mu)
linDeltaA = linDeltaA.Sub(perp.MulScalar(lambdaFr))
linDeltaB = linDeltaB.Add(perp.MulScalar(lambdaFr))
angDeltaA = angDeltaA.Add(angA.MulScalar(lambdaFr))
angDeltaB = angDeltaB.Add(angB.MulScalar(lambdaFr))
}
}
deltaW := w1B.Sub(w1A)
if frTors > 0.0 {
err := slmath.Dot3(deltaW, norm) * params.Dt
if math32.Abs(err) > 0.0 {
lin := math32.Vec3(0, 0, 0)
lambdaTors := ContactConstraint(err, q1A, q1B, mInvA, mInvB, iInvA, iInvB, lin, lin, nnorm, norm, params.ContactRelax, params.Dt)
lambdaTors = math32.Clamp(lambdaTors, -lambdaN*frTors, lambdaN*frTors)
angDeltaA = angDeltaA.Sub(norm.MulScalar(lambdaTors))
angDeltaB = angDeltaB.Add(norm.MulScalar(lambdaTors))
}
}
if frRoll > 0.0 {
deltaW = deltaW.Sub(norm.MulScalar(slmath.Dot3(norm, deltaW)))
err := slmath.Length3(deltaW) * params.Dt
if err > 0.0 {
lin := math32.Vec3(0, 0, 0)
rollN := slmath.Normal3(deltaW)
lambdaRoll := ContactConstraint(err, q1A, q1B, mInvA, mInvB, iInvA, iInvB, lin, lin, slmath.Negate3(rollN), rollN, params.ContactRelax, params.Dt)
lambdaRoll = max(lambdaRoll, -lambdaN*frRoll)
angDeltaA = angDeltaA.Sub(rollN.MulScalar(lambdaRoll))
angDeltaB = angDeltaB.Add(rollN.MulScalar(lambdaRoll))
}
}
// restitution (bounce)
if params.Restitution.IsTrue() && bounce > 0 && (mInvA > 0 || mInvB > 0) {
var vA, vB, vAnew, vBnew, dAnew, dBnew math32.Vector3
var mInvAr, mInvBr float32
var q0A, q0B math32.Quat
grav := params.Gravity.V().MulScalar(params.Dt)
if diA >= 0 {
q0A = DynamicQuat(diA, params.Cur)
w0A := DynamicAngDelta(diA, params.Cur)
v0A := DynamicDelta(diA, params.Cur)
v1A := DynamicDelta(diA, params.Next)
vA = VelocityAtPoint(v0A, w0A, dA).Add(grav)
vAnew = VelocityAtPoint(v1A, w1A, dA)
dAnew = slmath.MulQuatVectorInverse(q0A, slmath.Cross3(dA, nnorm)) // norm is not - here..
mInvAr = mInvA + slmath.Dot3(dAnew, iInvA.MulVector3(dAnew))
}
if diB >= 0 {
q0B = DynamicQuat(diB, params.Cur)
w0B := DynamicAngDelta(diB, params.Cur)
v0B := DynamicDelta(diB, params.Cur)
v1B := DynamicDelta(diB, params.Next)
vB = VelocityAtPoint(v0B, w0B, dB).Add(grav)
vBnew = VelocityAtPoint(v1B, w1B, dB)
dBnew = slmath.MulQuatVectorInverse(q0B, slmath.Cross3(dB, norm)) // norm is not - here..
mInvBr = mInvB + slmath.Dot3(dBnew, iInvB.MulVector3(dBnew))
}
mInv := mInvAr + mInvBr
relVel0 := slmath.Dot3(nnorm, vA.Sub(vB))
relVel1 := slmath.Dot3(nnorm, vAnew.Sub(vBnew))
if relVel0 < 0 {
dv := -(relVel1 - relVel0*bounce) / mInv
// fmt.Println(dv, relVel1, relVel0, bounce, mInv)
if diA >= 0 {
dvA := nnorm.MulScalar(mInvA * dv)
dwA := slmath.MulQuatVector(q0A, iInvA.MulVector3(dAnew).MulScalar(dv))
linDeltaA = linDeltaA.Add(dvA)
angDeltaA = angDeltaA.Add(dwA)
}
if diB >= 0 {
dvB := norm.MulScalar(mInvB * dv)
dwB := slmath.MulQuatVector(q0B, iInvB.MulVector3(dBnew).MulScalar(dv))
linDeltaB = linDeltaB.Add(dvB)
angDeltaB = angDeltaB.Add(dwB)
}
}
}
Contacts.Set(1.0, int(ci), int(ContactWeight))
SetContactADelta(ci, linDeltaA)
SetContactBDelta(ci, linDeltaB)
SetContactAAngDelta(ci, angDeltaA)
SetContactBAngDelta(ci, angDeltaB)
}
// StepBodyContactDeltas gathers raw deltas, angDeltas from contacts per dynamic
// and computes updated deltas integrated via StepBodyDeltas.
func StepBodyContactDeltas(i uint32) { //gosl:kernel
params := GetParams(0)
di := int32(i)
if di >= params.DynamicsN {
return
}
bi := DynamicBody(di)
invMass := Bodies.Value(int(bi), int(BodyInvMass))
if invMass == 0 {
return // no updates
}
cmax := ContactsN.Values[0]
linDel := math32.Vec3(0, 0, 0)
angDel := math32.Vec3(0, 0, 0)
tw := float32(0)
for ci := range cmax {
wt := Contacts.Value(int(ci), int(ContactWeight))
if wt == 0 { // 0 = no actual; else 1
continue
}
biA := GetContactA(ci)
biB := GetContactB(ci)
if biA == bi {
tw += wt
d := ContactADelta(ci)
linDel = linDel.Add(d)
a := ContactAAngDelta(ci)
angDel = angDel.Add(a)
}
if biB == bi {
tw += wt
d := ContactBDelta(ci)
linDel = linDel.Add(d)
a := ContactBAngDelta(ci)
angDel = angDel.Add(a)
}
}
Dynamics.Set(tw, int(di), int(params.Next), int(DynContactWeight))
StepBodyDeltas(di, bi, true, tw, linDel, angDel)
}
func ContactConstraint(err float32, q0A, q0B math32.Quat, mInvA, mInvB float32, iInvA, iInvB math32.Matrix3, linA, linB, angA, angB math32.Vector3, relaxation, dt float32) float32 {
denom := float32(0.0)
denom += slmath.LengthSquared3(linA) * mInvA
denom += slmath.LengthSquared3(linB) * mInvB
// Eq. 2-3 (make sure to project into the frame of the body)
rotAngA := slmath.MulQuatVectorInverse(q0A, angA)
rotAngB := slmath.MulQuatVectorInverse(q0B, angB)
denom += slmath.Dot3(rotAngA, iInvA.MulVector3(rotAngA))
denom += slmath.Dot3(rotAngB, iInvB.MulVector3(rotAngB))
lambda := -err
if denom > 0.0 {
lambda /= dt * denom
}
return lambda * relaxation
}
//gosl:end
// IsChildDynamic returns true if dic is a direct child
// on any joint where dip is the parent.
func (ml *Model) IsChildDynamic(dip, dic int32) bool {
if dip < 0 || dic < 0 {
return false
}
npja := ml.BodyJoints.Value(int(dip), int(0), int(0))
for j := range npja {
ji := ml.BodyJoints.Value(int(dip), int(0), int(1+j))
jci := JointChildIndex(ji)
if jci == dic {
return true
}
}
return false
}
// newton: sim/builder.py: find_shape_contact_pairs
// ConfigBodyCollidePairs compiles a list of body paris that could collide
// based on world and group settings and not being direct parent
// child relationship within a joint. Result has A with lower shape type,
// so that shapes are in a canonical order.
func (ml *Model) ConfigBodyCollidePairs() {
params := &ml.Params[0]
nb := params.BodiesN
nalc := int(nb) * 10
pt := tensor.NewInt32(nalc, 2)
np := 0
for a := range nb {
wa := GetBodyWorld(a)
ga := GetBodyGroup(a)
dia := GetBodyDynamic(a)
for b := range nb {
if a == b {
continue
}
wb := GetBodyWorld(b)
gb := GetBodyGroup(b)
if !WorldsCollide(wa, wb) {
continue
}
if !GroupsCollide(ga, gb) {
continue
}
dib := GetBodyDynamic(b)
// now check joints (ConfigJoints must have been called first)
if ml.IsChildDynamic(dia, dib) || ml.IsChildDynamic(dib, dia) {
continue
}
if np >= nalc {
nalc += int(nb)
pt.SetShapeSizes(nalc, 2)
// fmt.Println("body pairs realoc", nalc)
}
sA := GetBodyShape(a)
sB := GetBodyShape(b)
if sA <= sB {
pt.Set(a, int(np), int(0))
pt.Set(b, int(np), int(1))
} else {
pt.Set(b, int(np), int(0))
pt.Set(a, int(np), int(1))
}
np++
}
}
params.BodyCollidePairsN = int32(np)
if np == 0 {
np = 1
}
pt.SetShapeSizes(np, 2)
ml.BodyCollidePairs = pt
BodyCollidePairs = pt
// fmt.Println("body pairs over alloc", nalc - np, "total:", np)
}
// newton: geometry/kernels.py: count_contact_points
// SetMaxContacts computes [Params.MaxContacts] based on current list of
// [BodyCollidePairs].
func (ml *Model) SetMaxContacts() {
params := &ml.Params[0]
n := int32(0)
for ci := range params.BodyCollidePairsN {
biA := BodyCollidePairs.Value(int(ci), int(0))
biB := BodyCollidePairs.Value(int(ci), int(1))
// note: sA <= sB
sA := GetBodyShape(biA)
sB := GetBodyShape(biB)
infPlane := false
szA := BodyHSize(biA)
if szA.X == 0 {
infPlane = true
}
var ncB int32
ncA := ShapePairContacts(sA, sB, infPlane, &ncB)
n += ncA + ncB
}
// todo: this is a massive over-estimate, b/c there is no way everyone could be
// colliding at once. Except.. if it is a very small model.
if params.BodyCollidePairsN > 1000 {
n = 4 * max(int32(math32.Sqrt(float32(n))), int32(Bodies.DimSize(0)))
// fmt.Println("> 1000", params.BodyCollidePairsN, n)
}
n = max(n, 4*params.DynamicsN)
params.ContactsMax = n
ml.BroadContacts.SetShapeSizes(int(n), int(BroadContactVarsN))
ml.Contacts.SetShapeSizes(int(n), int(ContactVarsN))
}
// Code generated by "goal build"; DO NOT EDIT.
//line control.goal:1
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package physics
import "cogentcore.org/core/math32"
//gosl:start
// JointControlVars are external joint control input variables stored in tensor.Float32.
// These must be in one-to-one correspondence with the JointDoFs.
type JointControlVars int32 //enums:enum
const (
// Joint force and torque inputs
JointControlForce JointControlVars = iota
// JointTargetPos is the position target value input to the model,
// where 0 is the initial position. For angular joints, this is in radians.
// This is subject to a graded transition over time, [JointTargetPosCur]
// has the current effective value.
JointTargetPos
// JointTargetPosCur is the current position target value,
// updated from [JointTargetPos] input using the [Params.ControlDt]
// time constant.
JointTargetPosCur
// JointTargetStiff determines how strongly the target position
// is enforced: 0 = not at all; larger = stronger (e.g., 1000 or higher).
// Set to 0 to allow the joint to be fully flexible.
JointTargetStiff
// JointTargetVel is the velocity target value. For example, 0
// effectively damps joint movement in proportion to Damp parameter.
JointTargetVel
// JointTargetDamp determines how strongly the target velocity is enforced:
// 0 = not at all; larger = stronger (e.g., 1 is reasonable).
// Set to 0 to allow the joint to be fully flexible.
JointTargetDamp
)
// SetJointControl sets the control for given joint, dof and parameter
// to given value.
func SetJointControl(idx, dof int32, vr JointControlVars, value float32) {
JointControls.Set(value, int(JointDoFIndex(idx, dof)), int(vr))
}
func JointControl(idx, dof int32, vr JointControlVars) float32 {
return JointControls.Value(int(JointDoFIndex(idx, dof)), int(vr))
}
// SetJointControlForce sets the force for given joint, dof to given value.
func SetJointControlForce(idx, dof int32, value float32) {
SetJointControl(idx, dof, JointControlForce, value)
}
// SetJointTargetPos sets the target position and stiffness
// for given joint, DoF to given values.
// Stiffness determines how strongly the joint constraint is enforced
// (0 = not at all; 1000+ = strongly).
// For angular joints, values are in radians, see also
// [SetJointTargetAngle].
func SetJointTargetPos(idx, dof int32, pos, stiff float32) {
SetJointControl(idx, dof, JointTargetPos, pos)
SetJointControl(idx, dof, JointTargetStiff, stiff)
}
// SetJointTargetAngle sets the target angular position
// and stiffness for given joint, DoF to given values.
// Stiffness determines how strongly the joint constraint is enforced
// (0 = not at all; 1000+ = strongly).
// Angle is in Degrees, not radians. Usable range is within -180..180
// which is enforced, and values near the edge can be unstable at higher
// stiffness levels.
func SetJointTargetAngle(idx, dof int32, angDeg, stiff float32) {
pos := math32.WrapPi(math32.DegToRad(angDeg))
SetJointTargetPos(idx, dof, pos, stiff)
}
// GetJointTargetPos returns the target position
// for given joint, DoF.
func GetJointTargetPos(idx, dof int32) float32 {
return JointControl(idx, dof, JointTargetPos)
}
// SetJointTargetVel sets the target velocity and damping
// for given joint, DoF to given values. Damping determines
// how strongly the joint constraint is enforced
// (0 = not at all; 1000+ = strongly).
func SetJointTargetVel(idx, dof int32, vel, damp float32) {
SetJointControl(idx, dof, JointTargetVel, vel)
SetJointControl(idx, dof, JointTargetDamp, damp)
}
// GetJointTargetVel returns the target velocity
// for given joint, DoF.
func GetJointTargetVel(idx, dof int32) float32 {
return JointControl(idx, dof, JointTargetVel)
}
//gosl:end
// Code generated by "goal build"; DO NOT EDIT.
//line dynamics.goal:1
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package physics
import (
// "fmt"
"math"
"cogentcore.org/core/math32"
"cogentcore.org/lab/gosl/slmath"
)
//gosl:start
// DynamicVars are dynamic body variables stored in tensor.Float32.
type DynamicVars int32 //enums:enum
const (
// Index of body in list of bodies.
DynBody DynamicVars = iota
// 3D position of structural center.
DynPosX
DynPosY
DynPosZ
// Quaternion rotation.
DynQuatX
DynQuatY
DynQuatZ
DynQuatW
// Linear velocity.
DynVelX
DynVelY
DynVelZ
// Angular velocity.
DynAngVelX
DynAngVelY
DynAngVelZ
// Linear acceleration.
DynAccX
DynAccY
DynAccZ
// Angular acceleration due to applied torques.
DynAngAccX
DynAngAccY
DynAngAccZ
// Linear force driving linear acceleration (from joints, etc).
DynForceX
DynForceY
DynForceZ
// Torque driving angular acceleration (from joints, etc).
DynTorqueX
DynTorqueY
DynTorqueZ
// Linear deltas. These accumulate over time via StepBodyDeltas.
DynDeltaX
DynDeltaY
DynDeltaZ
// Angular deltas. These accumulate over time via StepBodyDeltas.
DynAngDeltaX
DynAngDeltaY
DynAngDeltaZ
// integrated weight of all contacts
DynContactWeight
)
// cni = current / next index
func SetDynamicBody(idx, bodyIdx int32) {
bi := math.Float32frombits(uint32(bodyIdx))
Dynamics.Set(bi, int(idx), int(0), int(DynBody))
Dynamics.Set(bi, int(idx), int(1), int(DynBody))
}
func DynamicBody(idx int32) int32 {
return int32(math.Float32bits(Dynamics.Value(int(idx), int(0), int(DynBody))))
}
func DynamicPos(idx, cni int32) math32.Vector3 {
return math32.Vec3(Dynamics.Value(int(idx), int(cni), int(DynPosX)), Dynamics.Value(int(idx), int(cni), int(DynPosY)), Dynamics.Value(int(idx), int(cni), int(DynPosZ)))
}
func SetDynamicPos(idx, cni int32, pos math32.Vector3) {
Dynamics.Set(pos.X, int(idx), int(cni), int(DynPosX))
Dynamics.Set(pos.Y, int(idx), int(cni), int(DynPosY))
Dynamics.Set(pos.Z, int(idx), int(cni), int(DynPosZ))
}
func DynamicQuat(idx, cni int32) math32.Quat {
return math32.NewQuat(Dynamics.Value(int(idx), int(cni), int(DynQuatX)), Dynamics.Value(int(idx), int(cni), int(DynQuatY)), Dynamics.Value(int(idx), int(cni), int(DynQuatZ)), Dynamics.Value(int(idx), int(cni), int(DynQuatW)))
}
func SetDynamicQuat(idx, cni int32, rot math32.Quat) {
Dynamics.Set(rot.X, int(idx), int(cni), int(DynQuatX))
Dynamics.Set(rot.Y, int(idx), int(cni), int(DynQuatY))
Dynamics.Set(rot.Z, int(idx), int(cni), int(DynQuatZ))
Dynamics.Set(rot.W, int(idx), int(cni), int(DynQuatW))
}
func DynamicVel(idx, cni int32) math32.Vector3 {
return math32.Vec3(Dynamics.Value(int(idx), int(cni), int(DynVelX)), Dynamics.Value(int(idx), int(cni), int(DynVelY)), Dynamics.Value(int(idx), int(cni), int(DynVelZ)))
}
func SetDynamicVel(idx, cni int32, vel math32.Vector3) {
Dynamics.Set(vel.X, int(idx), int(cni), int(DynVelX))
Dynamics.Set(vel.Y, int(idx), int(cni), int(DynVelY))
Dynamics.Set(vel.Z, int(idx), int(cni), int(DynVelZ))
}
func DynamicAcc(idx, cni int32) math32.Vector3 {
return math32.Vec3(Dynamics.Value(int(idx), int(cni), int(DynAccX)), Dynamics.Value(int(idx), int(cni), int(DynAccY)), Dynamics.Value(int(idx), int(cni), int(DynAccZ)))
}
func SetDynamicAcc(idx, cni int32, acc math32.Vector3) {
Dynamics.Set(acc.X, int(idx), int(cni), int(DynAccX))
Dynamics.Set(acc.Y, int(idx), int(cni), int(DynAccY))
Dynamics.Set(acc.Z, int(idx), int(cni), int(DynAccZ))
}
func DynamicForce(idx, cni int32) math32.Vector3 {
return math32.Vec3(Dynamics.Value(int(idx), int(cni), int(DynForceX)), Dynamics.Value(int(idx), int(cni), int(DynForceY)), Dynamics.Value(int(idx), int(cni), int(DynForceZ)))
}
func SetDynamicForce(idx, cni int32, force math32.Vector3) {
Dynamics.Set(force.X, int(idx), int(cni), int(DynForceX))
Dynamics.Set(force.Y, int(idx), int(cni), int(DynForceY))
Dynamics.Set(force.Z, int(idx), int(cni), int(DynForceZ))
}
func DynamicTorque(idx, cni int32) math32.Vector3 {
return math32.Vec3(Dynamics.Value(int(idx), int(cni), int(DynTorqueX)), Dynamics.Value(int(idx), int(cni), int(DynTorqueY)), Dynamics.Value(int(idx), int(cni), int(DynTorqueZ)))
}
func SetDynamicTorque(idx, cni int32, torque math32.Vector3) {
Dynamics.Set(torque.X, int(idx), int(cni), int(DynTorqueX))
Dynamics.Set(torque.Y, int(idx), int(cni), int(DynTorqueY))
Dynamics.Set(torque.Z, int(idx), int(cni), int(DynTorqueZ))
}
func DynamicAngVel(idx, cni int32) math32.Vector3 {
return math32.Vec3(Dynamics.Value(int(idx), int(cni), int(DynAngVelX)), Dynamics.Value(int(idx), int(cni), int(DynAngVelY)), Dynamics.Value(int(idx), int(cni), int(DynAngVelZ)))
}
func SetDynamicAngVel(idx, cni int32, angVel math32.Vector3) {
Dynamics.Set(angVel.X, int(idx), int(cni), int(DynAngVelX))
Dynamics.Set(angVel.Y, int(idx), int(cni), int(DynAngVelY))
Dynamics.Set(angVel.Z, int(idx), int(cni), int(DynAngVelZ))
}
func DynamicAngAcc(idx, cni int32) math32.Vector3 {
return math32.Vec3(Dynamics.Value(int(idx), int(cni), int(DynAngAccX)), Dynamics.Value(int(idx), int(cni), int(DynAngAccY)), Dynamics.Value(int(idx), int(cni), int(DynAngAccZ)))
}
func SetDynamicAngAcc(idx, cni int32, angAcc math32.Vector3) {
Dynamics.Set(angAcc.X, int(idx), int(cni), int(DynAngAccX))
Dynamics.Set(angAcc.Y, int(idx), int(cni), int(DynAngAccY))
Dynamics.Set(angAcc.Z, int(idx), int(cni), int(DynAngAccZ))
}
//////// Accumulating deltas
func DynamicDelta(idx, cni int32) math32.Vector3 {
return math32.Vec3(Dynamics.Value(int(idx), int(cni), int(DynDeltaX)), Dynamics.Value(int(idx), int(cni), int(DynDeltaY)), Dynamics.Value(int(idx), int(cni), int(DynDeltaZ)))
}
func SetDynamicDelta(idx, cni int32, delta math32.Vector3) {
Dynamics.Set(delta.X, int(idx), int(cni), int(DynDeltaX))
Dynamics.Set(delta.Y, int(idx), int(cni), int(DynDeltaY))
Dynamics.Set(delta.Z, int(idx), int(cni), int(DynDeltaZ))
}
func DynamicAngDelta(idx, cni int32) math32.Vector3 {
return math32.Vec3(Dynamics.Value(int(idx), int(cni), int(DynAngDeltaX)), Dynamics.Value(int(idx), int(cni), int(DynAngDeltaY)), Dynamics.Value(int(idx), int(cni), int(DynAngDeltaZ)))
}
func SetDynamicAngDelta(idx, cni int32, angDelta math32.Vector3) {
Dynamics.Set(angDelta.X, int(idx), int(cni), int(DynAngDeltaX))
Dynamics.Set(angDelta.Y, int(idx), int(cni), int(DynAngDeltaY))
Dynamics.Set(angDelta.Z, int(idx), int(cni), int(DynAngDeltaZ))
}
//gosl:end
// SetMass sets the mass of given body object (only relevant for dynamics),
// including a default inertia tensor based on solid shape of given size.
func (ml *Model) SetMass(idx int32, shape Shapes, size math32.Vector3, mass float32) {
Bodies.Set(shape.Radius(size), int(idx), int(BodyRadius))
Bodies.Set(mass, int(idx), int(BodyMass))
invm := mass
if mass > 0 {
invm = 1.0 / mass
}
Bodies.Set(invm, int(idx), int(BodyInvMass))
inertia := shape.Inertia(size, mass)
SetBodyInertia(idx, inertia)
SetBodyInvInertia(idx, inertia.Inverse())
}
// TotalKineticEnergy returns the total kinetic energy of the dynamic bodies,
// as a function of the velocities.
func (ml *Model) TotalKineticEnergy() float32 {
params := GetParams(0)
ke := float32(0)
n := int32(Dynamics.DimSize(0))
for di := range n {
bi := DynamicBody(di)
mass := Bodies.Value(int(bi), int(BodyMass))
inertia := BodyInertia(bi)
v := DynamicVel(di, params.Next)
mv := 0.5 * mass * slmath.LengthSquared3(v)
w := DynamicAngVel(di, params.Next)
iw := 0.5 * slmath.Dot3(w, inertia.MulVector3(w))
ke += mv + iw
}
return ke
}
// AngularVelocityAt returns the angular velocity vector of given dynamic body
// index and Next index, relative to given rotation axis at given point
// relative to the structural center of the given dynamic body.
// For example, to get rotation around the XZ plane, axis = (0,1,0) and
// the velocity value will show up in the Z axis for an X-axis point,
// and vice-versa (X for a Z-axis point).
// This uses DynamicAngVel which is computed after each step (into Next).
func AngularVelocityAt(di int32, point, axis math32.Vector3) math32.Vector3 {
params := GetParams(0)
w := DynamicAngVel(di, params.Next)
wp := slmath.Cross3(w.Mul(axis), point)
return wp
}
// AngularAccelAt returns the angular acceleration vector of given dynamic body
// index and Next index, relative to given rotation axis at given point
// relative to the structural center of the given dynamic body.
// For example, to get rotation around the XZ plane, axis = (0,1,0) and
// the acceleration value will show up in the Z axis for an X-axis point,
// and vice-versa (X for a Z-axis point).
// This uses DynamicAngAcc which is computed after each step (into Next).
func AngularAccelAt(di int32, point, axis math32.Vector3) math32.Vector3 {
params := GetParams(0)
w := DynamicAngAcc(di, params.Next)
wp := slmath.Cross3(w.Mul(axis), point)
return wp
}
// Code generated by "core generate -add-types -gosl"; DO NOT EDIT.
package physics
import (
"cogentcore.org/core/enums"
)
var _BodyVarsValues = []BodyVars{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42}
// BodyVarsN is the highest valid value for type BodyVars, plus one.
//
//gosl:start
const BodyVarsN BodyVars = 43
//gosl:end
var _BodyVarsValueMap = map[string]BodyVars{`BodyShape`: 0, `BodyDynamic`: 1, `BodyWorld`: 2, `BodyGroup`: 3, `BodyHSizeX`: 4, `BodyHSizeY`: 5, `BodyHSizeZ`: 6, `BodyThick`: 7, `BodyMass`: 8, `BodyInvMass`: 9, `BodyBounce`: 10, `BodyFriction`: 11, `BodyFrictionTortion`: 12, `BodyFrictionRolling`: 13, `BodyPosX`: 14, `BodyPosY`: 15, `BodyPosZ`: 16, `BodyQuatX`: 17, `BodyQuatY`: 18, `BodyQuatZ`: 19, `BodyQuatW`: 20, `BodyComX`: 21, `BodyComY`: 22, `BodyComZ`: 23, `BodyInertiaXX`: 24, `BodyInertiaYX`: 25, `BodyInertiaZX`: 26, `BodyInertiaXY`: 27, `BodyInertiaYY`: 28, `BodyInertiaZY`: 29, `BodyInertiaXZ`: 30, `BodyInertiaYZ`: 31, `BodyInertiaZZ`: 32, `BodyInvInertiaXX`: 33, `BodyInvInertiaYX`: 34, `BodyInvInertiaZX`: 35, `BodyInvInertiaXY`: 36, `BodyInvInertiaYY`: 37, `BodyInvInertiaZY`: 38, `BodyInvInertiaXZ`: 39, `BodyInvInertiaYZ`: 40, `BodyInvInertiaZZ`: 41, `BodyRadius`: 42}
var _BodyVarsDescMap = map[BodyVars]string{0: `BodyShape is the shape type of the object, as a Shapes type.`, 1: `BodyDynamic is the index into Dynamics for this body, which is -1 for static bodies. Use this to get current Pos and Quat values for a dynamic body.`, 2: `BodyWorld partitions bodies into different worlds for collision detection: Global bodies = -1 can collide with everything; otherwise only items within the same world collide. NewBody uses [World.CurrentWorld] to initialize.`, 3: `BodyGroup partitions bodies within worlds into different groups for collision detection. 0 does not collide with anything. Negative numbers are global within a world, except they don't collide amongst themselves (all non-dynamic bodies should go in -1 because they don't collide amongst each-other, but do potentially collide with dynamics). Positive numbers only collide amongst themselves, and with negative groups, but not other positive groups. To avoid unwanted collisions, put bodies into separate groups. There is an automatic constraint that the two objects within a single joint do not collide with each other, so this does not need to be handled here.`, 4: `BodyHSize is the half-size (e.g., radius) of the body. Values depend on shape type: X is generally radius, Y is half-height.`, 5: ``, 6: ``, 7: `BodyThick is the thickness of the body, as a hollow shape. If 0, then it is a solid shape (default).`, 8: `BodyMass is the mass of the object.`, 9: `BodyInvMass is 1/mass of the object or 0 if no mass.`, 10: `BodyBounce specifies the COR or coefficient of restitution (0..1), which determines how elastic the collision is, i.e., final velocity / initial velocity.`, 11: `BodyFriction is the standard coefficient for linear friction (mu).`, 12: `BodyFrictionTortion is resistance to spinning at the contact point.`, 13: `BodyFrictionRolling is resistance to rolling motion at contact.`, 14: `3D position of body (structural center).`, 15: ``, 16: ``, 17: `Quaternion rotation of body.`, 18: ``, 19: ``, 20: ``, 21: `Relative center-of-mass offset from 3D position of body.`, 22: ``, 23: ``, 24: `Inertia 3x3 matrix (column matrix organization, r,c labels).`, 25: ``, 26: ``, 27: ``, 28: ``, 29: ``, 30: ``, 31: ``, 32: ``, 33: `InvInertia inverse inertia 3x3 matrix (column matrix organization, r,c labels).`, 34: ``, 35: ``, 36: ``, 37: ``, 38: ``, 39: ``, 40: ``, 41: ``, 42: `radius for broadphase collision`}
var _BodyVarsMap = map[BodyVars]string{0: `BodyShape`, 1: `BodyDynamic`, 2: `BodyWorld`, 3: `BodyGroup`, 4: `BodyHSizeX`, 5: `BodyHSizeY`, 6: `BodyHSizeZ`, 7: `BodyThick`, 8: `BodyMass`, 9: `BodyInvMass`, 10: `BodyBounce`, 11: `BodyFriction`, 12: `BodyFrictionTortion`, 13: `BodyFrictionRolling`, 14: `BodyPosX`, 15: `BodyPosY`, 16: `BodyPosZ`, 17: `BodyQuatX`, 18: `BodyQuatY`, 19: `BodyQuatZ`, 20: `BodyQuatW`, 21: `BodyComX`, 22: `BodyComY`, 23: `BodyComZ`, 24: `BodyInertiaXX`, 25: `BodyInertiaYX`, 26: `BodyInertiaZX`, 27: `BodyInertiaXY`, 28: `BodyInertiaYY`, 29: `BodyInertiaZY`, 30: `BodyInertiaXZ`, 31: `BodyInertiaYZ`, 32: `BodyInertiaZZ`, 33: `BodyInvInertiaXX`, 34: `BodyInvInertiaYX`, 35: `BodyInvInertiaZX`, 36: `BodyInvInertiaXY`, 37: `BodyInvInertiaYY`, 38: `BodyInvInertiaZY`, 39: `BodyInvInertiaXZ`, 40: `BodyInvInertiaYZ`, 41: `BodyInvInertiaZZ`, 42: `BodyRadius`}
// String returns the string representation of this BodyVars value.
func (i BodyVars) String() string { return enums.String(i, _BodyVarsMap) }
// SetString sets the BodyVars value from its string representation,
// and returns an error if the string is invalid.
func (i *BodyVars) SetString(s string) error {
return enums.SetString(i, s, _BodyVarsValueMap, "BodyVars")
}
// Int64 returns the BodyVars value as an int64.
func (i BodyVars) Int64() int64 { return int64(i) }
// SetInt64 sets the BodyVars value from an int64.
func (i *BodyVars) SetInt64(in int64) { *i = BodyVars(in) }
// Desc returns the description of the BodyVars value.
func (i BodyVars) Desc() string { return enums.Desc(i, _BodyVarsDescMap) }
// BodyVarsValues returns all possible values for the type BodyVars.
func BodyVarsValues() []BodyVars { return _BodyVarsValues }
// Values returns all possible values for the type BodyVars.
func (i BodyVars) Values() []enums.Enum { return enums.Values(_BodyVarsValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i BodyVars) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *BodyVars) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "BodyVars") }
var _ContactVarsValues = []ContactVars{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}
// ContactVarsN is the highest valid value for type ContactVars, plus one.
//
//gosl:start
const ContactVarsN ContactVars = 33
//gosl:end
var _ContactVarsValueMap = map[string]ContactVars{`ContactA`: 0, `ContactB`: 1, `ContactPointIdx`: 2, `ContactAPointX`: 3, `ContactAPointY`: 4, `ContactAPointZ`: 5, `ContactBPointX`: 6, `ContactBPointY`: 7, `ContactBPointZ`: 8, `ContactAOffX`: 9, `ContactAOffY`: 10, `ContactAOffZ`: 11, `ContactBOffX`: 12, `ContactBOffY`: 13, `ContactBOffZ`: 14, `ContactAThick`: 15, `ContactBThick`: 16, `ContactNormX`: 17, `ContactNormY`: 18, `ContactNormZ`: 19, `ContactWeight`: 20, `ContactADeltaX`: 21, `ContactADeltaY`: 22, `ContactADeltaZ`: 23, `ContactAAngDeltaX`: 24, `ContactAAngDeltaY`: 25, `ContactAAngDeltaZ`: 26, `ContactBDeltaX`: 27, `ContactBDeltaY`: 28, `ContactBDeltaZ`: 29, `ContactBAngDeltaX`: 30, `ContactBAngDeltaY`: 31, `ContactBAngDeltaZ`: 32}
var _ContactVarsDescMap = map[ContactVars]string{0: `first body index`, 1: `the other body index`, 2: `contact point index for A-B pair`, 3: `contact point on body A`, 4: ``, 5: ``, 6: `contact point on body B`, 7: ``, 8: ``, 9: `contact offset on body A`, 10: ``, 11: ``, 12: `contact offset on body B`, 13: ``, 14: ``, 15: `Contact thickness`, 16: ``, 17: `normal pointing from center of B to center of A`, 18: ``, 19: ``, 20: `contact weighting -- 1 if contact made; for restitution use this to filter contacts when updating body.`, 21: `computed contact deltas, A`, 22: ``, 23: ``, 24: ``, 25: ``, 26: ``, 27: `computed contact deltas, B`, 28: ``, 29: ``, 30: ``, 31: ``, 32: ``}
var _ContactVarsMap = map[ContactVars]string{0: `ContactA`, 1: `ContactB`, 2: `ContactPointIdx`, 3: `ContactAPointX`, 4: `ContactAPointY`, 5: `ContactAPointZ`, 6: `ContactBPointX`, 7: `ContactBPointY`, 8: `ContactBPointZ`, 9: `ContactAOffX`, 10: `ContactAOffY`, 11: `ContactAOffZ`, 12: `ContactBOffX`, 13: `ContactBOffY`, 14: `ContactBOffZ`, 15: `ContactAThick`, 16: `ContactBThick`, 17: `ContactNormX`, 18: `ContactNormY`, 19: `ContactNormZ`, 20: `ContactWeight`, 21: `ContactADeltaX`, 22: `ContactADeltaY`, 23: `ContactADeltaZ`, 24: `ContactAAngDeltaX`, 25: `ContactAAngDeltaY`, 26: `ContactAAngDeltaZ`, 27: `ContactBDeltaX`, 28: `ContactBDeltaY`, 29: `ContactBDeltaZ`, 30: `ContactBAngDeltaX`, 31: `ContactBAngDeltaY`, 32: `ContactBAngDeltaZ`}
// String returns the string representation of this ContactVars value.
func (i ContactVars) String() string { return enums.String(i, _ContactVarsMap) }
// SetString sets the ContactVars value from its string representation,
// and returns an error if the string is invalid.
func (i *ContactVars) SetString(s string) error {
return enums.SetString(i, s, _ContactVarsValueMap, "ContactVars")
}
// Int64 returns the ContactVars value as an int64.
func (i ContactVars) Int64() int64 { return int64(i) }
// SetInt64 sets the ContactVars value from an int64.
func (i *ContactVars) SetInt64(in int64) { *i = ContactVars(in) }
// Desc returns the description of the ContactVars value.
func (i ContactVars) Desc() string { return enums.Desc(i, _ContactVarsDescMap) }
// ContactVarsValues returns all possible values for the type ContactVars.
func ContactVarsValues() []ContactVars { return _ContactVarsValues }
// Values returns all possible values for the type ContactVars.
func (i ContactVars) Values() []enums.Enum { return enums.Values(_ContactVarsValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i ContactVars) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *ContactVars) UnmarshalText(text []byte) error {
return enums.UnmarshalText(i, text, "ContactVars")
}
var _JointControlVarsValues = []JointControlVars{0, 1, 2, 3, 4, 5}
// JointControlVarsN is the highest valid value for type JointControlVars, plus one.
//
//gosl:start
const JointControlVarsN JointControlVars = 6
//gosl:end
var _JointControlVarsValueMap = map[string]JointControlVars{`JointControlForce`: 0, `JointTargetPos`: 1, `JointTargetPosCur`: 2, `JointTargetStiff`: 3, `JointTargetVel`: 4, `JointTargetDamp`: 5}
var _JointControlVarsDescMap = map[JointControlVars]string{0: `Joint force and torque inputs`, 1: `JointTargetPos is the position target value input to the model, where 0 is the initial position. For angular joints, this is in radians. This is subject to a graded transition over time, [JointTargetPosCur] has the current effective value.`, 2: `JointTargetPosCur is the current position target value, updated from [JointTargetPos] input using the [Params.ControlDt] time constant.`, 3: `JointTargetStiff determines how strongly the target position is enforced: 0 = not at all; larger = stronger (e.g., 1000 or higher). Set to 0 to allow the joint to be fully flexible.`, 4: `JointTargetVel is the velocity target value. For example, 0 effectively damps joint movement in proportion to Damp parameter.`, 5: `JointTargetDamp determines how strongly the target velocity is enforced: 0 = not at all; larger = stronger (e.g., 1 is reasonable). Set to 0 to allow the joint to be fully flexible.`}
var _JointControlVarsMap = map[JointControlVars]string{0: `JointControlForce`, 1: `JointTargetPos`, 2: `JointTargetPosCur`, 3: `JointTargetStiff`, 4: `JointTargetVel`, 5: `JointTargetDamp`}
// String returns the string representation of this JointControlVars value.
func (i JointControlVars) String() string { return enums.String(i, _JointControlVarsMap) }
// SetString sets the JointControlVars value from its string representation,
// and returns an error if the string is invalid.
func (i *JointControlVars) SetString(s string) error {
return enums.SetString(i, s, _JointControlVarsValueMap, "JointControlVars")
}
// Int64 returns the JointControlVars value as an int64.
func (i JointControlVars) Int64() int64 { return int64(i) }
// SetInt64 sets the JointControlVars value from an int64.
func (i *JointControlVars) SetInt64(in int64) { *i = JointControlVars(in) }
// Desc returns the description of the JointControlVars value.
func (i JointControlVars) Desc() string { return enums.Desc(i, _JointControlVarsDescMap) }
// JointControlVarsValues returns all possible values for the type JointControlVars.
func JointControlVarsValues() []JointControlVars { return _JointControlVarsValues }
// Values returns all possible values for the type JointControlVars.
func (i JointControlVars) Values() []enums.Enum { return enums.Values(_JointControlVarsValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i JointControlVars) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *JointControlVars) UnmarshalText(text []byte) error {
return enums.UnmarshalText(i, text, "JointControlVars")
}
var _DynamicVarsValues = []DynamicVars{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}
// DynamicVarsN is the highest valid value for type DynamicVars, plus one.
//
//gosl:start
const DynamicVarsN DynamicVars = 33
//gosl:end
var _DynamicVarsValueMap = map[string]DynamicVars{`DynBody`: 0, `DynPosX`: 1, `DynPosY`: 2, `DynPosZ`: 3, `DynQuatX`: 4, `DynQuatY`: 5, `DynQuatZ`: 6, `DynQuatW`: 7, `DynVelX`: 8, `DynVelY`: 9, `DynVelZ`: 10, `DynAngVelX`: 11, `DynAngVelY`: 12, `DynAngVelZ`: 13, `DynAccX`: 14, `DynAccY`: 15, `DynAccZ`: 16, `DynAngAccX`: 17, `DynAngAccY`: 18, `DynAngAccZ`: 19, `DynForceX`: 20, `DynForceY`: 21, `DynForceZ`: 22, `DynTorqueX`: 23, `DynTorqueY`: 24, `DynTorqueZ`: 25, `DynDeltaX`: 26, `DynDeltaY`: 27, `DynDeltaZ`: 28, `DynAngDeltaX`: 29, `DynAngDeltaY`: 30, `DynAngDeltaZ`: 31, `DynContactWeight`: 32}
var _DynamicVarsDescMap = map[DynamicVars]string{0: `Index of body in list of bodies.`, 1: `3D position of structural center.`, 2: ``, 3: ``, 4: `Quaternion rotation.`, 5: ``, 6: ``, 7: ``, 8: `Linear velocity.`, 9: ``, 10: ``, 11: `Angular velocity.`, 12: ``, 13: ``, 14: `Linear acceleration.`, 15: ``, 16: ``, 17: `Angular acceleration due to applied torques.`, 18: ``, 19: ``, 20: `Linear force driving linear acceleration (from joints, etc).`, 21: ``, 22: ``, 23: `Torque driving angular acceleration (from joints, etc).`, 24: ``, 25: ``, 26: `Linear deltas. These accumulate over time via StepBodyDeltas.`, 27: ``, 28: ``, 29: `Angular deltas. These accumulate over time via StepBodyDeltas.`, 30: ``, 31: ``, 32: `integrated weight of all contacts`}
var _DynamicVarsMap = map[DynamicVars]string{0: `DynBody`, 1: `DynPosX`, 2: `DynPosY`, 3: `DynPosZ`, 4: `DynQuatX`, 5: `DynQuatY`, 6: `DynQuatZ`, 7: `DynQuatW`, 8: `DynVelX`, 9: `DynVelY`, 10: `DynVelZ`, 11: `DynAngVelX`, 12: `DynAngVelY`, 13: `DynAngVelZ`, 14: `DynAccX`, 15: `DynAccY`, 16: `DynAccZ`, 17: `DynAngAccX`, 18: `DynAngAccY`, 19: `DynAngAccZ`, 20: `DynForceX`, 21: `DynForceY`, 22: `DynForceZ`, 23: `DynTorqueX`, 24: `DynTorqueY`, 25: `DynTorqueZ`, 26: `DynDeltaX`, 27: `DynDeltaY`, 28: `DynDeltaZ`, 29: `DynAngDeltaX`, 30: `DynAngDeltaY`, 31: `DynAngDeltaZ`, 32: `DynContactWeight`}
// String returns the string representation of this DynamicVars value.
func (i DynamicVars) String() string { return enums.String(i, _DynamicVarsMap) }
// SetString sets the DynamicVars value from its string representation,
// and returns an error if the string is invalid.
func (i *DynamicVars) SetString(s string) error {
return enums.SetString(i, s, _DynamicVarsValueMap, "DynamicVars")
}
// Int64 returns the DynamicVars value as an int64.
func (i DynamicVars) Int64() int64 { return int64(i) }
// SetInt64 sets the DynamicVars value from an int64.
func (i *DynamicVars) SetInt64(in int64) { *i = DynamicVars(in) }
// Desc returns the description of the DynamicVars value.
func (i DynamicVars) Desc() string { return enums.Desc(i, _DynamicVarsDescMap) }
// DynamicVarsValues returns all possible values for the type DynamicVars.
func DynamicVarsValues() []DynamicVars { return _DynamicVarsValues }
// Values returns all possible values for the type DynamicVars.
func (i DynamicVars) Values() []enums.Enum { return enums.Values(_DynamicVarsValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i DynamicVars) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *DynamicVars) UnmarshalText(text []byte) error {
return enums.UnmarshalText(i, text, "DynamicVars")
}
var _GPUVarsValues = []GPUVars{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}
// GPUVarsN is the highest valid value for type GPUVars, plus one.
//
//gosl:start
const GPUVarsN GPUVars = 13
//gosl:end
var _GPUVarsValueMap = map[string]GPUVars{`ParamsVar`: 0, `BodiesVar`: 1, `ObjectsVar`: 2, `BodyJointsVar`: 3, `JointsVar`: 4, `JointDoFsVar`: 5, `BodyCollidePairsVar`: 6, `DynamicsVar`: 7, `BroadContactsNVar`: 8, `BroadContactsVar`: 9, `ContactsNVar`: 10, `ContactsVar`: 11, `JointControlsVar`: 12}
var _GPUVarsDescMap = map[GPUVars]string{0: ``, 1: ``, 2: ``, 3: ``, 4: ``, 5: ``, 6: ``, 7: ``, 8: ``, 9: ``, 10: ``, 11: ``, 12: ``}
var _GPUVarsMap = map[GPUVars]string{0: `ParamsVar`, 1: `BodiesVar`, 2: `ObjectsVar`, 3: `BodyJointsVar`, 4: `JointsVar`, 5: `JointDoFsVar`, 6: `BodyCollidePairsVar`, 7: `DynamicsVar`, 8: `BroadContactsNVar`, 9: `BroadContactsVar`, 10: `ContactsNVar`, 11: `ContactsVar`, 12: `JointControlsVar`}
// String returns the string representation of this GPUVars value.
func (i GPUVars) String() string { return enums.String(i, _GPUVarsMap) }
// SetString sets the GPUVars value from its string representation,
// and returns an error if the string is invalid.
func (i *GPUVars) SetString(s string) error {
return enums.SetString(i, s, _GPUVarsValueMap, "GPUVars")
}
// Int64 returns the GPUVars value as an int64.
func (i GPUVars) Int64() int64 { return int64(i) }
// SetInt64 sets the GPUVars value from an int64.
func (i *GPUVars) SetInt64(in int64) { *i = GPUVars(in) }
// Desc returns the description of the GPUVars value.
func (i GPUVars) Desc() string { return enums.Desc(i, _GPUVarsDescMap) }
// GPUVarsValues returns all possible values for the type GPUVars.
func GPUVarsValues() []GPUVars { return _GPUVarsValues }
// Values returns all possible values for the type GPUVars.
func (i GPUVars) Values() []enums.Enum { return enums.Values(_GPUVarsValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i GPUVars) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *GPUVars) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "GPUVars") }
var _JointTypesValues = []JointTypes{0, 1, 2, 3, 4, 5, 6, 7}
// JointTypesN is the highest valid value for type JointTypes, plus one.
//
//gosl:start
const JointTypesN JointTypes = 8
//gosl:end
var _JointTypesValueMap = map[string]JointTypes{`Prismatic`: 0, `Revolute`: 1, `Ball`: 2, `Fixed`: 3, `Free`: 4, `Distance`: 5, `D6`: 6, `PlaneXZ`: 7}
var _JointTypesDescMap = map[JointTypes]string{0: `Prismatic allows translation along a single axis (slider): 1 DoF.`, 1: `Revolute allows rotation about a single axis (axel): 1 DoF.`, 2: `Ball allows rotation about all three axes (3 DoF, quaternion).`, 3: `Fixed locks all relative motion: 0 DoF.`, 4: `Free allows full 6-DoF motion (translation and rotation).`, 5: `Distance keeps two bodies a distance within joint limits: 6 DoF.`, 6: `D6 is a generic 6-DoF joint.`, 7: `PlaneXZ is a version of D6 for navigation in the X-Z plane, which creates 2 linear DoF (X, Z) for movement.`}
var _JointTypesMap = map[JointTypes]string{0: `Prismatic`, 1: `Revolute`, 2: `Ball`, 3: `Fixed`, 4: `Free`, 5: `Distance`, 6: `D6`, 7: `PlaneXZ`}
// String returns the string representation of this JointTypes value.
func (i JointTypes) String() string { return enums.String(i, _JointTypesMap) }
// SetString sets the JointTypes value from its string representation,
// and returns an error if the string is invalid.
func (i *JointTypes) SetString(s string) error {
return enums.SetString(i, s, _JointTypesValueMap, "JointTypes")
}
// Int64 returns the JointTypes value as an int64.
func (i JointTypes) Int64() int64 { return int64(i) }
// SetInt64 sets the JointTypes value from an int64.
func (i *JointTypes) SetInt64(in int64) { *i = JointTypes(in) }
// Desc returns the description of the JointTypes value.
func (i JointTypes) Desc() string { return enums.Desc(i, _JointTypesDescMap) }
// JointTypesValues returns all possible values for the type JointTypes.
func JointTypesValues() []JointTypes { return _JointTypesValues }
// Values returns all possible values for the type JointTypes.
func (i JointTypes) Values() []enums.Enum { return enums.Values(_JointTypesValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i JointTypes) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *JointTypes) UnmarshalText(text []byte) error {
return enums.UnmarshalText(i, text, "JointTypes")
}
var _JointVarsValues = []JointVars{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45}
// JointVarsN is the highest valid value for type JointVars, plus one.
//
//gosl:start
const JointVarsN JointVars = 46
//gosl:end
var _JointVarsValueMap = map[string]JointVars{`JointType`: 0, `JointEnabled`: 1, `JointParentFixed`: 2, `JointNoLinearRotation`: 3, `JointParent`: 4, `JointChild`: 5, `JointPPosX`: 6, `JointPPosY`: 7, `JointPPosZ`: 8, `JointPQuatX`: 9, `JointPQuatY`: 10, `JointPQuatZ`: 11, `JointPQuatW`: 12, `JointCPosX`: 13, `JointCPosY`: 14, `JointCPosZ`: 15, `JointCQuatX`: 16, `JointCQuatY`: 17, `JointCQuatZ`: 18, `JointCQuatW`: 19, `JointLinearDoFN`: 20, `JointAngularDoFN`: 21, `JointDoF1`: 22, `JointDoF2`: 23, `JointDoF3`: 24, `JointDoF4`: 25, `JointDoF5`: 26, `JointDoF6`: 27, `JointPForceX`: 28, `JointPForceY`: 29, `JointPForceZ`: 30, `JointPTorqueX`: 31, `JointPTorqueY`: 32, `JointPTorqueZ`: 33, `JointCForceX`: 34, `JointCForceY`: 35, `JointCForceZ`: 36, `JointCTorqueX`: 37, `JointCTorqueY`: 38, `JointCTorqueZ`: 39, `JointLinLambdaX`: 40, `JointLinLambdaY`: 41, `JointLinLambdaZ`: 42, `JointAngLambdaX`: 43, `JointAngLambdaY`: 44, `JointAngLambdaZ`: 45}
var _JointVarsDescMap = map[JointVars]string{0: `JointType (as an int32 from bits).`, 1: `JointEnabled allows joints to be dynamically enabled.`, 2: `JointParentFixed means that the parent is NOT updated based on the forces and positions for this joint. This can make dynamics cleaner when full accuracy is not necessary.`, 3: `JointNoLinearRotation ignores the rotational (angular) effects of linear joint position constraints (i.e., Coriolis and centrifugal forces) which can otherwise interfere with rotational position constraints in joints with both linear and angular DoFs (e.g., [PlaneXZ], for which this is on by default).`, 4: `JointParent is the dynamic body index for parent body. Can be -1 for a fixed parent for absolute anchor.`, 5: `JointChild is the dynamic body index for child body.`, 6: `relative position of joint, in parent frame. This is prior to parent body rotation.`, 7: ``, 8: ``, 9: `relative orientation of joint, in parent frame. This is prior to parent body rotation.`, 10: ``, 11: ``, 12: ``, 13: `relative position of joint, in child frame. This is prior to child body rotation.`, 14: ``, 15: ``, 16: `relative orientation of joint, in child frame. This is prior to parent body rotation.`, 17: ``, 18: ``, 19: ``, 20: `JointLinearDoFN is the number of linear degrees-of-freedom for the joint.`, 21: `JointAngularDoFN is the number of angular degrees-of-freedom for the joint.`, 22: `indexes in JointDoFs for each DoF`, 23: ``, 24: ``, 25: `angular starts here for Free, Distance, D6`, 26: ``, 27: ``, 28: `Computed parent joint force value.`, 29: ``, 30: ``, 31: `Computed parent joint torque value.`, 32: ``, 33: ``, 34: `Computed child joint force value.`, 35: ``, 36: ``, 37: `Computed child joint torque value.`, 38: ``, 39: ``, 40: `Computed linear lambdas.`, 41: ``, 42: ``, 43: `Computed angular lambdas.`, 44: ``, 45: ``}
var _JointVarsMap = map[JointVars]string{0: `JointType`, 1: `JointEnabled`, 2: `JointParentFixed`, 3: `JointNoLinearRotation`, 4: `JointParent`, 5: `JointChild`, 6: `JointPPosX`, 7: `JointPPosY`, 8: `JointPPosZ`, 9: `JointPQuatX`, 10: `JointPQuatY`, 11: `JointPQuatZ`, 12: `JointPQuatW`, 13: `JointCPosX`, 14: `JointCPosY`, 15: `JointCPosZ`, 16: `JointCQuatX`, 17: `JointCQuatY`, 18: `JointCQuatZ`, 19: `JointCQuatW`, 20: `JointLinearDoFN`, 21: `JointAngularDoFN`, 22: `JointDoF1`, 23: `JointDoF2`, 24: `JointDoF3`, 25: `JointDoF4`, 26: `JointDoF5`, 27: `JointDoF6`, 28: `JointPForceX`, 29: `JointPForceY`, 30: `JointPForceZ`, 31: `JointPTorqueX`, 32: `JointPTorqueY`, 33: `JointPTorqueZ`, 34: `JointCForceX`, 35: `JointCForceY`, 36: `JointCForceZ`, 37: `JointCTorqueX`, 38: `JointCTorqueY`, 39: `JointCTorqueZ`, 40: `JointLinLambdaX`, 41: `JointLinLambdaY`, 42: `JointLinLambdaZ`, 43: `JointAngLambdaX`, 44: `JointAngLambdaY`, 45: `JointAngLambdaZ`}
// String returns the string representation of this JointVars value.
func (i JointVars) String() string { return enums.String(i, _JointVarsMap) }
// SetString sets the JointVars value from its string representation,
// and returns an error if the string is invalid.
func (i *JointVars) SetString(s string) error {
return enums.SetString(i, s, _JointVarsValueMap, "JointVars")
}
// Int64 returns the JointVars value as an int64.
func (i JointVars) Int64() int64 { return int64(i) }
// SetInt64 sets the JointVars value from an int64.
func (i *JointVars) SetInt64(in int64) { *i = JointVars(in) }
// Desc returns the description of the JointVars value.
func (i JointVars) Desc() string { return enums.Desc(i, _JointVarsDescMap) }
// JointVarsValues returns all possible values for the type JointVars.
func JointVarsValues() []JointVars { return _JointVarsValues }
// Values returns all possible values for the type JointVars.
func (i JointVars) Values() []enums.Enum { return enums.Values(_JointVarsValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i JointVars) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *JointVars) UnmarshalText(text []byte) error {
return enums.UnmarshalText(i, text, "JointVars")
}
var _JointDoFVarsValues = []JointDoFVars{0, 1, 2, 3, 4}
// JointDoFVarsN is the highest valid value for type JointDoFVars, plus one.
//
//gosl:start
const JointDoFVarsN JointDoFVars = 5
//gosl:end
var _JointDoFVarsValueMap = map[string]JointDoFVars{`JointAxisX`: 0, `JointAxisY`: 1, `JointAxisZ`: 2, `JointLimitLower`: 3, `JointLimitUpper`: 4}
var _JointDoFVarsDescMap = map[JointDoFVars]string{0: `axis of articulation for the DoF`, 1: ``, 2: ``, 3: `joint limits`, 4: ``}
var _JointDoFVarsMap = map[JointDoFVars]string{0: `JointAxisX`, 1: `JointAxisY`, 2: `JointAxisZ`, 3: `JointLimitLower`, 4: `JointLimitUpper`}
// String returns the string representation of this JointDoFVars value.
func (i JointDoFVars) String() string { return enums.String(i, _JointDoFVarsMap) }
// SetString sets the JointDoFVars value from its string representation,
// and returns an error if the string is invalid.
func (i *JointDoFVars) SetString(s string) error {
return enums.SetString(i, s, _JointDoFVarsValueMap, "JointDoFVars")
}
// Int64 returns the JointDoFVars value as an int64.
func (i JointDoFVars) Int64() int64 { return int64(i) }
// SetInt64 sets the JointDoFVars value from an int64.
func (i *JointDoFVars) SetInt64(in int64) { *i = JointDoFVars(in) }
// Desc returns the description of the JointDoFVars value.
func (i JointDoFVars) Desc() string { return enums.Desc(i, _JointDoFVarsDescMap) }
// JointDoFVarsValues returns all possible values for the type JointDoFVars.
func JointDoFVarsValues() []JointDoFVars { return _JointDoFVarsValues }
// Values returns all possible values for the type JointDoFVars.
func (i JointDoFVars) Values() []enums.Enum { return enums.Values(_JointDoFVarsValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i JointDoFVars) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *JointDoFVars) UnmarshalText(text []byte) error {
return enums.UnmarshalText(i, text, "JointDoFVars")
}
var _ShapesValues = []Shapes{0, 1, 2, 3, 4, 5}
// ShapesN is the highest valid value for type Shapes, plus one.
//
//gosl:start
const ShapesN Shapes = 6
//gosl:end
var _ShapesValueMap = map[string]Shapes{`Plane`: 0, `Sphere`: 1, `Capsule`: 2, `Cylinder`: 3, `Box`: 4, `Cone`: 5}
var _ShapesDescMap = map[Shapes]string{0: `Plane cannot be a dynamic shape, but is most efficient for collision computations. Use size = 0 for an infinite plane. Natively extends in the X-Z plane: SizeX x SizeZ.`, 1: `Sphere. SizeX is the radius.`, 2: `Capsule is a cylinder with half-spheres on the ends. Natively oriented vertically along the Y axis. SizeX = radius of end caps, SizeY = _total_ half-height (i.e., SizeX + half-height of cylindrical portion, must be >= SizeX). This parameterization allows joint offsets to be SizeY, and direct swapping of shape across Box and Cylinder with same total extent.`, 3: `Cylinder, natively oriented vertically along the Y axis. SizeX = radius, SizeY = half-height of Y axis Cylinder does not support most collisions and is thus not recommended where collision data is needed.`, 4: `Box is a 3D rectalinear shape. The sizes are _half_ sizes along each dimension, relative to the center.`, 5: `Cone is like a cylinder with the top radius = 0, oriented up. SizeX = bottom radius, SizeY = half-height in Y. Cone does not support any collisions and is not recommended for interacting bodies.`}
var _ShapesMap = map[Shapes]string{0: `Plane`, 1: `Sphere`, 2: `Capsule`, 3: `Cylinder`, 4: `Box`, 5: `Cone`}
// String returns the string representation of this Shapes value.
func (i Shapes) String() string { return enums.String(i, _ShapesMap) }
// SetString sets the Shapes value from its string representation,
// and returns an error if the string is invalid.
func (i *Shapes) SetString(s string) error { return enums.SetString(i, s, _ShapesValueMap, "Shapes") }
// Int64 returns the Shapes value as an int64.
func (i Shapes) Int64() int64 { return int64(i) }
// SetInt64 sets the Shapes value from an int64.
func (i *Shapes) SetInt64(in int64) { *i = Shapes(in) }
// Desc returns the description of the Shapes value.
func (i Shapes) Desc() string { return enums.Desc(i, _ShapesDescMap) }
// ShapesValues returns all possible values for the type Shapes.
func ShapesValues() []Shapes { return _ShapesValues }
// Values returns all possible values for the type Shapes.
func (i Shapes) Values() []enums.Enum { return enums.Values(_ShapesValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i Shapes) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *Shapes) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Shapes") }
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package balls
//go:generate core generate -add-types
import (
"math/rand/v2"
"cogentcore.org/core/colors"
"cogentcore.org/core/math32"
"cogentcore.org/core/tree"
_ "cogentcore.org/lab/gosl/slbool/slboolcore" // include to get gui views
"cogentcore.org/lab/physics"
"cogentcore.org/lab/physics/phyxyz"
)
// Balls has sim params
type Balls struct {
// Number of balls: if collide, then run out of memory above 1000 or so
NBalls int
// Collide is whether the balls collide with each other
Collide bool
// Size of each ball (m)
Size float32
// Mass of each ball (kg)
Mass float32
// size of the box (m)
Width float32
Depth float32
Height float32
Thick float32
Bounce float32
Friction float32
FrictionTortion float32
FrictionRolling float32
}
func (b *Balls) Defaults() {
b.NBalls = 1000
b.Collide = true
b.Size = 0.2
b.Mass = 0.1
b.Width = 50
b.Depth = 50
b.Height = 20
b.Thick = .1
b.Bounce = 0.5
b.Friction = 0
b.FrictionTortion = 0
b.FrictionRolling = 0
}
func Config(b tree.Node) {
ed := phyxyz.NewEditor(b)
bs := &Balls{}
bs.Defaults()
ed.CameraPos = math32.Vec3(0, bs.Width, bs.Width)
ed.SetUserParams(bs)
ed.SetConfigFunc(func() {
ml := ed.Model
ml.Params[0].SubSteps = 100
ml.Params[0].Dt = 0.001
// ml.GPU = false
// ml.ReportTotalKE = true
sc := ed.Scene
rot := math32.NewQuatIdentity()
sc.NewBody(ml, "floor", physics.Plane, "#D0D0D080", math32.Vec3(0, 0, 0),
math32.Vec3(0, 0, 0), rot)
hw := bs.Width / 2
hd := bs.Depth / 2
hh := bs.Height / 2
ht := bs.Thick / 2
sc.NewBody(ml, "back-wall", physics.Box, "#0000FFA0", math32.Vec3(hw, hh, ht),
math32.Vec3(0, hh, -hd), rot)
sc.NewBody(ml, "left-wall", physics.Box, "#FF0000A0", math32.Vec3(ht, hh, hd),
math32.Vec3(-hw, hh, 0), rot)
sc.NewBody(ml, "right-wall", physics.Box, "#00FF00A0", math32.Vec3(ht, hh, hd),
math32.Vec3(hw, hh, 0), rot)
sc.NewBody(ml, "front-wall", physics.Box, "#FFFF00A0", math32.Vec3(hw, hh, ht),
math32.Vec3(0, hh, hd), rot)
box := bs.Width * .9
size := bs.Size
for i := range bs.NBalls {
ht := rand.Float32() * bs.Height
x := rand.Float32()*box - 0.5*box
z := rand.Float32()*box - 0.5*box
clr := colors.Names[i%len(colors.Names)]
bl := sc.NewDynamic(ml, "ball", physics.Sphere, clr, bs.Mass, math32.Vec3(size, size, size),
math32.Vec3(x, size+ht, z), rot)
if !bs.Collide {
physics.SetBodyGroup(bl.BodyIndex, int32(i+1)) // only collide within same group
}
bl.SetBodyBounce(bs.Bounce)
bl.SetBodyFriction(bs.Friction)
bl.SetBodyFrictionTortion(bs.FrictionTortion)
bl.SetBodyFrictionRolling(bs.FrictionRolling)
}
})
}
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"cogentcore.org/core/core"
"cogentcore.org/lab/physics/examples/balls"
)
func main() {
b := core.NewBody("balls").SetTitle("Physics Balls")
balls.Config(b)
b.RunMainWindow()
}
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package collide
//go:generate core generate -add-types
import (
"cogentcore.org/core/core"
"cogentcore.org/core/math32"
"cogentcore.org/core/tree"
_ "cogentcore.org/lab/gosl/slbool/slboolcore" // include to get gui views
"cogentcore.org/lab/physics"
"cogentcore.org/lab/physics/phyxyz"
)
// Collide has sim params
type Collide struct {
// Shape of left body
ShapeA physics.Shapes
// Shape of right body
ShapeB physics.Shapes
// Size of left body (radius, capsule, cylinder, box are 2x taller)
SizeA float32
// Size of right body (radius, capsule, cylinder, box are 2x taller)
SizeB float32
// Mass of left object: if lighter than B, it will bounce back more.
MassA float32
// Mass of right object: if lighter than B, it will move faster.
MassB float32
// Z (depth) position: offset to get different collision angles.
ZposA float32
// Z (depth) position: offset to get different collision angles.
ZposB float32
// Mass of the pusher panel: if lighter, it transfers less energy.
PushMass float32
// Friction is for sliding: around 0.01 seems pretty realistic
Friction float32
// FrictionTortion is for rotating. Not generally relevant here.
FrictionTortion float32
// FrictionRolling is for rolling: around 0.01 seems pretty realistic
FrictionRolling float32
}
func (cl *Collide) Defaults() {
cl.ShapeA = physics.Sphere
cl.ShapeB = physics.Sphere
cl.SizeA = 0.5
cl.SizeB = 0.5
cl.MassA = 1
cl.MassB = 1
cl.PushMass = 1
cl.Friction = 0.01
cl.FrictionTortion = 0.01
cl.FrictionRolling = 0.01
}
func Config(b tree.Node) {
ed := phyxyz.NewEditor(b)
ed.CameraPos = math32.Vec3(0, 20, 20)
cl := &Collide{}
cl.Defaults()
ed.SetUserParams(cl)
core.NewText(b).SetText("Pusher target position:")
pos := float32(3)
sld := core.NewSlider(b).SetMin(0).SetMax(5).SetStep(.1).SetEnforceStep(true)
core.Bind(&pos, sld)
ed.SetConfigFunc(func() {
ml := ed.Model
ml.GPU = false
// ml.ReportTotalKE = true
sc := ed.Scene
rot := math32.NewQuatIdentity()
fl := sc.NewBody(ml, "floor", physics.Plane, "#D0D0D080", math32.Vec3(0, 0, 0), math32.Vec3(0, 0, 0), rot)
physics.SetBodyFriction(fl.BodyIndex, cl.Friction)
physics.SetBodyFrictionRolling(fl.BodyIndex, cl.FrictionRolling)
physics.SetBodyFrictionTortion(fl.BodyIndex, cl.FrictionTortion)
hhA := 2 * cl.SizeA
hhB := 2 * cl.SizeB
if cl.ShapeA == physics.Sphere {
hhA = cl.SizeA
}
if cl.ShapeB == physics.Sphere {
hhB = cl.SizeB
}
ba := sc.NewDynamic(ml, "A", cl.ShapeA, "blue", cl.MassA, math32.Vec3(cl.SizeA, 2*cl.SizeA, cl.SizeA), math32.Vec3(-5, hhA, cl.ZposA), rot)
physics.SetBodyFriction(ba.BodyIndex, cl.Friction)
physics.SetBodyFrictionRolling(ba.BodyIndex, cl.FrictionRolling)
physics.SetBodyFrictionTortion(ba.BodyIndex, cl.FrictionTortion)
bb := sc.NewDynamic(ml, "B", cl.ShapeB, "red", cl.MassB, math32.Vec3(cl.SizeB, 2*cl.SizeB, cl.SizeB), math32.Vec3(0, hhB, cl.ZposB), rot)
physics.SetBodyFriction(bb.BodyIndex, cl.Friction)
physics.SetBodyFrictionRolling(bb.BodyIndex, cl.FrictionRolling)
physics.SetBodyFrictionTortion(bb.BodyIndex, cl.FrictionTortion)
push := sc.NewDynamic(ml, "push", physics.Box, "grey", cl.PushMass, math32.Vec3(.1, 2, 2), math32.Vec3(-8, 2, 0), rot)
ml.NewObject()
sc.NewJointPrismatic(ml, nil, push, math32.Vec3(-8, 0, 0), math32.Vec3(0, -2, 0), math32.Vec3(1, 0, 0))
})
ed.SetControlFunc(func(timeStep int) {
physics.SetJointTargetPos(0, 0, pos, 100)
physics.SetJointTargetVel(0, 0, 0, 20)
})
}
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"cogentcore.org/core/core"
"cogentcore.org/lab/physics/examples/collide"
)
func main() {
b := core.NewBody("collide").SetTitle("Physics Collide")
collide.Config(b)
b.RunMainWindow()
}
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
//go:generate core generate -add-types
import (
"fmt"
"cogentcore.org/core/colors"
"cogentcore.org/core/core"
"cogentcore.org/core/math32"
_ "cogentcore.org/lab/gosl/slbool/slboolcore" // include to get gui views
"cogentcore.org/lab/physics"
"cogentcore.org/lab/physics/builder"
"cogentcore.org/lab/physics/phyxyz"
)
// Pendula has sim params
type Pendula struct {
// Number of bar elements to add to the pendulum. More interesting the more you add!
NPendula int
// StartVert starts the pendulum in the vertical orientation
// (else horizontal, so it has somewhere to go). Need to add force if vertical.
StartVert bool
// TargetDegFromVert is the target number of degrees off of vertical
// for each joint. Critical for this to not be 0 for StartVert.
TargetDegFromVert int
// Timestep in msec to add a force
ForceOn int
// Timestep in msec to stop adding force
ForceOff int
// Force to add
Force float32
// half-size of the pendulum elements.
HSize math32.Vector3
// Mass of each bar (kg)
Mass float32
// do the elements collide with each other? this is currently bad!
Collide bool
// Stiff is the strength of the positional constraint to keep
// each bar in a vertical position.
Stiff float32
// Damp is the strength of the velocity constraint to keep each
// bar not moving.
Damp float32
}
func (b *Pendula) Defaults() {
b.NPendula = 2
b.HSize.Set(0.05, .2, 0.05)
b.Mass = 0.1
b.ForceOn = 100
b.ForceOff = 102
b.Force = 0
b.Damp = 0
b.Stiff = 0
}
func main() {
b := core.NewBody("test1").SetTitle("Physics Pendula")
ed := phyxyz.NewEditor(b)
ed.CameraPos = math32.Vec3(0, 3, 3)
ps := &Pendula{}
ps.Defaults()
ed.SetUserParams(ps)
bld := builder.NewBuilder()
var botJoint *builder.Joint
ed.SetConfigFunc(func() {
bld.Reset()
wld := bld.NewWorld()
obj := wld.NewObject()
ml := ed.Model
ml.GPU = false
// ml.ReportTotalKE = true
sc := ed.Scene
rot := math32.NewQuatIdentity()
rleft := math32.NewQuatAxisAngle(math32.Vec3(0, 0, 1), -math32.Pi/2)
if ps.StartVert {
rleft = rot
}
stY := 4 * ps.HSize.Y
x := -ps.HSize.Y
y := stY
if ps.StartVert {
x = 0
y -= ps.HSize.Y
}
pb := obj.NewDynamicSkin(sc, "top", physics.Capsule, "blue", ps.Mass, ps.HSize, math32.Vec3(x, y, 0), rleft)
if !ps.Collide {
pb.SetGroup(1)
}
targ := math32.DegToRad(float32(ps.TargetDegFromVert))
jd := obj.NewJointRevolute(nil, pb, math32.Vec3(0, stY, 0), math32.Vec3(0, ps.HSize.Y, 0), math32.Vec3(0, 0, 1))
jd.DoF(0).Init.SetPos(targ).SetStiff(ps.Stiff).SetVel(0).SetDamp(ps.Damp)
for i := 1; i < ps.NPendula; i++ {
clr := colors.Names[12+i%len(colors.Names)]
x := -float32(i)*ps.HSize.Y*2 - ps.HSize.Y
y := stY
if ps.StartVert {
y = stY + x
x = 0
}
cb := obj.NewDynamicSkin(sc, "child", physics.Capsule, clr, ps.Mass, ps.HSize, math32.Vec3(x, y, 0), rleft)
if !ps.Collide {
cb.SetGroup(1 + i)
}
jd = obj.NewJointRevolute(pb, cb, math32.Vec3(0, -ps.HSize.Y, 0), math32.Vec3(0, ps.HSize.Y, 0), math32.Vec3(0, 0, 1))
jd.DoF(0).Init.SetPos(targ).SetStiff(ps.Stiff).SetVel(0).SetDamp(ps.Damp)
pb = cb
botJoint = jd
}
bld.ReplicateWorld(nil, 0, 2, 2, math32.Vec3(0, 0, -1), math32.Vec3(1, 0, 0))
bld.Build(ml, sc)
})
ed.SetControlFunc(func(timeStep int) {
if timeStep >= ps.ForceOn && timeStep < ps.ForceOff {
fmt.Println(timeStep, "\tforce on:", ps.Force)
physics.SetJointControlForce(botJoint.JointIndex, 0, ps.Force)
} else {
physics.SetJointControlForce(botJoint.JointIndex, 0, 0)
}
})
b.RunMainWindow()
}
// Copyright (c) 2019, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package virtroom
//go:generate core generate -add-types
import (
"fmt"
"image"
"math/rand/v2"
"cogentcore.org/core/base/iox/imagex"
"cogentcore.org/core/colors"
"cogentcore.org/core/colors/colormap"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/gpu"
"cogentcore.org/core/icons"
"cogentcore.org/core/math32"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/abilities"
"cogentcore.org/core/tree"
"cogentcore.org/core/xyz"
"cogentcore.org/core/xyz/xyzcore"
"cogentcore.org/lab/physics"
"cogentcore.org/lab/physics/builder"
"cogentcore.org/lab/physics/phyxyz"
)
// Emer is the robot agent in the environment.
type Emer struct {
// if true, emer is angry: changes face color
Angry bool
// VestibHRightEar is the horizontal rotation vestibular signal measured
// at the right ear, averaged over the steps.
VestibHRightEar float32
// full height of emer
Height float32
// emer object
Obj *builder.Object `display:"-"`
// PlaneXZ joint for controlling 2D position.
XZ *builder.Joint
// ball joint for the neck.
Neck *builder.Joint
// Right eye of emer
EyeR *builder.Body `display:"-"`
}
func (em *Emer) Defaults() {
em.Height = 1
}
// Env encapsulates the virtual environment
type Env struct { //types:add
// Emer state
Emer Emer `new-window:"+"`
// Stiffness for actions
Stiff float32
// how far to move every step
MoveStep float32
// how far to rotate every step
RotStep float32
// number of model steps to take
ModelSteps int
// width of room
Width float32
// depth of room
Depth float32
// height of room
Height float32
// thickness of walls of room
Thick float32
// current depth map
DepthVals []float32
// offscreen render camera settings
Camera phyxyz.Camera
// color map to use for rendering depth map
DepthMap core.ColorMapName
// The core physics elements: Model, Builder, Scene
Physics builder.Physics
// 3D visualization of the Scene
SceneEditor *xyzcore.SceneEditor
// snapshot image
EyeRImg *core.Image `display:"-"`
// depth map image
DepthImage *core.Image `display:"-"`
}
func (ev *Env) Defaults() {
ev.Emer.Defaults()
ev.Width = 10
ev.Depth = 15
ev.Height = 2
ev.Thick = 0.2
ev.Stiff = 1000
ev.MoveStep = ev.Emer.Height * .2
ev.RotStep = 15
ev.ModelSteps = 100
ev.DepthMap = core.ColorMapName("ColdHot")
ev.Camera.Defaults()
ev.Camera.FOV = 90
}
func (ev *Env) MakeModel(sc *xyz.Scene) {
ev.Physics.Model = physics.NewModel()
ev.Physics.Builder = builder.NewBuilder()
ev.Physics.Model.GPU = false
ev.Physics.Model.GetContacts = true
sc.Background = colors.Scheme.Select.Container
xyz.NewAmbient(sc, "ambient", 0.3, xyz.DirectSun)
dir := xyz.NewDirectional(sc, "dir", 1, xyz.DirectSun)
dir.Pos.Set(0, 2, 1) // default: 0,1,1 = above and behind us (we are at 0,0,X)
ev.Physics.Scene = phyxyz.NewScene(sc)
wl := ev.Physics.Builder.NewGlobalWorld()
ev.MakeRoom(wl, "room1", ev.Width, ev.Depth, ev.Height, ev.Thick)
ew := ev.Physics.Builder.NewWorld()
ev.MakeEmer(ew, &ev.Emer, "emer")
// vw.Physics.Builder.ReplicateWorld(vw.Physics.Scene, 1, 1, 8) // 1x8
ev.Physics.Build()
params := physics.GetParams(0)
// params.ControlDt = 0.1
params.Dt = 0.001
params.SubSteps = 1
params.Gravity.Y = 0 // note: critical to not have gravity for full rotation
// https://github.com/cogentcore/lab/issues/47
// params.MaxForce = 1.0e3
// params.AngularDamping = 0.5
// params.SubSteps = 1
}
// Initstate reinitializes the physics model state.
func (ev *Env) InitState() { //types:add
ev.Physics.InitState()
ev.UpdateView()
}
// ConfigView3D makes the 3D view
func (ev *Env) ConfigView3D(sc *xyz.Scene) {
// sc.MultiSample = 1 // we are using depth grab so we need this = 1
}
// RenderEyeImg returns a snapshot from the perspective of Emer's right eye
func (ev *Env) RenderEyeImg() image.Image {
if ev.Emer.EyeR == nil {
return nil
}
return ev.Physics.Scene.RenderFrom(ev.Emer.EyeR.Skin, &ev.Camera)[0]
}
// GrabEyeImage takes a snapshot from the perspective of Emer's right eye
func (ev *Env) GrabEyeImage() { //types:add
img := ev.RenderEyeImg()
if img != nil {
ev.EyeRImg.SetImage(img)
ev.EyeRImg.NeedsRender()
}
// depth, err := ev.View3D.DepthImage()
// if err == nil && depth != nil {
// ev.DepthVals = depth
// ev.ViewDepth(depth)
// }
}
// Sensors reads sensors at various key points on body.
func (ev *Env) Sensors() {
ev.Emer.Obj.RunSensors()
}
// ViewDepth updates depth bitmap with depth data
func (ev *Env) ViewDepth(depth []float32) {
cmap := colormap.AvailableMaps[string(ev.DepthMap)]
img := image.NewRGBA(image.Rectangle{Max: ev.Camera.Size})
ev.DepthImage.SetImage(img)
phyxyz.DepthImage(img, depth, cmap, &ev.Camera)
ev.DepthImage.NeedsRender()
}
// UpdateView tells 3D view it needs to update.
func (ev *Env) UpdateView() {
if ev.SceneEditor.IsVisible() {
ev.SceneEditor.NeedsRender()
}
}
// ModelStep does one step of the physics model.
func (ev *Env) ModelStep() { //types:add
ev.Emer.VestibHRightEar = 0
for range ev.ModelSteps { // we're computing average over sensor data
ev.Physics.StepQuiet(1)
ev.Sensors()
}
ev.Physics.Step(1)
ev.Emer.VestibHRightEar /= float32(ev.ModelSteps)
fmt.Println("vestibH right ear:", ev.Emer.VestibHRightEar)
ev.Emer.Angry = false
ctN := physics.ContactsN.Value(0)
for ci := range ctN {
ca := physics.GetContactA(ci)
cb := physics.GetContactB(ci)
if ca == 0 || cb == 0 { // ignore the floor
continue
}
if ev.Emer.Obj.HasBodyIndex(ca, cb) {
ev.Emer.Angry = true
// fmt.Println("hit wall: turn around!")
rot := 100.0 + 90.0*rand.Float32()
ev.Emer.XZ.AddTargetAngle(2, rot, ev.Stiff)
}
}
ev.GrabEyeImage()
ev.UpdateView()
}
// StepForward moves Emer forward in current facing direction one step
func (ev *Env) StepForward() { //types:add
// doesn't integrate well with joints..
// ev.Emer.MoveOnAxisBody(0, 0, 0, 1, -ev.MoveStep)
// ev.Emer.PoseToPhysics()
ang := math32.Pi*.5 - ev.Emer.XZ.DoF(2).Current.Pos
// ang := float32(math32.Pi * .5)
ev.Emer.XZ.AddPlaneXZPos(ang, -ev.MoveStep, ev.Stiff)
ev.ModelStep()
}
// StepBackward moves Emer backward in current facing direction one step.
func (ev *Env) StepBackward() { //types:add
ang := math32.Pi*.5 - ev.Emer.XZ.DoF(2).Current.Pos
// ang := float32(math32.Pi * .5)
ev.Emer.XZ.AddPlaneXZPos(ang, ev.MoveStep, ev.Stiff)
ev.ModelStep()
}
// RotBodyLeft rotates emer left.
func (ev *Env) RotBodyLeft() { //types:add
ev.Emer.XZ.AddTargetAngle(2, ev.RotStep, ev.Stiff)
ev.ModelStep()
}
// RotBodyRight rotates emer right.
func (ev *Env) RotBodyRight() { //types:add
ev.Emer.XZ.AddTargetAngle(2, -ev.RotStep, ev.Stiff)
ev.ModelStep()
}
// RotHeadLeft rotates emer left.
func (ev *Env) RotHeadLeft() { //types:add
ev.Emer.Neck.AddTargetAngle(0, ev.RotStep, ev.Stiff)
ev.ModelStep()
}
// RotHeadRight rotates emer right.
func (ev *Env) RotHeadRight() { //types:add
ev.Emer.Neck.AddTargetAngle(0, -ev.RotStep, ev.Stiff)
ev.ModelStep()
}
// MakeRoom constructs a new room with given params
func (ev *Env) MakeRoom(wl *builder.World, name string, width, depth, height, thick float32) {
rot := math32.NewQuatIdentity()
hw := width / 2
hd := depth / 2
hh := height / 2
ht := thick / 2
obj := wl.NewObject()
sc := ev.Physics.Scene
obj.NewBodySkin(sc, name+"_floor", physics.Plane, "grey", math32.Vec3(hw, 0, hd),
math32.Vec3(0, 0, 0), rot)
obj.NewBodySkin(sc, name+"_back-wall", physics.Box, "blue", math32.Vec3(hw, hh, ht),
math32.Vec3(0, hh, -hd), rot)
obj.NewBodySkin(sc, name+"_left-wall", physics.Box, "red", math32.Vec3(ht, hh, hd),
math32.Vec3(-hw, hh, 0), rot)
obj.NewBodySkin(sc, name+"_right-wall", physics.Box, "green", math32.Vec3(ht, hh, hd),
math32.Vec3(hw, hh, 0), rot)
obj.NewBodySkin(sc, name+"_front-wall", physics.Box, "yellow", math32.Vec3(hw, hh, ht),
math32.Vec3(0, hh, hd), rot)
}
// MakeEmer constructs a new Emer virtual robot of given height (e.g., 1).
func (ev *Env) MakeEmer(wl *builder.World, em *Emer, name string) {
hh := em.Height / 2
hw := hh * .4
hd := hh * .15
headsz := hd * 1.5
eyesz := headsz * .2
mass := float32(1) // kg
rot := math32.NewQuatIdentity()
obj := wl.NewObject()
em.Obj = obj
sc := ev.Physics.Scene
off := float32(0.01) // note: critical to float slightly off the plane!
// otherwise, this is where the problems in rotation come in.
emr := obj.NewDynamicSkin(sc, name+"_body", physics.Box, "purple", mass, math32.Vec3(hw, hh, hd), math32.Vec3(0, hh+off, 0), rot)
// body := physics.NewCapsule(emr, "body", math32.Vec3(0, hh, 0), hh, hw)
// body := physics.NewCylinder(emr, "body", math32.Vec3(0, hh, 0), hh, hw)
em.XZ = obj.NewJointPlaneXZ(nil, emr, math32.Vec3(0, 0, 0), math32.Vec3(0, -hh, 0))
// emr.Group = 0 // no collide (temporary)
headPos := math32.Vec3(0, 2*hh+headsz+off, 0)
head := obj.NewDynamicSkin(sc, name+"_head", physics.Box, "tan", mass*.1, math32.Vec3(headsz, headsz, headsz), headPos, rot)
// head.Group = 0
hdsk := head.Skin
hdsk.InitSkin = func(sld *xyz.Solid) {
hdsk.BoxInit(sld)
sld.Updater(func() {
clr := hdsk.Color
if ev.Emer.Angry {
clr = "pink"
}
hdsk.UpdateColor(clr, sld)
})
}
// em.Neck = obj.NewJointBall(emr, head, math32.Vec3(0, hh, 0), math32.Vec3(0, -headsz, 0))
em.Neck = obj.NewJointRevolute(emr, head, math32.Vec3(0, hh, 0), math32.Vec3(0, -headsz, 0), math32.Vec3(0, 1, 0))
em.Neck.ParentFixed = true
em.Neck.NoLinearRotation = true
// obj.NewJointFixed(emr, head, math32.Vec3(0, hh, 0), math32.Vec3(0, -headsz, 0))
obj.NewSensor(func(obj *builder.Object) {
hd := obj.Body(1)
av := physics.AngularVelocityAt(hd.DynamicIndex, math32.Vec3(headsz, 0, 0), math32.Vec3(0, 1, 0))
em.VestibHRightEar += av.Z // shows up in Z
})
eyeoff := math32.Vec3(-headsz*.6, headsz*.1, -(headsz + eyesz*.3))
bd := obj.NewDynamicSkin(sc, name+"_eye-l", physics.Box, "green", mass*.001, math32.Vec3(eyesz, eyesz*.5, eyesz*.2), headPos.Add(eyeoff), rot)
// bd.Group = 0
ej := obj.NewJointFixed(head, bd, eyeoff, math32.Vec3(0, 0, -eyesz*.3))
ej.ParentFixed = true
eyeoff.X = headsz * .6
em.EyeR = obj.NewDynamicSkin(sc, name+"_eye-r", physics.Box, "green", mass*.001, math32.Vec3(eyesz, eyesz*.5, eyesz*.2), headPos.Add(eyeoff), rot)
// em.EyeR.Group = 0
ej = obj.NewJointFixed(head, em.EyeR, eyeoff, math32.Vec3(0, 0, -eyesz*.3))
ej.ParentFixed = true
}
func (ev *Env) ConfigGUI(b tree.Node) {
// vgpu.Debug = true
tb := core.NewToolbar(b)
tb.Maker(ev.MakeToolbar)
split := core.NewSplits(b)
core.NewForm(split).SetStruct(ev)
imfr := core.NewFrame(split)
tbvw := core.NewTabs(split)
scfr, _ := tbvw.NewTab("3D View")
split.SetSplits(.2, .2, .6)
//////// 3D Scene
etb := core.NewToolbar(scfr)
_ = etb
ev.SceneEditor = xyzcore.NewSceneEditor(scfr)
ev.SceneEditor.UpdateWidget()
sc := ev.SceneEditor.SceneXYZ()
ev.MakeModel(sc)
// local toolbar for manipulating emer
// etb.Maker(phyxyz.MakeStateToolbar(&ev.Emer.Rel, func() {
// ev.World.Update()
// ev.SceneEditor.NeedsRender()
// }))
sc.Camera.Pose.Pos = math32.Vec3(0, 40, 3.5)
sc.Camera.LookAt(math32.Vec3(0, 5, 0), math32.Vec3(0, 1, 0))
sc.SaveCamera("3")
sc.Camera.Pose.Pos = math32.Vec3(0, 20, 30)
sc.Camera.LookAt(math32.Vec3(0, 5, 0), math32.Vec3(0, 1, 0))
sc.SaveCamera("2")
sc.Camera.Pose.Pos = math32.Vec3(-.86, .97, 2.7)
sc.Camera.LookAt(math32.Vec3(0, .8, 0), math32.Vec3(0, 1, 0))
sc.SaveCamera("1")
sc.SaveCamera("default")
//////// Image
imfr.Styler(func(s *styles.Style) {
s.Direction = styles.Column
})
core.NewText(imfr).SetText("Right Eye Image:")
ev.EyeRImg = core.NewImage(imfr)
ev.EyeRImg.SetName("eye-r-img")
ev.EyeRImg.Image = image.NewRGBA(image.Rectangle{Max: ev.Camera.Size})
core.NewText(imfr).SetText("Right Eye Depth:")
ev.DepthImage = core.NewImage(imfr)
ev.DepthImage.SetName("depth-img")
ev.DepthImage.Image = image.NewRGBA(image.Rectangle{Max: ev.Camera.Size})
}
func (ev *Env) MakeToolbar(p *tree.Plan) {
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(ev.InitState).SetText("Init").SetIcon(icons.Update)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(ev.GrabEyeImage).SetText("Grab Image").SetIcon(icons.Image)
})
tree.Add(p, func(w *core.Separator) {})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(ev.ModelStep).SetText("Step").SetIcon(icons.SkipNext).
Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.RepeatClickable)
})
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(ev.StepForward).SetText("Fwd").SetIcon(icons.SkipNext).
Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.RepeatClickable)
})
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(ev.StepBackward).SetText("Bkw").SetIcon(icons.SkipPrevious).
Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.RepeatClickable)
})
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(ev.RotBodyLeft).SetText("Body Left").SetIcon(icons.KeyboardArrowLeft).
Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.RepeatClickable)
})
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(ev.RotBodyRight).SetText("Body Right").SetIcon(icons.KeyboardArrowRight).
Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.RepeatClickable)
})
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(ev.RotHeadLeft).SetText("Head Left").SetIcon(icons.KeyboardArrowLeft).
Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.RepeatClickable)
})
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(ev.RotHeadRight).SetText("Head Right").SetIcon(icons.KeyboardArrowRight).
Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.RepeatClickable)
})
})
tree.Add(p, func(w *core.Separator) {})
tree.Add(p, func(w *core.Button) {
w.SetText("README").SetIcon(icons.FileMarkdown).
SetTooltip("Open browser on README.").
OnClick(func(e events.Event) {
core.TheApp.OpenURL("https://github.com/cogentcore/core/blob/master/xyz/examples/physics/README.md")
})
})
}
func (ev *Env) NoGUIRun() {
gp, dev, err := gpu.NoDisplayGPU()
if err != nil {
panic(err)
}
sc := phyxyz.NoDisplayScene(gp, dev)
ev.MakeModel(sc)
img := ev.RenderEyeImg()
if img != nil {
imagex.Save(img, "eyer_0.png")
}
}
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"os"
"cogentcore.org/core/core"
"cogentcore.org/lab/physics/examples/virtroom"
)
var NoGUI bool
func main() {
if len(os.Args) > 1 && os.Args[1] == "-nogui" {
NoGUI = true
}
ev := &virtroom.Env{}
ev.Defaults()
if NoGUI {
ev.NoGUIRun()
return
}
// core.RenderTrace = true
b := core.NewBody("virtroom").SetTitle("Physics Virtual Room")
ev.ConfigGUI(b)
b.RunMainWindow()
}
// Code generated by "gosl"; DO NOT EDIT
package physics
import (
"embed"
"fmt"
"math"
"unsafe"
"cogentcore.org/core/gpu"
"cogentcore.org/lab/tensor"
)
//go:embed shaders/*.wgsl
var shaders embed.FS
var (
// GPUInitialized is true once the GPU system has been initialized.
// Prevents multiple initializations.
GPUInitialized bool
// ComputeGPU is the compute gpu device.
// Set this prior to calling GPUInit() to use an existing device.
ComputeGPU *gpu.GPU
// BorrowedGPU is true if our ComputeGPU is set externally,
// versus created specifically for this system. If external,
// we don't release it.
BorrowedGPU bool
// UseGPU indicates whether to use GPU vs. CPU.
UseGPU bool
)
// GPUSystem is a GPU compute System with kernels operating on the
// same set of data variables.
var GPUSystem *gpu.ComputeSystem
// GPUVars is an enum for GPU variables, for specifying what to sync.
type GPUVars int32 //enums:enum
const (
ParamsVar GPUVars = 0
BodiesVar GPUVars = 1
ObjectsVar GPUVars = 2
BodyJointsVar GPUVars = 3
JointsVar GPUVars = 4
JointDoFsVar GPUVars = 5
BodyCollidePairsVar GPUVars = 6
DynamicsVar GPUVars = 7
BroadContactsNVar GPUVars = 8
BroadContactsVar GPUVars = 9
ContactsNVar GPUVars = 10
ContactsVar GPUVars = 11
JointControlsVar GPUVars = 12
)
// Tensor stride variables
var TensorStrides tensor.Uint32
// GPUInit initializes the GPU compute system,
// configuring system(s), variables and kernels.
// It is safe to call multiple times: detects if already run.
func GPUInit() {
if GPUInitialized {
return
}
GPUInitialized = true
if ComputeGPU == nil { // set prior to this call to use an external
ComputeGPU = gpu.NewComputeGPU()
} else {
BorrowedGPU = true
}
gp := ComputeGPU
_ = fmt.Sprintf("%g",math.NaN()) // keep imports happy
{
sy := gpu.NewComputeSystem(gp, "Default")
GPUSystem = sy
vars := sy.Vars()
{
sgp := vars.AddGroup(gpu.Storage, "Params")
var vr *gpu.Var
_ = vr
vr = sgp.Add("TensorStrides", gpu.Uint32, 1, gpu.ComputeShader)
vr.ReadOnly = true
vr = sgp.AddStruct("Params", int(unsafe.Sizeof(PhysicsParams{})), 1, gpu.ComputeShader)
sgp.SetNValues(1)
}
{
sgp := vars.AddGroup(gpu.Storage, "Bodies")
var vr *gpu.Var
_ = vr
vr = sgp.Add("Bodies", gpu.Float32, 1, gpu.ComputeShader)
vr = sgp.Add("Objects", gpu.Int32, 1, gpu.ComputeShader)
vr = sgp.Add("BodyJoints", gpu.Int32, 1, gpu.ComputeShader)
vr = sgp.Add("Joints", gpu.Float32, 1, gpu.ComputeShader)
vr = sgp.Add("JointDoFs", gpu.Float32, 1, gpu.ComputeShader)
vr = sgp.Add("BodyCollidePairs", gpu.Int32, 1, gpu.ComputeShader)
sgp.SetNValues(1)
}
{
sgp := vars.AddGroup(gpu.Storage, "Dynamics")
var vr *gpu.Var
_ = vr
vr = sgp.Add("Dynamics", gpu.Float32, 1, gpu.ComputeShader)
vr = sgp.Add("BroadContactsN", gpu.Int32, 1, gpu.ComputeShader)
vr = sgp.Add("BroadContacts", gpu.Float32, 1, gpu.ComputeShader)
vr = sgp.Add("ContactsN", gpu.Int32, 1, gpu.ComputeShader)
vr = sgp.Add("Contacts", gpu.Float32, 1, gpu.ComputeShader)
sgp.SetNValues(1)
}
{
sgp := vars.AddGroup(gpu.Storage, "Controls")
var vr *gpu.Var
_ = vr
vr = sgp.Add("JointControls", gpu.Float32, 1, gpu.ComputeShader)
sgp.SetNValues(1)
}
var pl *gpu.ComputePipeline
pl = gpu.NewComputePipelineShaderFS(shaders, "shaders/CollisionBroad.wgsl", sy)
pl.AddVarUsed(0, "TensorStrides")
pl.AddVarUsed(1, "Bodies")
pl.AddVarUsed(1, "BodyCollidePairs")
pl.AddVarUsed(2, "BroadContacts")
pl.AddVarUsed(2, "BroadContactsN")
pl.AddVarUsed(2, "Dynamics")
pl.AddVarUsed(0, "Params")
pl = gpu.NewComputePipelineShaderFS(shaders, "shaders/CollisionNarrow.wgsl", sy)
pl.AddVarUsed(0, "TensorStrides")
pl.AddVarUsed(1, "Bodies")
pl.AddVarUsed(2, "BroadContacts")
pl.AddVarUsed(2, "BroadContactsN")
pl.AddVarUsed(2, "Contacts")
pl.AddVarUsed(2, "ContactsN")
pl.AddVarUsed(2, "Dynamics")
pl.AddVarUsed(0, "Params")
pl = gpu.NewComputePipelineShaderFS(shaders, "shaders/DynamicsCurToNext.wgsl", sy)
pl.AddVarUsed(0, "TensorStrides")
pl.AddVarUsed(2, "Dynamics")
pl.AddVarUsed(0, "Params")
pl = gpu.NewComputePipelineShaderFS(shaders, "shaders/ForcesFromJoints.wgsl", sy)
pl.AddVarUsed(0, "TensorStrides")
pl.AddVarUsed(1, "BodyJoints")
pl.AddVarUsed(2, "Dynamics")
pl.AddVarUsed(1, "Joints")
pl.AddVarUsed(0, "Params")
pl = gpu.NewComputePipelineShaderFS(shaders, "shaders/InitDynamics.wgsl", sy)
pl.AddVarUsed(0, "TensorStrides")
pl.AddVarUsed(1, "Bodies")
pl.AddVarUsed(2, "Dynamics")
pl.AddVarUsed(0, "Params")
pl = gpu.NewComputePipelineShaderFS(shaders, "shaders/StepBodyContactDeltas.wgsl", sy)
pl.AddVarUsed(0, "TensorStrides")
pl.AddVarUsed(1, "Bodies")
pl.AddVarUsed(2, "Contacts")
pl.AddVarUsed(2, "ContactsN")
pl.AddVarUsed(2, "Dynamics")
pl.AddVarUsed(0, "Params")
pl = gpu.NewComputePipelineShaderFS(shaders, "shaders/StepBodyContacts.wgsl", sy)
pl.AddVarUsed(0, "TensorStrides")
pl.AddVarUsed(1, "Bodies")
pl.AddVarUsed(2, "Contacts")
pl.AddVarUsed(2, "ContactsN")
pl.AddVarUsed(2, "Dynamics")
pl.AddVarUsed(0, "Params")
pl = gpu.NewComputePipelineShaderFS(shaders, "shaders/StepInit.wgsl", sy)
pl.AddVarUsed(0, "TensorStrides")
pl.AddVarUsed(2, "BroadContactsN")
pl.AddVarUsed(2, "ContactsN")
pl.AddVarUsed(3, "JointControls")
pl.AddVarUsed(0, "Params")
pl = gpu.NewComputePipelineShaderFS(shaders, "shaders/StepIntegrateBodies.wgsl", sy)
pl.AddVarUsed(0, "TensorStrides")
pl.AddVarUsed(1, "Bodies")
pl.AddVarUsed(2, "Dynamics")
pl.AddVarUsed(0, "Params")
pl = gpu.NewComputePipelineShaderFS(shaders, "shaders/StepJointForces.wgsl", sy)
pl.AddVarUsed(0, "TensorStrides")
pl.AddVarUsed(1, "Bodies")
pl.AddVarUsed(2, "Dynamics")
pl.AddVarUsed(3, "JointControls")
pl.AddVarUsed(1, "JointDoFs")
pl.AddVarUsed(1, "Joints")
pl.AddVarUsed(0, "Params")
pl = gpu.NewComputePipelineShaderFS(shaders, "shaders/StepSolveJoints.wgsl", sy)
pl.AddVarUsed(0, "TensorStrides")
pl.AddVarUsed(1, "Bodies")
pl.AddVarUsed(2, "Dynamics")
pl.AddVarUsed(3, "JointControls")
pl.AddVarUsed(1, "JointDoFs")
pl.AddVarUsed(1, "Joints")
pl.AddVarUsed(1, "Objects")
pl.AddVarUsed(0, "Params")
sy.Config()
}
}
// GPURelease releases the GPU compute system resources.
// Call this at program exit.
func GPURelease() {
if GPUSystem != nil {
GPUSystem.Release()
GPUSystem = nil
}
if !BorrowedGPU && ComputeGPU != nil {
ComputeGPU.Release()
}
ComputeGPU = nil
}
// RunCollisionBroad runs the CollisionBroad kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// Can call multiple Run* kernels in a row, which are then all launched
// in the same command submission on the GPU, which is by far the most efficient.
// MUST call RunDone (with optional vars to sync) after all Run calls.
// Alternatively, a single-shot RunOneCollisionBroad call does Run and Done for a
// single run-and-sync case.
func RunCollisionBroad(n int) {
if UseGPU {
RunCollisionBroadGPU(n)
} else {
RunCollisionBroadCPU(n)
}
}
// RunCollisionBroadGPU runs the CollisionBroad kernel on the GPU. See [RunCollisionBroad] for more info.
func RunCollisionBroadGPU(n int) {
sy := GPUSystem
pl := sy.ComputePipelines["CollisionBroad"]
ce, _ := sy.BeginComputePass()
pl.Dispatch1D(ce, n, 64)
}
// RunCollisionBroadCPU runs the CollisionBroad kernel on the CPU.
func RunCollisionBroadCPU(n int) {
gpu.VectorizeFunc(0, n, CollisionBroad)
}
// RunOneCollisionBroad runs the CollisionBroad kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// This version then calls RunDone with the given variables to sync
// after the Run, for a single-shot Run-and-Done call. If multiple kernels
// can be run in sequence, it is much more efficient to do multiple Run*
// calls followed by a RunDone call.
func RunOneCollisionBroad(n int, syncVars ...GPUVars) {
if UseGPU {
RunCollisionBroadGPU(n)
RunDone(syncVars...)
} else {
RunCollisionBroadCPU(n)
}
}
// RunCollisionNarrow runs the CollisionNarrow kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// Can call multiple Run* kernels in a row, which are then all launched
// in the same command submission on the GPU, which is by far the most efficient.
// MUST call RunDone (with optional vars to sync) after all Run calls.
// Alternatively, a single-shot RunOneCollisionNarrow call does Run and Done for a
// single run-and-sync case.
func RunCollisionNarrow(n int) {
if UseGPU {
RunCollisionNarrowGPU(n)
} else {
RunCollisionNarrowCPU(n)
}
}
// RunCollisionNarrowGPU runs the CollisionNarrow kernel on the GPU. See [RunCollisionNarrow] for more info.
func RunCollisionNarrowGPU(n int) {
sy := GPUSystem
pl := sy.ComputePipelines["CollisionNarrow"]
ce, _ := sy.BeginComputePass()
pl.Dispatch1D(ce, n, 64)
}
// RunCollisionNarrowCPU runs the CollisionNarrow kernel on the CPU.
func RunCollisionNarrowCPU(n int) {
gpu.VectorizeFunc(0, n, CollisionNarrow)
}
// RunOneCollisionNarrow runs the CollisionNarrow kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// This version then calls RunDone with the given variables to sync
// after the Run, for a single-shot Run-and-Done call. If multiple kernels
// can be run in sequence, it is much more efficient to do multiple Run*
// calls followed by a RunDone call.
func RunOneCollisionNarrow(n int, syncVars ...GPUVars) {
if UseGPU {
RunCollisionNarrowGPU(n)
RunDone(syncVars...)
} else {
RunCollisionNarrowCPU(n)
}
}
// RunDynamicsCurToNext runs the DynamicsCurToNext kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// Can call multiple Run* kernels in a row, which are then all launched
// in the same command submission on the GPU, which is by far the most efficient.
// MUST call RunDone (with optional vars to sync) after all Run calls.
// Alternatively, a single-shot RunOneDynamicsCurToNext call does Run and Done for a
// single run-and-sync case.
func RunDynamicsCurToNext(n int) {
if UseGPU {
RunDynamicsCurToNextGPU(n)
} else {
RunDynamicsCurToNextCPU(n)
}
}
// RunDynamicsCurToNextGPU runs the DynamicsCurToNext kernel on the GPU. See [RunDynamicsCurToNext] for more info.
func RunDynamicsCurToNextGPU(n int) {
sy := GPUSystem
pl := sy.ComputePipelines["DynamicsCurToNext"]
ce, _ := sy.BeginComputePass()
pl.Dispatch1D(ce, n, 64)
}
// RunDynamicsCurToNextCPU runs the DynamicsCurToNext kernel on the CPU.
func RunDynamicsCurToNextCPU(n int) {
gpu.VectorizeFunc(0, n, DynamicsCurToNext)
}
// RunOneDynamicsCurToNext runs the DynamicsCurToNext kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// This version then calls RunDone with the given variables to sync
// after the Run, for a single-shot Run-and-Done call. If multiple kernels
// can be run in sequence, it is much more efficient to do multiple Run*
// calls followed by a RunDone call.
func RunOneDynamicsCurToNext(n int, syncVars ...GPUVars) {
if UseGPU {
RunDynamicsCurToNextGPU(n)
RunDone(syncVars...)
} else {
RunDynamicsCurToNextCPU(n)
}
}
// RunForcesFromJoints runs the ForcesFromJoints kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// Can call multiple Run* kernels in a row, which are then all launched
// in the same command submission on the GPU, which is by far the most efficient.
// MUST call RunDone (with optional vars to sync) after all Run calls.
// Alternatively, a single-shot RunOneForcesFromJoints call does Run and Done for a
// single run-and-sync case.
func RunForcesFromJoints(n int) {
if UseGPU {
RunForcesFromJointsGPU(n)
} else {
RunForcesFromJointsCPU(n)
}
}
// RunForcesFromJointsGPU runs the ForcesFromJoints kernel on the GPU. See [RunForcesFromJoints] for more info.
func RunForcesFromJointsGPU(n int) {
sy := GPUSystem
pl := sy.ComputePipelines["ForcesFromJoints"]
ce, _ := sy.BeginComputePass()
pl.Dispatch1D(ce, n, 64)
}
// RunForcesFromJointsCPU runs the ForcesFromJoints kernel on the CPU.
func RunForcesFromJointsCPU(n int) {
gpu.VectorizeFunc(0, n, ForcesFromJoints)
}
// RunOneForcesFromJoints runs the ForcesFromJoints kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// This version then calls RunDone with the given variables to sync
// after the Run, for a single-shot Run-and-Done call. If multiple kernels
// can be run in sequence, it is much more efficient to do multiple Run*
// calls followed by a RunDone call.
func RunOneForcesFromJoints(n int, syncVars ...GPUVars) {
if UseGPU {
RunForcesFromJointsGPU(n)
RunDone(syncVars...)
} else {
RunForcesFromJointsCPU(n)
}
}
// RunInitDynamics runs the InitDynamics kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// Can call multiple Run* kernels in a row, which are then all launched
// in the same command submission on the GPU, which is by far the most efficient.
// MUST call RunDone (with optional vars to sync) after all Run calls.
// Alternatively, a single-shot RunOneInitDynamics call does Run and Done for a
// single run-and-sync case.
func RunInitDynamics(n int) {
if UseGPU {
RunInitDynamicsGPU(n)
} else {
RunInitDynamicsCPU(n)
}
}
// RunInitDynamicsGPU runs the InitDynamics kernel on the GPU. See [RunInitDynamics] for more info.
func RunInitDynamicsGPU(n int) {
sy := GPUSystem
pl := sy.ComputePipelines["InitDynamics"]
ce, _ := sy.BeginComputePass()
pl.Dispatch1D(ce, n, 64)
}
// RunInitDynamicsCPU runs the InitDynamics kernel on the CPU.
func RunInitDynamicsCPU(n int) {
gpu.VectorizeFunc(0, n, InitDynamics)
}
// RunOneInitDynamics runs the InitDynamics kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// This version then calls RunDone with the given variables to sync
// after the Run, for a single-shot Run-and-Done call. If multiple kernels
// can be run in sequence, it is much more efficient to do multiple Run*
// calls followed by a RunDone call.
func RunOneInitDynamics(n int, syncVars ...GPUVars) {
if UseGPU {
RunInitDynamicsGPU(n)
RunDone(syncVars...)
} else {
RunInitDynamicsCPU(n)
}
}
// RunStepBodyContactDeltas runs the StepBodyContactDeltas kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// Can call multiple Run* kernels in a row, which are then all launched
// in the same command submission on the GPU, which is by far the most efficient.
// MUST call RunDone (with optional vars to sync) after all Run calls.
// Alternatively, a single-shot RunOneStepBodyContactDeltas call does Run and Done for a
// single run-and-sync case.
func RunStepBodyContactDeltas(n int) {
if UseGPU {
RunStepBodyContactDeltasGPU(n)
} else {
RunStepBodyContactDeltasCPU(n)
}
}
// RunStepBodyContactDeltasGPU runs the StepBodyContactDeltas kernel on the GPU. See [RunStepBodyContactDeltas] for more info.
func RunStepBodyContactDeltasGPU(n int) {
sy := GPUSystem
pl := sy.ComputePipelines["StepBodyContactDeltas"]
ce, _ := sy.BeginComputePass()
pl.Dispatch1D(ce, n, 64)
}
// RunStepBodyContactDeltasCPU runs the StepBodyContactDeltas kernel on the CPU.
func RunStepBodyContactDeltasCPU(n int) {
gpu.VectorizeFunc(0, n, StepBodyContactDeltas)
}
// RunOneStepBodyContactDeltas runs the StepBodyContactDeltas kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// This version then calls RunDone with the given variables to sync
// after the Run, for a single-shot Run-and-Done call. If multiple kernels
// can be run in sequence, it is much more efficient to do multiple Run*
// calls followed by a RunDone call.
func RunOneStepBodyContactDeltas(n int, syncVars ...GPUVars) {
if UseGPU {
RunStepBodyContactDeltasGPU(n)
RunDone(syncVars...)
} else {
RunStepBodyContactDeltasCPU(n)
}
}
// RunStepBodyContacts runs the StepBodyContacts kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// Can call multiple Run* kernels in a row, which are then all launched
// in the same command submission on the GPU, which is by far the most efficient.
// MUST call RunDone (with optional vars to sync) after all Run calls.
// Alternatively, a single-shot RunOneStepBodyContacts call does Run and Done for a
// single run-and-sync case.
func RunStepBodyContacts(n int) {
if UseGPU {
RunStepBodyContactsGPU(n)
} else {
RunStepBodyContactsCPU(n)
}
}
// RunStepBodyContactsGPU runs the StepBodyContacts kernel on the GPU. See [RunStepBodyContacts] for more info.
func RunStepBodyContactsGPU(n int) {
sy := GPUSystem
pl := sy.ComputePipelines["StepBodyContacts"]
ce, _ := sy.BeginComputePass()
pl.Dispatch1D(ce, n, 64)
}
// RunStepBodyContactsCPU runs the StepBodyContacts kernel on the CPU.
func RunStepBodyContactsCPU(n int) {
gpu.VectorizeFunc(0, n, StepBodyContacts)
}
// RunOneStepBodyContacts runs the StepBodyContacts kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// This version then calls RunDone with the given variables to sync
// after the Run, for a single-shot Run-and-Done call. If multiple kernels
// can be run in sequence, it is much more efficient to do multiple Run*
// calls followed by a RunDone call.
func RunOneStepBodyContacts(n int, syncVars ...GPUVars) {
if UseGPU {
RunStepBodyContactsGPU(n)
RunDone(syncVars...)
} else {
RunStepBodyContactsCPU(n)
}
}
// RunStepInit runs the StepInit kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// Can call multiple Run* kernels in a row, which are then all launched
// in the same command submission on the GPU, which is by far the most efficient.
// MUST call RunDone (with optional vars to sync) after all Run calls.
// Alternatively, a single-shot RunOneStepInit call does Run and Done for a
// single run-and-sync case.
func RunStepInit(n int) {
if UseGPU {
RunStepInitGPU(n)
} else {
RunStepInitCPU(n)
}
}
// RunStepInitGPU runs the StepInit kernel on the GPU. See [RunStepInit] for more info.
func RunStepInitGPU(n int) {
sy := GPUSystem
pl := sy.ComputePipelines["StepInit"]
ce, _ := sy.BeginComputePass()
pl.Dispatch1D(ce, n, 64)
}
// RunStepInitCPU runs the StepInit kernel on the CPU.
func RunStepInitCPU(n int) {
gpu.VectorizeFunc(0, n, StepInit)
}
// RunOneStepInit runs the StepInit kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// This version then calls RunDone with the given variables to sync
// after the Run, for a single-shot Run-and-Done call. If multiple kernels
// can be run in sequence, it is much more efficient to do multiple Run*
// calls followed by a RunDone call.
func RunOneStepInit(n int, syncVars ...GPUVars) {
if UseGPU {
RunStepInitGPU(n)
RunDone(syncVars...)
} else {
RunStepInitCPU(n)
}
}
// RunStepIntegrateBodies runs the StepIntegrateBodies kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// Can call multiple Run* kernels in a row, which are then all launched
// in the same command submission on the GPU, which is by far the most efficient.
// MUST call RunDone (with optional vars to sync) after all Run calls.
// Alternatively, a single-shot RunOneStepIntegrateBodies call does Run and Done for a
// single run-and-sync case.
func RunStepIntegrateBodies(n int) {
if UseGPU {
RunStepIntegrateBodiesGPU(n)
} else {
RunStepIntegrateBodiesCPU(n)
}
}
// RunStepIntegrateBodiesGPU runs the StepIntegrateBodies kernel on the GPU. See [RunStepIntegrateBodies] for more info.
func RunStepIntegrateBodiesGPU(n int) {
sy := GPUSystem
pl := sy.ComputePipelines["StepIntegrateBodies"]
ce, _ := sy.BeginComputePass()
pl.Dispatch1D(ce, n, 64)
}
// RunStepIntegrateBodiesCPU runs the StepIntegrateBodies kernel on the CPU.
func RunStepIntegrateBodiesCPU(n int) {
gpu.VectorizeFunc(0, n, StepIntegrateBodies)
}
// RunOneStepIntegrateBodies runs the StepIntegrateBodies kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// This version then calls RunDone with the given variables to sync
// after the Run, for a single-shot Run-and-Done call. If multiple kernels
// can be run in sequence, it is much more efficient to do multiple Run*
// calls followed by a RunDone call.
func RunOneStepIntegrateBodies(n int, syncVars ...GPUVars) {
if UseGPU {
RunStepIntegrateBodiesGPU(n)
RunDone(syncVars...)
} else {
RunStepIntegrateBodiesCPU(n)
}
}
// RunStepJointForces runs the StepJointForces kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// Can call multiple Run* kernels in a row, which are then all launched
// in the same command submission on the GPU, which is by far the most efficient.
// MUST call RunDone (with optional vars to sync) after all Run calls.
// Alternatively, a single-shot RunOneStepJointForces call does Run and Done for a
// single run-and-sync case.
func RunStepJointForces(n int) {
if UseGPU {
RunStepJointForcesGPU(n)
} else {
RunStepJointForcesCPU(n)
}
}
// RunStepJointForcesGPU runs the StepJointForces kernel on the GPU. See [RunStepJointForces] for more info.
func RunStepJointForcesGPU(n int) {
sy := GPUSystem
pl := sy.ComputePipelines["StepJointForces"]
ce, _ := sy.BeginComputePass()
pl.Dispatch1D(ce, n, 64)
}
// RunStepJointForcesCPU runs the StepJointForces kernel on the CPU.
func RunStepJointForcesCPU(n int) {
gpu.VectorizeFunc(0, n, StepJointForces)
}
// RunOneStepJointForces runs the StepJointForces kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// This version then calls RunDone with the given variables to sync
// after the Run, for a single-shot Run-and-Done call. If multiple kernels
// can be run in sequence, it is much more efficient to do multiple Run*
// calls followed by a RunDone call.
func RunOneStepJointForces(n int, syncVars ...GPUVars) {
if UseGPU {
RunStepJointForcesGPU(n)
RunDone(syncVars...)
} else {
RunStepJointForcesCPU(n)
}
}
// RunStepSolveJoints runs the StepSolveJoints kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// Can call multiple Run* kernels in a row, which are then all launched
// in the same command submission on the GPU, which is by far the most efficient.
// MUST call RunDone (with optional vars to sync) after all Run calls.
// Alternatively, a single-shot RunOneStepSolveJoints call does Run and Done for a
// single run-and-sync case.
func RunStepSolveJoints(n int) {
if UseGPU {
RunStepSolveJointsGPU(n)
} else {
RunStepSolveJointsCPU(n)
}
}
// RunStepSolveJointsGPU runs the StepSolveJoints kernel on the GPU. See [RunStepSolveJoints] for more info.
func RunStepSolveJointsGPU(n int) {
sy := GPUSystem
pl := sy.ComputePipelines["StepSolveJoints"]
ce, _ := sy.BeginComputePass()
pl.Dispatch1D(ce, n, 64)
}
// RunStepSolveJointsCPU runs the StepSolveJoints kernel on the CPU.
func RunStepSolveJointsCPU(n int) {
gpu.VectorizeFunc(0, n, StepSolveJoints)
}
// RunOneStepSolveJoints runs the StepSolveJoints kernel with given number of elements,
// on either the CPU or GPU depending on the UseGPU variable.
// This version then calls RunDone with the given variables to sync
// after the Run, for a single-shot Run-and-Done call. If multiple kernels
// can be run in sequence, it is much more efficient to do multiple Run*
// calls followed by a RunDone call.
func RunOneStepSolveJoints(n int, syncVars ...GPUVars) {
if UseGPU {
RunStepSolveJointsGPU(n)
RunDone(syncVars...)
} else {
RunStepSolveJointsCPU(n)
}
}
// RunDone must be called after Run* calls to start compute kernels.
// This actually submits the kernel jobs to the GPU, and adds commands
// to synchronize the given variables back from the GPU to the CPU.
// After this function completes, the GPU results will be available in
// the specified variables.
func RunDone(syncVars ...GPUVars) {
if !UseGPU {
return
}
sy := GPUSystem
sy.ComputeEncoder.End()
ReadFromGPU(syncVars...)
sy.EndComputePass()
SyncFromGPU(syncVars...)
}
// ToGPU copies given variables to the GPU for the system.
func ToGPU(vars ...GPUVars) {
if !UseGPU {
return
}
sy := GPUSystem
syVars := sy.Vars()
for _, vr := range vars {
switch vr {
case ParamsVar:
v, _ := syVars.ValueByIndex(0, "Params", 0)
gpu.SetValueFrom(v, Params)
case BodiesVar:
v, _ := syVars.ValueByIndex(1, "Bodies", 0)
gpu.SetValueFrom(v, Bodies.Values)
case ObjectsVar:
v, _ := syVars.ValueByIndex(1, "Objects", 0)
gpu.SetValueFrom(v, Objects.Values)
case BodyJointsVar:
v, _ := syVars.ValueByIndex(1, "BodyJoints", 0)
gpu.SetValueFrom(v, BodyJoints.Values)
case JointsVar:
v, _ := syVars.ValueByIndex(1, "Joints", 0)
gpu.SetValueFrom(v, Joints.Values)
case JointDoFsVar:
v, _ := syVars.ValueByIndex(1, "JointDoFs", 0)
gpu.SetValueFrom(v, JointDoFs.Values)
case BodyCollidePairsVar:
v, _ := syVars.ValueByIndex(1, "BodyCollidePairs", 0)
gpu.SetValueFrom(v, BodyCollidePairs.Values)
case DynamicsVar:
v, _ := syVars.ValueByIndex(2, "Dynamics", 0)
gpu.SetValueFrom(v, Dynamics.Values)
case BroadContactsNVar:
v, _ := syVars.ValueByIndex(2, "BroadContactsN", 0)
gpu.SetValueFrom(v, BroadContactsN.Values)
case BroadContactsVar:
v, _ := syVars.ValueByIndex(2, "BroadContacts", 0)
gpu.SetValueFrom(v, BroadContacts.Values)
case ContactsNVar:
v, _ := syVars.ValueByIndex(2, "ContactsN", 0)
gpu.SetValueFrom(v, ContactsN.Values)
case ContactsVar:
v, _ := syVars.ValueByIndex(2, "Contacts", 0)
gpu.SetValueFrom(v, Contacts.Values)
case JointControlsVar:
v, _ := syVars.ValueByIndex(3, "JointControls", 0)
gpu.SetValueFrom(v, JointControls.Values)
}
}
}
// RunGPUSync can be called to synchronize data between CPU and GPU.
// Any prior ToGPU* calls will execute to send data to the GPU,
// and any subsequent RunDone* calls will copy data back from the GPU.
func RunGPUSync() {
if !UseGPU {
return
}
sy := GPUSystem
sy.BeginComputePass()
}
// ToGPUTensorStrides gets tensor strides and starts copying to the GPU.
func ToGPUTensorStrides() {
if !UseGPU {
return
}
sy := GPUSystem
syVars := sy.Vars()
TensorStrides.SetShapeSizes(120)
TensorStrides.SetInt1D(Bodies.Shape().Strides[0], 0)
TensorStrides.SetInt1D(Bodies.Shape().Strides[1], 1)
TensorStrides.SetInt1D(Objects.Shape().Strides[0], 10)
TensorStrides.SetInt1D(Objects.Shape().Strides[1], 11)
TensorStrides.SetInt1D(BodyJoints.Shape().Strides[0], 20)
TensorStrides.SetInt1D(BodyJoints.Shape().Strides[1], 21)
TensorStrides.SetInt1D(BodyJoints.Shape().Strides[2], 22)
TensorStrides.SetInt1D(Joints.Shape().Strides[0], 30)
TensorStrides.SetInt1D(Joints.Shape().Strides[1], 31)
TensorStrides.SetInt1D(JointDoFs.Shape().Strides[0], 40)
TensorStrides.SetInt1D(JointDoFs.Shape().Strides[1], 41)
TensorStrides.SetInt1D(BodyCollidePairs.Shape().Strides[0], 50)
TensorStrides.SetInt1D(BodyCollidePairs.Shape().Strides[1], 51)
TensorStrides.SetInt1D(Dynamics.Shape().Strides[0], 60)
TensorStrides.SetInt1D(Dynamics.Shape().Strides[1], 61)
TensorStrides.SetInt1D(Dynamics.Shape().Strides[2], 62)
TensorStrides.SetInt1D(BroadContactsN.Shape().Strides[0], 70)
TensorStrides.SetInt1D(BroadContacts.Shape().Strides[0], 80)
TensorStrides.SetInt1D(BroadContacts.Shape().Strides[1], 81)
TensorStrides.SetInt1D(ContactsN.Shape().Strides[0], 90)
TensorStrides.SetInt1D(Contacts.Shape().Strides[0], 100)
TensorStrides.SetInt1D(Contacts.Shape().Strides[1], 101)
TensorStrides.SetInt1D(JointControls.Shape().Strides[0], 110)
TensorStrides.SetInt1D(JointControls.Shape().Strides[1], 111)
v, _ := syVars.ValueByIndex(0, "TensorStrides", 0)
gpu.SetValueFrom(v, TensorStrides.Values)
}
// ReadFromGPU starts the process of copying vars to the GPU.
func ReadFromGPU(vars ...GPUVars) {
sy := GPUSystem
syVars := sy.Vars()
for _, vr := range vars {
switch vr {
case ParamsVar:
v, _ := syVars.ValueByIndex(0, "Params", 0)
v.GPUToRead(sy.CommandEncoder)
case BodiesVar:
v, _ := syVars.ValueByIndex(1, "Bodies", 0)
v.GPUToRead(sy.CommandEncoder)
case ObjectsVar:
v, _ := syVars.ValueByIndex(1, "Objects", 0)
v.GPUToRead(sy.CommandEncoder)
case BodyJointsVar:
v, _ := syVars.ValueByIndex(1, "BodyJoints", 0)
v.GPUToRead(sy.CommandEncoder)
case JointsVar:
v, _ := syVars.ValueByIndex(1, "Joints", 0)
v.GPUToRead(sy.CommandEncoder)
case JointDoFsVar:
v, _ := syVars.ValueByIndex(1, "JointDoFs", 0)
v.GPUToRead(sy.CommandEncoder)
case BodyCollidePairsVar:
v, _ := syVars.ValueByIndex(1, "BodyCollidePairs", 0)
v.GPUToRead(sy.CommandEncoder)
case DynamicsVar:
v, _ := syVars.ValueByIndex(2, "Dynamics", 0)
v.GPUToRead(sy.CommandEncoder)
case BroadContactsNVar:
v, _ := syVars.ValueByIndex(2, "BroadContactsN", 0)
v.GPUToRead(sy.CommandEncoder)
case BroadContactsVar:
v, _ := syVars.ValueByIndex(2, "BroadContacts", 0)
v.GPUToRead(sy.CommandEncoder)
case ContactsNVar:
v, _ := syVars.ValueByIndex(2, "ContactsN", 0)
v.GPUToRead(sy.CommandEncoder)
case ContactsVar:
v, _ := syVars.ValueByIndex(2, "Contacts", 0)
v.GPUToRead(sy.CommandEncoder)
case JointControlsVar:
v, _ := syVars.ValueByIndex(3, "JointControls", 0)
v.GPUToRead(sy.CommandEncoder)
}
}
}
// SyncFromGPU synchronizes vars from the GPU to the actual variable.
func SyncFromGPU(vars ...GPUVars) {
sy := GPUSystem
syVars := sy.Vars()
for _, vr := range vars {
switch vr {
case ParamsVar:
v, _ := syVars.ValueByIndex(0, "Params", 0)
v.ReadSync()
gpu.ReadToBytes(v, Params)
case BodiesVar:
v, _ := syVars.ValueByIndex(1, "Bodies", 0)
v.ReadSync()
gpu.ReadToBytes(v, Bodies.Values)
case ObjectsVar:
v, _ := syVars.ValueByIndex(1, "Objects", 0)
v.ReadSync()
gpu.ReadToBytes(v, Objects.Values)
case BodyJointsVar:
v, _ := syVars.ValueByIndex(1, "BodyJoints", 0)
v.ReadSync()
gpu.ReadToBytes(v, BodyJoints.Values)
case JointsVar:
v, _ := syVars.ValueByIndex(1, "Joints", 0)
v.ReadSync()
gpu.ReadToBytes(v, Joints.Values)
case JointDoFsVar:
v, _ := syVars.ValueByIndex(1, "JointDoFs", 0)
v.ReadSync()
gpu.ReadToBytes(v, JointDoFs.Values)
case BodyCollidePairsVar:
v, _ := syVars.ValueByIndex(1, "BodyCollidePairs", 0)
v.ReadSync()
gpu.ReadToBytes(v, BodyCollidePairs.Values)
case DynamicsVar:
v, _ := syVars.ValueByIndex(2, "Dynamics", 0)
v.ReadSync()
gpu.ReadToBytes(v, Dynamics.Values)
case BroadContactsNVar:
v, _ := syVars.ValueByIndex(2, "BroadContactsN", 0)
v.ReadSync()
gpu.ReadToBytes(v, BroadContactsN.Values)
case BroadContactsVar:
v, _ := syVars.ValueByIndex(2, "BroadContacts", 0)
v.ReadSync()
gpu.ReadToBytes(v, BroadContacts.Values)
case ContactsNVar:
v, _ := syVars.ValueByIndex(2, "ContactsN", 0)
v.ReadSync()
gpu.ReadToBytes(v, ContactsN.Values)
case ContactsVar:
v, _ := syVars.ValueByIndex(2, "Contacts", 0)
v.ReadSync()
gpu.ReadToBytes(v, Contacts.Values)
case JointControlsVar:
v, _ := syVars.ValueByIndex(3, "JointControls", 0)
v.ReadSync()
gpu.ReadToBytes(v, JointControls.Values)
}
}
}
// GetParams returns a pointer to the given global variable:
// [Params] []PhysicsParams at given index. This directly processed in the GPU code,
// so this function call is an equivalent for the CPU.
func GetParams(idx uint32) *PhysicsParams {
return &Params[idx]
}
// Code generated by "goal build"; DO NOT EDIT.
//line joint.goal:1
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package physics
import (
// "fmt"
"math"
"cogentcore.org/core/math32"
)
//gosl:start
// Sentinel value for unlimited joint limits
const JointLimitUnlimited = 1e10
// JointTypes are joint types that determine nature of interaction.
type JointTypes int32 //enums:enum
const (
// Prismatic allows translation along a single axis (slider): 1 DoF.
Prismatic JointTypes = iota
// Revolute allows rotation about a single axis (axel): 1 DoF.
Revolute
// Ball allows rotation about all three axes (3 DoF, quaternion).
Ball
// Fixed locks all relative motion: 0 DoF.
Fixed
// Free allows full 6-DoF motion (translation and rotation).
Free
// Distance keeps two bodies a distance within joint limits: 6 DoF.
Distance
// D6 is a generic 6-DoF joint.
D6
// PlaneXZ is a version of D6 for navigation in the X-Z plane,
// which creates 2 linear DoF (X, Z) for movement.
PlaneXZ
)
// JointVars are joint state variables stored in tensor.Float32.
// These are all static joint properties; dynamic control variables
// in [JointControlVars] and [JointControls].
type JointVars int32 //enums:enum
const (
// JointType (as an int32 from bits).
JointType JointVars = iota
// JointEnabled allows joints to be dynamically enabled.
JointEnabled
// JointParentFixed means that the parent is NOT updated based on
// the forces and positions for this joint. This can make dynamics
// cleaner when full accuracy is not necessary.
JointParentFixed
// JointNoLinearRotation ignores the rotational (angular) effects of
// linear joint position constraints (i.e., Coriolis and centrifugal forces)
// which can otherwise interfere with rotational position constraints in
// joints with both linear and angular DoFs
// (e.g., [PlaneXZ], for which this is on by default).
JointNoLinearRotation
// JointParent is the dynamic body index for parent body.
// Can be -1 for a fixed parent for absolute anchor.
JointParent
// JointChild is the dynamic body index for child body.
JointChild
// relative position of joint, in parent frame.
// This is prior to parent body rotation.
JointPPosX
JointPPosY
JointPPosZ
// relative orientation of joint, in parent frame.
// This is prior to parent body rotation.
JointPQuatX
JointPQuatY
JointPQuatZ
JointPQuatW
// relative position of joint, in child frame.
// This is prior to child body rotation.
JointCPosX
JointCPosY
JointCPosZ
// relative orientation of joint, in child frame.
// This is prior to parent body rotation.
JointCQuatX
JointCQuatY
JointCQuatZ
JointCQuatW
// JointLinearDoFN is the number of linear degrees-of-freedom for the joint.
JointLinearDoFN
// JointAngularDoFN is the number of angular degrees-of-freedom for the joint.
JointAngularDoFN
// indexes in JointDoFs for each DoF
JointDoF1
JointDoF2
JointDoF3
// angular starts here for Free, Distance, D6
JointDoF4
JointDoF5
JointDoF6
// Computed forces (temp storage until aggregated by bodies).
// Computed parent joint force value.
JointPForceX
JointPForceY
JointPForceZ
// Computed parent joint torque value.
JointPTorqueX
JointPTorqueY
JointPTorqueZ
// Computed child joint force value.
JointCForceX
JointCForceY
JointCForceZ
// Computed child joint torque value.
JointCTorqueX
JointCTorqueY
JointCTorqueZ
// Computed linear lambdas.
JointLinLambdaX
JointLinLambdaY
JointLinLambdaZ
// Computed angular lambdas.
JointAngLambdaX
JointAngLambdaY
JointAngLambdaZ
)
func GetJointType(idx int32) JointTypes {
return JointTypes(math.Float32bits(Joints.Value(int(idx), int(JointType))))
}
func SetJointType(idx int32, typ JointTypes) {
Joints.Set(math.Float32frombits(uint32(typ)), int(idx), int(JointType))
}
func GetJointEnabled(idx int32) bool {
je := math.Float32bits(Joints.Value(int(idx), int(JointEnabled)))
return je != 0
}
func SetJointEnabled(idx int32, enabled bool) {
je := uint32(0)
if enabled {
je = 1
}
Joints.Set(math.Float32frombits(je), int(idx), int(JointEnabled))
}
func GetJointParentFixed(idx int32) bool {
je := math.Float32bits(Joints.Value(int(idx), int(JointParentFixed)))
return je != 0
}
func SetJointParentFixed(idx int32, enabled bool) {
je := uint32(0)
if enabled {
je = 1
}
Joints.Set(math.Float32frombits(je), int(idx), int(JointParentFixed))
}
func GetJointNoLinearRotation(idx int32) bool {
je := math.Float32bits(Joints.Value(int(idx), int(JointNoLinearRotation)))
return je != 0
}
func SetJointNoLinearRotation(idx int32, enabled bool) {
je := uint32(0)
if enabled {
je = 1
}
Joints.Set(math.Float32frombits(je), int(idx), int(JointNoLinearRotation))
}
func SetJointParent(idx, bodyIdx int32) {
Joints.Set(math.Float32frombits(uint32(bodyIdx)), int(idx), int(JointParent))
}
func JointParentIndex(idx int32) int32 {
return int32(math.Float32bits(Joints.Value(int(idx), int(JointParent))))
}
func SetJointChild(idx, bodyIdx int32) {
Joints.Set(math.Float32frombits(uint32(bodyIdx)), int(idx), int(JointChild))
}
func JointChildIndex(idx int32) int32 {
return int32(math.Float32bits(Joints.Value(int(idx), int(JointChild))))
}
func SetJointLinearDoFN(idx, dofN int32) {
Joints.Set(math.Float32frombits(uint32(dofN)), int(idx), int(JointLinearDoFN))
}
func GetJointLinearDoFN(idx int32) int32 {
return int32(math.Float32bits(Joints.Value(int(idx), int(JointLinearDoFN))))
}
func SetJointAngularDoFN(idx, dofN int32) {
Joints.Set(math.Float32frombits(uint32(dofN)), int(idx), int(JointAngularDoFN))
}
func GetJointAngularDoFN(idx int32) int32 {
return int32(math.Float32bits(Joints.Value(int(idx), int(JointAngularDoFN))))
}
func SetJointDoFIndex(idx, dof, dofIdx int32) {
Joints.Set(math.Float32frombits(uint32(dofIdx)), int(idx), int(int32(JointDoF1)+dof))
}
func JointDoFIndex(idx, dof int32) int32 {
return int32(math.Float32bits(Joints.Value(int(idx), int(int32(JointDoF1)+dof))))
}
func JointPPos(idx int32) math32.Vector3 {
return math32.Vec3(Joints.Value(int(idx), int(JointPPosX)), Joints.Value(int(idx), int(JointPPosY)), Joints.Value(int(idx), int(JointPPosZ)))
}
func SetJointPPos(idx int32, pos math32.Vector3) {
Joints.Set(pos.X, int(idx), int(JointPPosX))
Joints.Set(pos.Y, int(idx), int(JointPPosY))
Joints.Set(pos.Z, int(idx), int(JointPPosZ))
}
func JointPQuat(idx int32) math32.Quat {
return math32.NewQuat(Joints.Value(int(idx), int(JointPQuatX)), Joints.Value(int(idx), int(JointPQuatY)), Joints.Value(int(idx), int(JointPQuatZ)), Joints.Value(int(idx), int(JointPQuatW)))
}
func SetJointPQuat(idx int32, rot math32.Quat) {
Joints.Set(rot.X, int(idx), int(JointPQuatX))
Joints.Set(rot.Y, int(idx), int(JointPQuatY))
Joints.Set(rot.Z, int(idx), int(JointPQuatZ))
Joints.Set(rot.W, int(idx), int(JointPQuatW))
}
func JointCPos(idx int32) math32.Vector3 {
return math32.Vec3(Joints.Value(int(idx), int(JointCPosX)), Joints.Value(int(idx), int(JointCPosY)), Joints.Value(int(idx), int(JointCPosZ)))
}
func SetJointCPos(idx int32, pos math32.Vector3) {
Joints.Set(pos.X, int(idx), int(JointCPosX))
Joints.Set(pos.Y, int(idx), int(JointCPosY))
Joints.Set(pos.Z, int(idx), int(JointCPosZ))
}
func JointCQuat(idx int32) math32.Quat {
return math32.NewQuat(Joints.Value(int(idx), int(JointCQuatX)), Joints.Value(int(idx), int(JointCQuatY)), Joints.Value(int(idx), int(JointCQuatZ)), Joints.Value(int(idx), int(JointCQuatW)))
}
func SetJointCQuat(idx int32, rot math32.Quat) {
Joints.Set(rot.X, int(idx), int(JointCQuatX))
Joints.Set(rot.Y, int(idx), int(JointCQuatY))
Joints.Set(rot.Z, int(idx), int(JointCQuatZ))
Joints.Set(rot.W, int(idx), int(JointCQuatW))
}
func JointPForce(idx int32) math32.Vector3 {
return math32.Vec3(Joints.Value(int(idx), int(JointPForceX)), Joints.Value(int(idx), int(JointPForceY)), Joints.Value(int(idx), int(JointPForceZ)))
}
func SetJointPForce(idx int32, f math32.Vector3) {
Joints.Set(f.X, int(idx), int(JointPForceX))
Joints.Set(f.Y, int(idx), int(JointPForceY))
Joints.Set(f.Z, int(idx), int(JointPForceZ))
}
func JointPTorque(idx int32) math32.Vector3 {
return math32.Vec3(Joints.Value(int(idx), int(JointPTorqueX)), Joints.Value(int(idx), int(JointPTorqueY)), Joints.Value(int(idx), int(JointPTorqueZ)))
}
func SetJointPTorque(idx int32, t math32.Vector3) {
Joints.Set(t.X, int(idx), int(JointPTorqueX))
Joints.Set(t.Y, int(idx), int(JointPTorqueY))
Joints.Set(t.Z, int(idx), int(JointPTorqueZ))
}
func JointCForce(idx int32) math32.Vector3 {
return math32.Vec3(Joints.Value(int(idx), int(JointCForceX)), Joints.Value(int(idx), int(JointCForceY)), Joints.Value(int(idx), int(JointCForceZ)))
}
func SetJointCForce(idx int32, f math32.Vector3) {
Joints.Set(f.X, int(idx), int(JointCForceX))
Joints.Set(f.Y, int(idx), int(JointCForceY))
Joints.Set(f.Z, int(idx), int(JointCForceZ))
}
func JointCTorque(idx int32) math32.Vector3 {
return math32.Vec3(Joints.Value(int(idx), int(JointCTorqueX)), Joints.Value(int(idx), int(JointCTorqueY)), Joints.Value(int(idx), int(JointCTorqueZ)))
}
func SetJointCTorque(idx int32, t math32.Vector3) {
Joints.Set(t.X, int(idx), int(JointCTorqueX))
Joints.Set(t.Y, int(idx), int(JointCTorqueY))
Joints.Set(t.Z, int(idx), int(JointCTorqueZ))
}
func JointLinLambda(idx int32) math32.Vector3 {
return math32.Vec3(Joints.Value(int(idx), int(JointLinLambdaX)), Joints.Value(int(idx), int(JointLinLambdaY)), Joints.Value(int(idx), int(JointLinLambdaZ)))
}
func SetJointLinLambda(idx int32, t math32.Vector3) {
Joints.Set(t.X, int(idx), int(JointLinLambdaX))
Joints.Set(t.Y, int(idx), int(JointLinLambdaY))
Joints.Set(t.Z, int(idx), int(JointLinLambdaZ))
}
func JointAngLambda(idx int32) math32.Vector3 {
return math32.Vec3(Joints.Value(int(idx), int(JointAngLambdaX)), Joints.Value(int(idx), int(JointAngLambdaY)), Joints.Value(int(idx), int(JointAngLambdaZ)))
}
func SetJointAngLambda(idx int32, t math32.Vector3) {
Joints.Set(t.X, int(idx), int(JointAngLambdaX))
Joints.Set(t.Y, int(idx), int(JointAngLambdaY))
Joints.Set(t.Z, int(idx), int(JointAngLambdaZ))
}
// JointDoFVars are joint DoF state variables stored in tensor.Float32,
// one for each DoF.
type JointDoFVars int32 //enums:enum
const (
// axis of articulation for the DoF
JointAxisX JointDoFVars = iota
JointAxisY
JointAxisZ
// joint limits
JointLimitLower
JointLimitUpper
)
func JointAxisDoF(didx int32) math32.Vector3 {
return math32.Vec3(JointDoFs.Value(int(didx), int(JointAxisX)), JointDoFs.Value(int(didx), int(JointAxisY)), JointDoFs.Value(int(didx), int(JointAxisZ)))
}
func SetJointAxisDoF(didx int32, axis math32.Vector3) {
JointDoFs.Set(axis.X, int(didx), int(JointAxisX))
JointDoFs.Set(axis.Y, int(didx), int(JointAxisY))
JointDoFs.Set(axis.Z, int(didx), int(JointAxisZ))
}
func JointAxis(idx, dof int32) math32.Vector3 {
return JointAxisDoF(JointDoFIndex(idx, dof))
}
func SetJointAxis(idx, dof int32, axis math32.Vector3) {
SetJointAxisDoF(JointDoFIndex(idx, dof), axis)
}
func JointDoF(idx, dof int32, vr JointDoFVars) float32 {
return JointDoFs.Value(int(JointDoFIndex(idx, dof)), int(vr))
}
func SetJointDoF(idx, dof int32, vr JointDoFVars, value float32) {
JointDoFs.Set(value, int(JointDoFIndex(idx, dof)), int(vr))
}
//gosl:end
func (ml *Model) JointDefaults(idx int32) {
rot := math32.NewQuatIdentity()
SetJointPQuat(idx, rot)
SetJointCQuat(idx, rot)
}
func (ml *Model) JointDoFDefaults(didx int32) {
JointDoFs.Set(-JointLimitUnlimited, int(didx), int(JointLimitLower))
JointDoFs.Set(JointLimitUnlimited, int(didx), int(JointLimitUpper))
JointControls.Set(1, int(didx), int(JointTargetDamp))
}
// NewJointFixed adds a new Fixed joint
// between parent and child dynamic object indexes.
// Use -1 for parent to add a world-anchored joint.
// ppos, cpos are the relative positions from the parent, child.
// Sets relative rotation matricies to identity by default.
func (ml *Model) NewJointFixed(parent, child int32, ppos, cpos math32.Vector3) int32 {
idx := ml.newJoint(Fixed, parent, child, ppos, cpos)
SetJointNoLinearRotation(idx, true)
return idx
}
// NewJointPrismatic adds a new Prismatic (slider) joint
// between parent and child dynamic object indexes.
// Use -1 for parent to add a world-anchored joint.
// ppos, cpos are the relative positions from the parent, child.
// Sets relative rotation matricies to identity by default.
// axis is the axis of articulation for the joint.
func (ml *Model) NewJointPrismatic(parent, child int32, ppos, cpos, axis math32.Vector3) int32 {
idx := ml.newJoint(Prismatic, parent, child, ppos, cpos)
SetJointLinearDoFN(idx, 1)
ml.newJointDoF(idx, 0, axis)
return idx
}
// NewJointRevolute adds a new Revolute (hinge, axel) joint
// between parent and child dynamic object indexes.
// Use -1 for parent to add a world-anchored joint.
// ppos, cpos are the relative positions from the parent, child.
// Sets relative rotation matricies to identity by default.
// axis is the axis of articulation for the joint.
func (ml *Model) NewJointRevolute(parent, child int32, ppos, cpos, axis math32.Vector3) int32 {
idx := ml.newJoint(Revolute, parent, child, ppos, cpos)
SetJointAngularDoFN(idx, 1)
ml.newJointDoF(idx, 0, axis)
return idx
}
// NewJointBall adds a new Ball joint (3 angular DoF)
// between parent and child dynamic object indexes.
// Use -1 for parent to add a world-anchored joint.
// ppos, cpos are the relative positions from the parent, child.
// Sets relative rotation matricies to identity by default.
func (ml *Model) NewJointBall(parent, child int32, ppos, cpos math32.Vector3) int32 {
idx := ml.newJoint(Ball, parent, child, ppos, cpos)
SetJointAngularDoFN(idx, 3)
for d := range math32.W {
axis := math32.Vector3{}
axis.SetDim(d, 1)
ml.newJointDoF(idx, int32(d), axis)
}
return idx
}
// NewJointPlaneXZ adds a new 3 DoF Planar motion joint suitable for
// controlling the motion of a body on the standard X-Z plane (Y = up).
// The two linear DoF control position in X, Z, and 3rd angular
// controls rotation in Y axis. Sets [JointNoLinearRotation]
// Use -1 for parent to add a world-anchored joint (typical).
// ppos, cpos are the relative positions from the parent, child.
// Sets relative rotation matricies to identity by default.
func (ml *Model) NewJointPlaneXZ(parent, child int32, ppos, cpos math32.Vector3) int32 {
idx := ml.NewJointD6(parent, child, ppos, cpos, 2, 1)
ml.newJointDoF(idx, 0, math32.Vec3(1, 0, 0))
ml.newJointDoF(idx, 1, math32.Vec3(0, 0, 1))
ml.newJointDoF(idx, 2, math32.Vec3(0, 1, 0))
SetJointNoLinearRotation(idx, true)
return idx
}
// NewJointDistance adds a new Distance joint (6 DoF)
// between parent and child dynamic object indexes,
// with distance constrained only on the first linear X axis.
// Use -1 for parent to add a world-anchored joint.
// ppos, cpos are the relative positions from the parent, child.
// Sets relative rotation matricies to identity by default.
func (ml *Model) NewJointDistance(parent, child int32, ppos, cpos math32.Vector3, minDist, maxDist float32) int32 {
idx := ml.newJoint(Distance, parent, child, ppos, cpos)
SetJointLinearDoFN(idx, 3)
SetJointAngularDoFN(idx, 3)
for d := range math32.W {
axis := math32.Vector3{}
axis.SetDim(d, 1)
ml.newJointDoF(idx, int32(d), axis)
}
for d := range math32.W {
axis := math32.Vector3{}
axis.SetDim(d, 1)
ml.newJointDoF(idx, int32(d), axis)
}
// only on the X linear axis
SetJointDoF(idx, 0, JointLimitLower, minDist)
SetJointDoF(idx, 0, JointLimitUpper, maxDist)
return idx
}
// NewJointD6 adds a new D6 6 DoF joint with given number of actual
// linear and angular degrees-of-freedom,
// between parent and child dynamic object indexes.
// Use -1 for parent to add a world-anchored joint.
// ppos, cpos are the relative positions from the parent, child.
// Sets relative rotation matricies to identity by default.
func (ml *Model) NewJointD6(parent, child int32, ppos, cpos math32.Vector3, linDoF, angDoF int32) int32 {
idx := ml.newJoint(D6, parent, child, ppos, cpos)
SetJointLinearDoFN(idx, linDoF)
SetJointAngularDoFN(idx, angDoF)
return idx
}
// NewJointFree adds a new Free joint (of which there is little point)
// between parent and child dynamic object indexes.
// Use -1 for parent to add a world-anchored joint.
// ppos, cpos are the relative positions from the parent, child.
// Sets relative rotation matricies to identity by default.
func (ml *Model) NewJointFree(parent, child int32, ppos, cpos math32.Vector3) int32 {
idx := ml.newJoint(Free, parent, child, ppos, cpos)
SetJointLinearDoFN(idx, 0)
SetJointAngularDoFN(idx, 0)
return idx
}
// newJoint adds a new joint between parent and child
// dynamic object indexes.
// Use -1 for parent to add a world-anchored joint.
// ppos, cpos are the relative positions from the parent, child.
// Sets relative rotation matricies to identity by default.
func (ml *Model) newJoint(joint JointTypes, parent, child int32, ppos, cpos math32.Vector3) int32 {
sizes := ml.Joints.ShapeSizes()
idx := int32(sizes[0])
params := &ml.Params[0]
params.JointsN = idx + 1
ml.Joints.SetShapeSizes(int(idx+1), int(JointVarsN))
ml.JointDefaults(idx)
SetJointType(idx, joint)
SetJointEnabled(idx, true)
SetJointParent(idx, parent)
SetJointChild(idx, child)
SetJointPPos(idx, ppos)
SetJointCPos(idx, cpos)
if ml.CurrentObjectJoint >= int(params.MaxObjectJoints)-1 {
params.MaxObjectJoints = int32(ml.CurrentObjectJoint + 1)
ml.Objects.SetShapeSizes(ml.CurrentObject+1, int(params.MaxObjectJoints+1))
}
ml.Objects.Set(idx, int(ml.CurrentObject), int(1+ml.CurrentObjectJoint))
ml.CurrentObjectJoint++
ml.Objects.Set(int32(ml.CurrentObjectJoint), int(ml.CurrentObject), int(0))
return idx
}
// newJointDoF adds new JointDoFs and JointControls entries
// initialized to detfaults. Returns index.
func (ml *Model) newJointDoF(jidx, dof int32, axis math32.Vector3) int32 {
sizes := ml.JointDoFs.ShapeSizes()
didx := int32(sizes[0])
ml.JointDoFs.SetShapeSizes(int(didx+1), int(JointDoFVarsN))
ml.JointControls.SetShapeSizes(int(didx+1), int(JointControlVarsN))
ml.Params[0].JointDoFsN = didx + 1
ml.JointDoFDefaults(didx)
SetJointDoFIndex(jidx, dof, didx)
SetJointAxis(jidx, dof, axis)
return didx
}
// Code generated by "goal build"; DO NOT EDIT.
//line model.goal:1
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package physics
import (
"cogentcore.org/core/math32"
"cogentcore.org/lab/tensor"
)
//go:generate core generate -add-types -gosl
// Model contains and manages all of the physics elements.
type Model struct {
// GPU determines whether to use GPU (else CPU).
GPU bool
// Params are global parameters.
Params []PhysicsParams
// GetContacts will download Contacts from the GPU, if processing them on the CPU.
GetContacts bool
// ReportTotalKE prints out the total computed kinetic energy in the system after
// every step.
ReportTotalKE bool
// CurrentWorld is the [BodyWorld] value to use when creating new bodies.
// Set to -1 to create global elements that interact with everything,
// while 0 and positive numbers only interact amongst themselves.
CurrentWorld int
// CurrentObject is the Object to use when creating new joints.
// Call NewObject to increment.
CurrentObject int `edit:"-"`
// CurrentObjectJoint is the Joint index in CurrentObject
// to use when creating new joints.
CurrentObjectJoint int `edit:"-"`
// ReplicasN is the number of replicated worlds.
// Total bodies from ReplicasStart should be ReplicasN * ReplicaBodiesN.
ReplicasN int32 `edit:"-"`
// ReplicaBodiesStart is the starting body index for replicated world bodies,
// which is needed to efficiently select a body from a specific world.
// This is the start of the World=0 first instance.
ReplicaBodiesStart int32 `edit:"-"`
// ReplicaBodiesN is the number of body elements within each set of
// replicated world bodies, which is needed to efficiently select
// a body from a specific world.
ReplicaBodiesN int32 `edit:"-"`
// ReplicaJointsStart is the starting joint index for replicated world joints,
// which is needed to efficiently select a joint from a specific world.
// This is the start of the World=0 first instance.
ReplicaJointsStart int32 `edit:"-"`
// ReplicaJointsN is the number of joint elements within each set of
// replicated world joints, which is needed to efficiently select
// a joint from a specific world.
ReplicaJointsN int32 `edit:"-"`
// Bodies are the rigid body elements (dynamic and static),
// specifying the constant, non-dynamic properties,
// which is initial state for dynamics.
// [body][BodyVarsN]
Bodies *tensor.Float32 `display:"no-inline"`
// Objects is a list of joint indexes for each object, where each object
// contains all the joints interconnecting an overlapping set of bodies.
// This is known as an articulation in other physics software.
// Joints must be added in parent -> child order within objects, as joints
// are updated in sequential order within object.
// [object][MaxObjectJoints+1]
Objects *tensor.Int32 `display:"no-inline"`
// BodyJoints is a list of joint indexes for each dynamic body, for aggregating.
// [dyn body][parent, child][Params.BodyJointsMax]
BodyJoints *tensor.Int32 `display:"no-inline"`
// Joints is a list of permanent joints connecting bodies,
// which do not change (no dynamic variables).
// [joint][JointVarsN]
Joints *tensor.Float32 `display:"no-inline"`
// JointDoFs is a list of joint DoF parameters, allocated per joint.
// [dof][JointDoFVars]
JointDoFs *tensor.Float32 `display:"no-inline"`
// BodyCollidePairs are pairs of Body indexes that could potentially collide
// based on precomputed collision logic, using World, Group, and Joint indexes.
// [BodyCollidePairsN][2]
BodyCollidePairs *tensor.Int32
// Dynamics are the dynamic rigid body elements: these actually move.
// The first set of variables are for initial values, and the second current.
// [body][cur/next][DynamicVarsN]
Dynamics *tensor.Float32 `display:"no-inline"`
// BroadContactsN has number of points of broad contact
// between bodies. [1]
BroadContactsN *tensor.Int32 `display:"no-inline"`
// BroadContacts are the results of broad-phase contact processing,
// establishing possible points of contact between bodies.
// [ContactsMax][BroadContactVarsN]
BroadContacts *tensor.Float32 `display:"no-inline"`
// ContactsN has number of points of narrow (final) contact
// between bodies. [1]
ContactsN *tensor.Int32 `display:"no-inline"`
// Contacts are the results of narrow-phase contact processing,
// where only actual contacts with fully-specified values are present.
// [ContactsMax][ContactVarsN]
Contacts *tensor.Float32 `display:"no-inline"`
// JointControls are dynamic joint control inputs, per joint DoF
// (in correspondence with [JointDoFs]). This can be uploaded to the
// GPU at every step.
// [dof][JointControlVarsN]
JointControls *tensor.Float32 `display:"no-inline"`
}
func NewModel() *Model {
ml := &Model{}
ml.Init()
return ml
}
// Init makes initial vars. Called in NewModel.
// Must call Config once configured.
func (ml *Model) Init() {
ml.GPU = true
ml.Params = make([]PhysicsParams, 1)
ml.Params[0].Defaults()
ml.Reset()
}
// Reset resets all data to empty: starting over.
func (ml *Model) Reset() {
ml.CurrentWorld = 0
ml.CurrentObject = 0
ml.CurrentObjectJoint = 0
ml.Params[0].Reset()
ml.Bodies = tensor.NewFloat32(0, int(BodyVarsN))
ml.Objects = tensor.NewInt32(0, 1)
ml.Joints = tensor.NewFloat32(0, int(JointVarsN))
ml.JointDoFs = tensor.NewFloat32(0, int(JointDoFVarsN))
ml.BodyJoints = tensor.NewInt32(0, 2, 2)
ml.BodyCollidePairs = tensor.NewInt32(0, 2)
ml.Dynamics = tensor.NewFloat32(0, 2, int(DynamicVarsN))
ml.BroadContactsN = tensor.NewInt32(1)
ml.BroadContacts = tensor.NewFloat32(0, int(ContactVarsN))
ml.ContactsN = tensor.NewInt32(1)
ml.Contacts = tensor.NewFloat32(0, int(ContactVarsN))
ml.JointControls = tensor.NewFloat32(0, int(JointControlVarsN))
ml.SetAsCurrentVars()
}
// NewObject adds a new object. Returns the CurrentObject.
func (ml *Model) NewObject() int32 {
params := &ml.Params[0]
sizes := ml.Objects.ShapeSizes()
idx := int32(sizes[0])
ml.Objects.SetShapeSizes(int(idx+1), int(params.MaxObjectJoints+1))
params.ObjectsN = idx + 1
ml.CurrentObject = int(idx)
ml.CurrentObjectJoint = 0
return idx
}
// NewBody adds a new body with given parameters. Returns the index.
// Use this for Static elements; NewDynamic for dynamic elements.
func (ml *Model) NewBody(shape Shapes, hsize, pos math32.Vector3, rot math32.Quat) int32 {
sizes := ml.Bodies.ShapeSizes()
idx := int32(sizes[0])
ml.Bodies.SetShapeSizes(int(idx+1), int(BodyVarsN))
ml.Params[0].BodiesN = idx + 1
SetBodyShape(idx, shape)
SetBodyDynamic(idx, -1)
if shape == Capsule {
hsize.Y = max(hsize.Y, hsize.X*1.01)
}
SetBodyHSize(idx, hsize)
SetBodyPos(idx, pos)
SetBodyQuat(idx, rot)
SetBodyGroup(idx, -1) // assume static
SetBodyWorld(idx, int32(ml.CurrentWorld))
ml.SetMass(idx, shape, hsize, 0) // assume static
return idx
}
// NewDynamic adds a new dynamic body with given parameters. Returns the index.
// Shape cannot be [Plane].
func (ml *Model) NewDynamic(shape Shapes, mass float32, hsize, pos math32.Vector3, rot math32.Quat) (bodyIdx, dynIdx int32) {
if shape == Plane {
panic("physics.NewDynamic: shape cannot be Plane")
}
bodyIdx = ml.NewBody(shape, hsize, pos, rot)
sizes := ml.Dynamics.ShapeSizes()
dynIdx = int32(sizes[0])
ml.Dynamics.SetShapeSizes(int(dynIdx+1), 2, int(DynamicVarsN))
ml.Params[0].DynamicsN = dynIdx + 1
SetDynamicBody(dynIdx, bodyIdx)
SetBodyDynamic(bodyIdx, dynIdx)
SetBodyGroup(bodyIdx, 1) // dynamic
ml.SetMass(bodyIdx, shape, hsize, mass)
return
}
// SetAsCurrent sets these as the current global values that are
// processed in the code (on the GPU). If this was not the setter of
// the current variables, then the parameter variables are copied up
// to the GPU.
func (ml *Model) SetAsCurrent() {
isCur := (Bodies == ml.Bodies)
CurModel = ml
ml.SetAsCurrentVars()
if GPUInitialized && !isCur {
ml.ToGPUInfra()
}
}
// SetAsCurrentVars sets these as the current global values that are
// processed in the code (on the GPU).
func (ml *Model) SetAsCurrentVars() {
Params = ml.Params
Bodies = ml.Bodies
Objects = ml.Objects
Joints = ml.Joints
JointDoFs = ml.JointDoFs
BodyJoints = ml.BodyJoints
BodyCollidePairs = ml.BodyCollidePairs
Dynamics = ml.Dynamics
BroadContactsN = ml.BroadContactsN
BroadContacts = ml.BroadContacts
ContactsN = ml.ContactsN
Contacts = ml.Contacts
JointControls = ml.JointControls
}
// GPUInit initializes the GPU and transfers Infra.
// Should have already called SetAsCurrent (needed for CPU and GPU).
func (ml *Model) GPUInit() {
GPUInit()
UseGPU = ml.GPU
ml.ToGPUInfra()
}
// ToGPUInfra copies all the infrastructure for these filters up to
// the GPU. This is done in GPUInit, and if current switched.
func (ml *Model) ToGPUInfra() {
ToGPUTensorStrides()
ToGPU(ParamsVar, BodiesVar, ObjectsVar, JointsVar, JointDoFsVar, BodyJointsVar, BodyCollidePairsVar, DynamicsVar, BroadContactsNVar, BroadContactsVar, ContactsNVar, ContactsVar, JointControlsVar)
}
// ReplicasBodyIndexes returns the body and dynamics (if dynamic) indexes
// for given replica world and source body index, if ReplicasN is > 0.
// Otherwise, returns bi and corresponding dynamic index.
func (ml *Model) ReplicasBodyIndexes(bi, replica int32) (bodyIdx, dynIdx int32) {
start := ml.ReplicaBodiesStart
n := ml.ReplicaBodiesN
if ml.ReplicasN == 0 || bi < start {
return bi, GetBodyDynamic(bi)
}
rbi := (bi - start) % n
bodyIdx = start + rbi + replica*n
dynIdx = GetBodyDynamic(bodyIdx)
return
}
// ReplicasJointIndex returns the joint indexe for given replica
// world and source body index, if ReplicasN is > 0.
// Otherwise, returns bi and corresponding dynamic index.
func (ml *Model) ReplicasJointIndex(ji, replica int32) int32 {
start := ml.ReplicaJointsStart
n := ml.ReplicaJointsN
if ml.ReplicasN == 0 || ji < start {
return ji
}
rji := (ji - start) % n
nji := start + rji + replica*n
return nji
}
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package physics
import (
"cogentcore.org/core/math32"
"cogentcore.org/lab/gosl/slbool"
"cogentcore.org/lab/gosl/slvec"
)
//gosl:start
// PhysicsParams are the physics parameters
type PhysicsParams struct {
// Iterations is the number of integration iterations to perform
// within each solver step. Muller et al (2020) report that 1 is best.
Iterations int32 `default:"1"`
// Dt is the integration stepsize.
// For highly kinetic situations (e.g., rapidly moving bouncing balls)
// 0.0001 is needed to ensure contact registration. Use SubSteps to
// accomplish a target effective read-out step size.
Dt float32 `default:"0.0001"`
// SubSteps is the number of integration steps to take per Step()
// function call. These sub steps are taken without any sync to/from
// the GPU and are therefore much faster.
SubSteps int32 `default:"10" min:"1"`
// ControlDt is the stepsize for integrating joint control position values
// [JointTargetPos] over time, to avoid sudden strong changes in force.
// For higher-DoF joints (e.g., Ball), this can be important for stability,
// but it can also result in under-shoot of the target position.
ControlDt float32 `default:"1,0.1"`
// ControlDtThr is the threshold on the control delta above which
// ControlDt is used. ControlDt is most important for large changes,
// and can result in under-shoot if engaged for small changes.
ControlDtThr float32 `default:"1"`
// Contact margin is the extra distance for broadphase collision
// around rigid bodies. This can make some joints potentially unstable if > 0
ContactMargin float32 `default:"0,0.1"`
// ContactRelax is rigid contact relaxation constant.
// Higher values cause errros
ContactRelax float32 `default:"0.8"` // 0.8 def
// Contact weighting: balances contact forces?
ContactWeighting slbool.Bool `default:"true"` // true
// Restitution takes into account bounciness of objects.
Restitution slbool.Bool `default:"false"` // false
// JointLinearRelax is joint linear relaxation constant.
JointLinearRelax float32 `default:"0.7"` // 0.7 def
// JointAngularRelax is joint angular relaxation constant.
JointAngularRelax float32 `default:"0.4"` // 0.4 def
// JointLinearComply is joint linear compliance constant.
JointLinearComply float32 `default:"0"` // 0 def
// JointAngularComply is joint angular compliance constant.
JointAngularComply float32 `default:"0"` // 0 def
// AngularDamping is damping of angular motion.
AngularDamping float32 `default:"0"` // 0 def
// SoftRelax is soft-body relaxation constant.
SoftRelax float32 `default:"0.9"`
// MaxForce is the maximum computed force value, which prevents
// runaway numerical overflow.
MaxForce float32 `default:"1e5"`
// MaxDelta is the maximum computed change in position magnitude,
// which prevents runaway numerical overflow.
MaxDelta float32 `default:"2"`
// MaxGeomIter is number of iterations to perform in shape-based
// geometry collision computations
MaxGeomIter int32 `default:"10"`
// Maximum number of contacts to process at any given point.
ContactsMax int32 `edit:"-"`
// Index for the current state (0 or 1, alternates with Next).
Cur int32 `edit:"-"`
// Index for the next state (1 or 0, alternates with Cur).
Next int32 `edit:"-"`
// BodiesN is number of rigid bodies.
BodiesN int32 `edit:"-"`
// DynamicsN is number of dynamics bodies.
DynamicsN int32 `edit:"-"`
// ObjectsN is number of objects.
ObjectsN int32 `edit:"-"`
// MaxObjectJoints is max number of joints per object.
MaxObjectJoints int32 `edit:"-"`
// JointsN is number of joints.
JointsN int32 `edit:"-"`
// JointDoFsN is number of joint DoFs.
JointDoFsN int32 `edit:"-"`
// BodyJointsMax is max number of joints per body + 1 for actual n.
BodyJointsMax int32 `edit:"-"`
// BodyCollidePairsN is the total number of pre-compiled collision pairs
// to examine.
BodyCollidePairsN int32 `edit:"-"`
pad, pad1, pad2 int32
// Gravity is the gravity acceleration function
Gravity slvec.Vector3
}
func (pr *PhysicsParams) Defaults() {
pr.Iterations = 1
pr.Dt = 0.0001
pr.SubSteps = 10
pr.ControlDt = 1
pr.ControlDtThr = 1
pr.Gravity.Set(0, -9.81, 0)
pr.ContactMargin = 0
pr.ContactRelax = 0.8
pr.ContactWeighting.SetBool(true)
pr.Restitution.SetBool(false)
pr.JointLinearRelax = 0.7
pr.JointAngularRelax = 0.4
pr.JointLinearComply = 0
pr.JointAngularComply = 0
pr.AngularDamping = 0
pr.SoftRelax = 0.9
pr.MaxForce = 1.0e5
pr.MaxDelta = 2
pr.MaxGeomIter = 10
}
// Reset resets the N's
func (pr *PhysicsParams) Reset() {
pr.BodiesN = 0
pr.DynamicsN = 0
pr.ObjectsN = 0
pr.MaxObjectJoints = 0
pr.JointsN = 0
pr.JointDoFsN = 0
pr.BodyJointsMax = 0
pr.BodyCollidePairsN = 0
}
// StepsToMsec returns the given number of individual Step calls
// converted into milliseconds, suitable for driving controls.
func (pr *PhysicsParams) StepsToMsec(steps int) int {
msper := 1000 * pr.Dt * float32(pr.SubSteps)
return int(math32.Round(float32(steps) * msper))
}
// StepsToMsec returns the given number of individual Step calls
// converted into milliseconds, suitable for driving controls,
// Using the currently-set Params.
func StepsToMsec(steps int) int {
return GetParams(0).StepsToMsec(steps)
}
//gosl:end
// Copyright (c) 2019, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package phyxyz
import (
"image"
"cogentcore.org/core/math32"
)
// Camera defines the properties of a camera needed for rendering from a node.
type Camera struct {
// size of image to record
Size image.Point
// field of view in degrees
FOV float32
// near plane z coordinate
Near float32 `default:"0.01"`
// far plane z coordinate
Far float32 `default:"1000"`
// maximum distance for depth maps. Anything above is 1.
// This is independent of Near / Far rendering (though must be < Far)
// and is for normalized depth maps.
MaxD float32 `default:"20"`
// use the natural log of 1 + depth for normalized depth values in display etc.
LogD bool `default:"true"`
// number of multi-samples to use for antialising -- 4 is best and default.
MSample int `default:"4"`
// up direction for camera. Defaults to positive Y axis,
// and is reset by call to LookAt method.
UpDir math32.Vector3
}
func (cm *Camera) Defaults() {
cm.Size = image.Point{320, 180}
cm.FOV = 30
cm.Near = .01
cm.Far = 1000
cm.MaxD = 20
cm.LogD = true
cm.MSample = 4
cm.UpDir = math32.Vec3(0, 1, 0)
}
// Copyright (c) 2019, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package phyxyz
import (
"image"
"cogentcore.org/core/base/slicesx"
"cogentcore.org/core/colors/colormap"
"cogentcore.org/core/math32"
)
// DepthNorm renders a normalized linear depth map from GPU (0-1 normalized floats) to
// given float slice, which is resized if not already appropriate size.
// if flipY then Y axis is flipped -- input is bottom-Y = 0.
// Camera params determine whether log is used, and max cutoff distance for sensitive
// range of distances -- also has Near / Far required to transform numbers into
// linearized distance values.
func DepthNorm(nd *[]float32, depth []float32, cam *Camera, flipY bool) {
sz := cam.Size
totn := sz.X * sz.Y
*nd = slicesx.SetLength(*nd, totn)
fpn := cam.Far + cam.Near
fmn := cam.Far - cam.Near
var norm float32
if cam.LogD {
norm = 1 / math32.Log(1+cam.MaxD)
} else {
norm = 1 / cam.MaxD
}
twonf := (2.0 * cam.Near * cam.Far)
for y := 0; y < sz.Y; y++ {
for x := 0; x < sz.X; x++ {
oi := y*sz.X + x
ii := oi
if flipY {
ii = (sz.Y-y-1)*sz.X + x
}
d := depth[ii]
z := d*2 - 1 // convert from 0..1 to -1..1
lind := twonf / (fpn - (z * fmn)) // untransform
effd := float32(1)
if lind < cam.MaxD {
if cam.LogD {
effd = norm * math32.Log(1+lind)
} else {
effd = norm * lind
}
}
(*nd)[oi] = effd
}
}
}
// DepthImage renders an image of linear depth map from GPU (0-1 normalized floats) to
// given image, which must be of appropriate size for map, using given colormap name.
// Camera params determine whether log is used, and max cutoff distance for sensitive
// range of distances -- also has Near / Far required to transform numbers into
// linearized distance values. Y axis is always flipped.
func DepthImage(img *image.RGBA, depth []float32, cmap *colormap.Map, cam *Camera) {
if img == nil {
return
}
sz := img.Bounds().Size()
fpn := cam.Far + cam.Near
fmn := cam.Far - cam.Near
var norm float32
if cam.LogD {
norm = 1 / math32.Log(1+cam.MaxD)
} else {
norm = 1 / cam.MaxD
}
twonf := (2.0 * cam.Near * cam.Far)
for y := 0; y < sz.Y; y++ {
for x := 0; x < sz.X; x++ {
ii := (sz.Y-y-1)*sz.X + x // always flip for images
d := depth[ii]
z := d*2 - 1 // convert from 0..1 to -1..1
lind := twonf / (fpn - (z * fmn)) // untransform
effd := float32(1)
if lind < cam.MaxD {
if cam.LogD {
effd = norm * math32.Log(1+lind)
} else {
effd = norm * lind
}
}
clr := cmap.Map(effd)
img.Set(x, y, clr)
}
}
}
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package phyxyz
import (
"fmt"
"time"
"cogentcore.org/core/colors"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/icons"
"cogentcore.org/core/math32"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/abilities"
"cogentcore.org/core/tree"
"cogentcore.org/core/xyz"
"cogentcore.org/core/xyz/xyzcore"
_ "cogentcore.org/lab/gosl/slbool/slboolcore" // include to get gui views
"cogentcore.org/lab/physics"
)
// Editor provides a basic viewer and parameter controller widget
// for exploring physics models. It creates and manages its own
// [physics.Model] and [phyxyz.Scene].
type Editor struct { //types:add
core.Frame
// Model has the physics simulation.
Model *physics.Model
// Scene has the 3D GUI visualization.
Scene *Scene
// UserParams is a struct with parameters for configuring the physics sim.
// These are displayed in the editor.
UserParams any
// ConfigFunc is the function that configures the [physics.Model].
ConfigFunc func()
// ControlFunc is the function that sets control parameters,
// based on the current timestep (in milliseconds, converted from physics time).
ControlFunc func(timeStep int)
// CameraPos provides the default initial camera position, looking at the origin.
// Set this to larger numbers to zoom out, and smaller numbers to zoom in.
// Defaults to math32.Vec3(0, 25, 20).
CameraPos math32.Vector3
// Replica is the replica world to view, if replicas are present in model.
Replica int
// IsRunning is true if currently running sim.
isRunning bool
// Stop triggers topping of running.
stop bool
// TimeStep is current time step in physics update cycles.
TimeStep int
// editor is the xyz GUI visualization widget.
editor *xyzcore.SceneEditor
// Toolbar is the top toolbar.
toolbar *core.Toolbar
// Splits is the container for elements.
splits *core.Splits
// UserParamsForm has the user's config parameters.
userParamsForm *core.Form
// ParamsForm has the Physics parameters.
paramsForm *core.Form
}
func (pe *Editor) CopyFieldsFrom(frm tree.Node) {
fr := frm.(*Editor)
pe.Frame.CopyFieldsFrom(&fr.Frame)
}
func (pe *Editor) Init() {
pe.Frame.Init()
pe.CameraPos = math32.Vec3(0, 25, 20)
pe.Styler(func(s *styles.Style) {
s.Grow.Set(1, 1)
s.Direction = styles.Column
})
tree.AddChildAt(pe, "tb", func(w *core.Toolbar) {
pe.toolbar = w
w.Maker(pe.MakeToolbar)
})
tree.AddChildAt(pe, "splits", func(w *core.Splits) {
pe.splits = w
pe.splits.SetSplits(0.2, 0.8)
tree.AddChildAt(w, "forms", func(w *core.Frame) {
w.Styler(func(s *styles.Style) {
s.Direction = styles.Column
s.Grow.Set(1, 1)
})
tree.AddChildAt(w, "users", func(w *core.Form) {
pe.userParamsForm = w
})
tree.AddChildAt(w, "params", func(w *core.Form) {
pe.paramsForm = w
if pe.UserParams != nil {
pe.userParamsForm.SetStruct(pe.UserParams)
}
params := &pe.Model.Params[0]
pe.paramsForm.SetStruct(params)
})
})
tree.AddChildAt(w, "scene", func(w *xyzcore.SceneEditor) {
pe.editor = w
w.UpdateWidget()
sc := pe.editor.SceneXYZ()
sc.Background = colors.Scheme.Select.Container
xyz.NewAmbient(sc, "ambient", 0.3, xyz.DirectSun)
dir := xyz.NewDirectional(sc, "dir", 1, xyz.DirectSun)
dir.Pos.Set(0, 2, 1)
pe.Scene = NewScene(sc)
pe.Model = physics.NewModel()
sc.Camera.Pose.Pos = math32.Vec3(0, 40, 3.5)
sc.Camera.LookAt(math32.Vec3(0, 5, 0), math32.Vec3(0, 1, 0))
sc.SaveCamera("3")
sc.Camera.Pose.Pos = math32.Vec3(-1.33, 2.24, 3.55)
sc.Camera.LookAt(math32.Vec3(0, .5, 0), math32.Vec3(0, 1, 0))
sc.SaveCamera("2")
sc.Camera.Pose.Pos = pe.CameraPos
sc.Camera.LookAt(math32.Vec3(0, 0, 0), math32.Vec3(0, 1, 0))
sc.SaveCamera("1")
sc.SaveCamera("default")
pe.ConfigModel()
})
})
}
// ConfigModel configures the physics model.
func (pe *Editor) ConfigModel() {
if pe.isRunning {
core.MessageSnackbar(pe, "Simulation is still running...")
return
}
pe.Scene.Reset()
pe.Model.Reset()
if pe.ConfigFunc != nil {
pe.ConfigFunc()
}
pe.Scene.Init(pe.Model)
pe.stop = false
pe.TimeStep = 0
pe.editor.NeedsRender()
}
// Restart restarts the simulation, returning true if successful (i.e., not running).
func (pe *Editor) Restart() bool {
if pe.isRunning {
core.MessageSnackbar(pe, "Simulation is still running...")
return false
}
pe.stop = false
pe.TimeStep = 0
pe.Scene.InitState(pe.Model)
pe.editor.NeedsRender()
return true
}
// Step steps the world n times, with updates. Must be called as a goroutine.
func (pe *Editor) Step(n int) {
if pe.isRunning {
return
}
pe.isRunning = true
pe.Model.SetAsCurrent()
pe.toolbar.AsyncLock()
pe.toolbar.UpdateRender()
pe.toolbar.AsyncUnlock()
for range n {
if pe.ControlFunc != nil {
pe.ControlFunc(physics.StepsToMsec(pe.TimeStep))
}
pe.Model.Step()
pe.TimeStep++
pe.Scene.Update()
pe.editor.AsyncLock()
pe.editor.NeedsRender()
pe.editor.AsyncUnlock()
if !pe.Model.GPU {
time.Sleep(time.Nanosecond) // this is essential for web (wasm) running to actually update
// if running in GPU mode, it works, but otherwise the thread never yields and it never updates.
}
if pe.stop {
pe.stop = false
break
}
}
pe.isRunning = false
pe.AsyncLock()
pe.Update()
pe.AsyncUnlock()
}
func (pe *Editor) MakeToolbar(p *tree.Plan) {
stepNButton := func(n int) {
nm := fmt.Sprintf("Step %d", n)
tree.AddAt(p, nm, func(w *core.Button) {
w.FirstStyler(func(s *styles.Style) { s.SetEnabled(!pe.isRunning) })
w.SetText(nm).SetIcon(icons.PlayArrow).
SetTooltip(fmt.Sprintf("Step state %d times", n)).
OnClick(func(e events.Event) {
if pe.isRunning {
fmt.Println("still running...")
return
}
go pe.Step(n)
})
w.Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.RepeatClickable)
})
})
}
tree.Add(p, func(w *core.Button) {
w.SetText("Restart").SetIcon(icons.Reset).
SetTooltip("Reset physics state back to starting.").
OnClick(func(e events.Event) {
pe.Restart()
})
w.FirstStyler(func(s *styles.Style) { s.SetEnabled(!pe.isRunning) })
})
tree.Add(p, func(w *core.Button) {
w.SetText("Stop").SetIcon(icons.Stop).
SetTooltip("Stop running").
OnClick(func(e events.Event) {
pe.stop = true
})
w.FirstStyler(func(s *styles.Style) { s.SetEnabled(pe.isRunning) })
})
tree.Add(p, func(w *core.Separator) {})
stepNButton(1)
stepNButton(10)
stepNButton(100)
stepNButton(1000)
stepNButton(10000)
tree.Add(p, func(w *core.Separator) {})
tree.Add(p, func(w *core.Button) {
w.SetText("Rebuild").SetIcon(icons.Reset).
SetTooltip("Rebuild the environment, when you change parameters").
OnClick(func(e events.Event) {
pe.ConfigModel()
})
w.FirstStyler(func(s *styles.Style) { s.SetEnabled(!pe.isRunning) })
})
tree.Add(p, func(w *core.Separator) {})
tt := "Replica world to view"
tree.Add(p, func(w *core.Text) { w.SetText("Replica:").SetTooltip(tt) })
tree.Add(p, func(w *core.Spinner) {
core.Bind(&pe.Replica, w)
w.SetMin(0).SetTooltip(tt)
w.Styler(func(s *styles.Style) {
replN := int32(0)
if physics.CurModel != nil && pe.Scene != nil {
replN = physics.CurModel.ReplicasN
pe.Scene.ReplicasView = replN > 0
}
w.SetMax(float32(replN - 1))
s.SetEnabled(replN > 1)
})
w.OnChange(func(e events.Event) {
pe.Scene.ReplicasIndex = pe.Replica
pe.Scene.Update()
pe.NeedsRender()
})
})
}
// Copyright (c) 2019, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package phyxyz
import (
"image"
"cogentcore.org/core/gpu"
"cogentcore.org/core/xyz"
)
// NoDisplayScene returns a xyz Scene initialized and ready to use
// in NoGUI offscreen rendering mode, using given GPU and device.
// Must manually call Init3D and Style3D on the Scene prior to
// a RenderFromNode call to grab the image from a specific camera.
func NoDisplayScene(gp *gpu.GPU, dev *gpu.Device) *xyz.Scene {
sc := xyz.NewScene()
sc.MultiSample = 4
sc.Geom.Size = image.Point{1024, 768}
sc.ConfigOffscreen(gp, dev)
return sc
}
// Copyright (c) 2019, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// package phyxyz implements visualization of [physics] using [xyz]
// 3D graphics.
package phyxyz
//go:generate core generate -add-types
import (
"image"
"cogentcore.org/core/math32"
"cogentcore.org/core/tree"
"cogentcore.org/core/xyz"
"cogentcore.org/lab/physics"
)
// Scene displays a [physics.Model] using a [xyz.Scene].
// One Scene can be used for multiple different [physics.Model]s which
// is more efficient when running multiple in parallel.
// Initial construction of the physics and visualization happens here.
type Scene struct {
// Scene is the [xyz.Scene] object for visualizing.
Scene *xyz.Scene
// Root is the root Group node in the Scene under which the world is rendered.
Root *xyz.Group
// Skins are the view elements for each body in [physics.Model].
Skins []*Skin
// ReplicasView enables viewing of different replicated worlds
// using the same skins.
ReplicasView bool
// ReplicasIndex is the replicated world to view.
ReplicasIndex int
}
// NewScene returns a new Scene for visualizing a [physics.Model].
// with given [xyz.Scene], making a top-level Root group in the scene.
func NewScene(sc *xyz.Scene) *Scene {
rgp := xyz.NewGroup(sc)
rgp.SetName("world")
xysc := &Scene{Scene: sc, Root: rgp}
return xysc
}
// Init configures the visual world based on Skins,
// and calls Config on [physics.Model].
// Call this _once_ after making all the new Skins and Bodies.
// (will return if already called). This calls Update().
func (sc *Scene) Init(ml *physics.Model) {
ml.Config()
if ml.ReplicasN > 0 {
sc.ReplicasView = true
} else {
sc.ReplicasView = false
}
if len(sc.Root.Makers.Normal) > 0 {
sc.Update()
return
}
sc.Root.Maker(func(p *tree.Plan) {
for _, sk := range sc.Skins {
sk.Add(p)
}
})
sc.Update()
}
// InitState calls InitState on the Model and then Update.
func (sc *Scene) InitState(ml *physics.Model) {
ml.InitState()
sc.Update()
}
// Reset resets any existing views, starting fresh for a new configuration.
func (sc *Scene) Reset() {
sc.Skins = nil
if sc.Scene != nil {
sc.Scene.Update()
}
}
// Update updates the xyz scene from current physics node state.
// (use physics.Model.SetAsCurrent()).
func (sc *Scene) Update() {
sc.UpdateFromPhysics()
if sc.Scene != nil {
sc.Scene.Update()
}
}
// UpdateFromPhysics updates the Scene from currently active
// physics state (use physics.Model.SetAsCurrent()).
func (sc *Scene) UpdateFromPhysics() {
for _, sk := range sc.Skins {
sk.UpdateFromPhysics(sc)
}
}
// RenderFrom does an offscreen render using given [Skin]
// for the camera position and orientation, returning the render image(s)
// for each replicated world (1 if no replicas).
// Current scene camera is saved and restored.
func (sc *Scene) RenderFrom(sk *Skin, cam *Camera) []image.Image {
xysc := sc.Scene
camnm := "scene-renderfrom-save"
xysc.SaveCamera(camnm)
rep := sc.ReplicasIndex
xysc.Camera.FOV = cam.FOV
xysc.Camera.Near = cam.Near
xysc.Camera.Far = cam.Far
xysc.Camera.Pose.Pos = sk.Pos
xysc.Camera.Pose.Quat = sk.Quat
xysc.Camera.Pose.Scale.Set(1, 1, 1)
xysc.UseAltFrame(cam.Size)
ml := physics.CurModel
var imgs []image.Image
if sc.ReplicasView {
imgs = make([]image.Image, ml.ReplicasN)
for i := range ml.ReplicasN {
sc.ReplicasIndex = int(i)
sc.Update() // full Update needed, beyond just UpdateFromPhysics.
xysc.Camera.Pose.Pos = sk.Pos
xysc.Camera.Pose.Quat = sk.Quat
img := xysc.RenderGrabImage()
imgs[i] = img
}
sc.ReplicasIndex = rep
sc.UpdateFromPhysics()
} else {
img := xysc.RenderGrabImage()
imgs = []image.Image{img}
}
xysc.SetCamera(camnm)
xysc.UseMainFrame()
return imgs
}
// DepthImage returns the current rendered depth image
// func (vw *Scene) DepthImage() ([]float32, error) {
// return vw.Scene.DepthImage()
// }
func (sc *Scene) NewSkin(shape physics.Shapes, name, clr string, hsize math32.Vector3, pos math32.Vector3, rot math32.Quat) *Skin {
sk := &Skin{Name: name, Shape: shape, Color: clr, HSize: hsize, DynamicIndex: -1, Pos: pos, Quat: rot}
sc.Skins = append(sc.Skins, sk)
return sk
}
// AddSkinClone adds a cloned version of given skin.
func (sc *Scene) AddSkinClone(sk *Skin) {
nsk := &Skin{}
*nsk = *sk
sc.Skins = append(sc.Skins, sk)
}
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package phyxyz
import (
"fmt"
"strconv"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/colors"
"cogentcore.org/core/math32"
"cogentcore.org/core/tree"
"cogentcore.org/core/xyz"
"cogentcore.org/lab/physics"
)
// Skin has visualization functions for physics elements.
type Skin struct { //types:add -setters
// Name is a name for element (index always appended, so it is unique).
Name string
// Shape is the physical shape of the element.
Shape physics.Shapes
// Color is the color of the element.
Color string
// HSize is the half-size (e.g., radius) of the body.
// Values depend on shape type: X is generally radius,
// Y is half-height.
HSize math32.Vector3
// Pos is the position.
Pos math32.Vector3
// Quat is the rotation as a quaternion.
Quat math32.Quat
// NewSkin is a function that returns a new [xyz.Node]
// to represent this element. If nil, uses appropriate defaults.
NewSkin func() tree.Node
// InitSkin is a function that initializes a new [xyz.Node]
// that represents this element. If nil, uses appropriate defaults.
InitSkin func(sld *xyz.Solid)
// BodyIndex is the index of the body in [physics.Bodies]
BodyIndex int32
// DynamicIndex is the index in [physics.Dynamics] (-1 if not dynamic).
DynamicIndex int32
}
// NewBody adds a new body with given parameters.
// Returns the Skin which can then be further customized.
// Use this for Static elements; NewDynamic for dynamic elements.
func (sc *Scene) NewBody(ml *physics.Model, name string, shape physics.Shapes, clr string, hsize, pos math32.Vector3, rot math32.Quat) *Skin {
idx := ml.NewBody(shape, hsize, pos, rot)
sk := sc.NewSkin(shape, name, clr, hsize, pos, rot)
sk.SetBodyIndex(idx)
return sk
}
// NewDynamic adds a new dynamic body with given parameters.
// Returns the Skin which can then be further customized.
func (sc *Scene) NewDynamic(ml *physics.Model, name string, shape physics.Shapes, clr string, mass float32, hsize, pos math32.Vector3, rot math32.Quat) *Skin {
idx, dyIdx := ml.NewDynamic(shape, mass, hsize, pos, rot)
sk := sc.NewSkin(shape, name, clr, hsize, pos, rot)
sk.SetBodyIndex(idx).SetDynamicIndex(dyIdx)
return sk
}
// UpdateFromPhysics updates the Skin from physics state.
func (sk *Skin) UpdateFromPhysics(sc *Scene) {
params := physics.GetParams(0)
di := int32(sk.DynamicIndex)
bi := int32(sk.BodyIndex)
if sc.ReplicasView {
bi, di = physics.CurModel.ReplicasBodyIndexes(bi, int32(sc.ReplicasIndex))
}
if di >= 0 {
sk.Pos = physics.DynamicPos(di, params.Cur)
sk.Quat = physics.DynamicQuat(di, params.Cur)
} else {
sk.Pos = physics.BodyPos(bi)
sk.Quat = physics.BodyQuat(bi)
}
}
// UpdatePose updates the xyz node pose from skin.
func (sk *Skin) UpdatePose(sld *xyz.Solid) {
sld.Pose.Pos = sk.Pos
sld.Pose.Quat = sk.Quat
}
// UpdateColor updates the xyz node color from skin.
func (sk *Skin) UpdateColor(clr string, sld *xyz.Solid) {
if clr == "" {
return
}
sld.Material.Color = errors.Log1(colors.FromString(clr))
}
// Add adds given physics node to the [tree.Plan], using NewSkin
// function on the node, or default.
func (sk *Skin) Add(p *tree.Plan) {
nm := sk.Name + strconv.Itoa(int(sk.BodyIndex))
newFunc := sk.NewSkin
if newFunc == nil {
newFunc = func() tree.Node {
return any(tree.New[xyz.Solid]()).(tree.Node)
}
}
p.Add(nm, newFunc, func(n tree.Node) { sk.Init(n.(*xyz.Solid)) })
}
// Init initializes xyz node using InitSkin function or default.
func (sk *Skin) Init(sld *xyz.Solid) {
initFunc := sk.InitSkin
if initFunc != nil {
initFunc(sld)
return
}
switch sk.Shape {
case physics.Plane:
sk.PlaneInit(sld)
case physics.Sphere:
sk.SphereInit(sld)
case physics.Capsule:
sk.CapsuleInit(sld)
case physics.Cylinder:
sk.CylinderInit(sld)
case physics.Box:
sk.BoxInit(sld)
}
}
// BoxInit is the default InitSkin function for [physics.Box].
// Only updates Pose in Updater: if node will change size or color,
// add updaters for that.
func (sk *Skin) BoxInit(sld *xyz.Solid) {
mnm := "physics.Box"
if ms, _ := sld.Scene.MeshByName(mnm); ms == nil {
xyz.NewBox(sld.Scene, mnm, 1, 1, 1)
}
sld.SetMeshName(mnm)
sld.Pose.Scale = sk.HSize.MulScalar(2)
sk.UpdateColor(sk.Color, sld)
sld.Updater(func() {
sk.UpdatePose(sld)
})
}
// PlaneInit is the default InitSkin function for [physics.Plane].
// Only updates Pose in Updater: if node will change size or color,
// add updaters for that.
func (sk *Skin) PlaneInit(sld *xyz.Solid) {
mnm := "physics.Plane"
if ms, _ := sld.Scene.MeshByName(mnm); ms == nil {
pl := xyz.NewPlane(sld.Scene, mnm, 1, 1)
pl.Segs.Set(4, 4)
}
sld.SetMeshName(mnm)
if sk.HSize.X == 0 {
inf := float32(1e3)
sld.Pose.Scale = math32.Vec3(inf, 1, inf)
} else {
sld.Pose.Scale = sk.HSize.MulScalar(2)
}
sk.UpdateColor(sk.Color, sld)
sld.Updater(func() {
sk.UpdatePose(sld)
})
}
// CylinderInit is the default InitSkin function for [physics.Cylinder].
// Only updates Pose in Updater: if node will change size or color,
// add updaters for that.
func (sk *Skin) CylinderInit(sld *xyz.Solid) {
mnm := "physics.Cylinder"
if ms, _ := sld.Scene.MeshByName(mnm); ms == nil {
xyz.NewCylinder(sld.Scene, mnm, 1, 1, 32, 1, true, true)
}
sld.SetMeshName(mnm)
sld.Pose.Scale = sk.HSize
sld.Pose.Scale.Y *= 2
sk.UpdateColor(sk.Color, sld)
sld.Updater(func() {
sk.UpdatePose(sld)
})
}
// CapsuleInit is the default InitSkin function for [physics.Capsule].
// Only updates Pose in Updater: if node will change size or color,
// add updaters for that.
func (sk *Skin) CapsuleInit(sld *xyz.Solid) {
rat := sk.HSize.Y / sk.HSize.X
mnm := fmt.Sprintf("physics.Capsule_%g", math32.Truncate(rat, 3))
if ms, _ := sld.Scene.MeshByName(mnm); ms == nil {
ms = xyz.NewCapsule(sld.Scene, mnm, 2*(sk.HSize.Y-sk.HSize.X)/sk.HSize.X, 1, 32, 1)
}
sld.SetMeshName(mnm)
sld.Pose.Scale.Set(sk.HSize.X, sk.HSize.X, sk.HSize.X)
sk.UpdateColor(sk.Color, sld)
sld.Updater(func() {
sk.UpdatePose(sld)
})
}
// SphereInit is the default InitSkin function for [physics.Sphere].
// Only updates Pose in Updater: if node will change size or color,
// add updaters for that.
func (sk *Skin) SphereInit(sld *xyz.Solid) {
mnm := "physics.Sphere"
if ms, _ := sld.Scene.MeshByName(mnm); ms == nil {
ms = xyz.NewSphere(sld.Scene, mnm, 1, 32)
}
sld.SetMeshName(mnm)
sld.Pose.Scale.SetScalar(sk.HSize.X)
sk.UpdateColor(sk.Color, sld)
sld.Updater(func() {
sk.UpdatePose(sld)
})
}
// SetBodyWorld partitions bodies into different worlds for
// collision detection: Global bodies = -1 can collide with
// everything; otherwise only items within the same world collide.
func (sk *Skin) SetBodyWorld(world int) {
physics.SetBodyWorld(sk.BodyIndex, int32(world))
}
// SetBodyGroup partitions bodies within worlds into different groups
// for collision detection. 0 does not collide with anything.
// Negative numbers are global within a world, except they don't
// collide amongst themselves (all non-dynamic bodies should go
// in -1 because they don't collide amongst each-other, but do
// potentially collide with dynamics).
// Positive numbers only collide amongst themselves, and with
// negative groups, but not other positive groups. This is for
// more special-purpose dynamics: in general use 1 for all dynamic
// bodies. There is an automatic constraint that the two objects
// within a single joint do not collide with each other, so this
// does not need to be handled here.
func (sk *Skin) SetBodyGroup(group int) {
physics.SetBodyGroup(sk.BodyIndex, int32(group))
}
// SetBodyBounce specifies the COR or coefficient of restitution (0..1),
// which determines how elastic the collision is,
// i.e., final velocity / initial velocity.
func (sk *Skin) SetBodyBounce(val float32) {
physics.Bodies.Set(val, int(sk.BodyIndex), int(physics.BodyBounce))
}
// SetBodyFriction is the standard coefficient for linear friction (mu).
func (sk *Skin) SetBodyFriction(val float32) {
physics.Bodies.Set(val, int(sk.BodyIndex), int(physics.BodyFriction))
}
// SetBodyFrictionTortion is resistance to spinning at the contact point.
func (sk *Skin) SetBodyFrictionTortion(val float32) {
physics.Bodies.Set(val, int(sk.BodyIndex), int(physics.BodyFrictionTortion))
}
// SetBodyFrictionRolling is resistance to rolling motion at contact.
func (sk *Skin) SetBodyFrictionRolling(val float32) {
physics.Bodies.Set(val, int(sk.BodyIndex), int(physics.BodyFrictionRolling))
}
// NewJointFixed adds a new Fixed joint between given parent and child.
// Use nil for parent to add a world-anchored joint.
// ppos, cpos are the relative positions from the parent, child.
// These are for the non-rotated body (i.e., body rotation is applied
// to these positions as well).
// Sets relative rotation matricies to identity by default.
func (sc *Scene) NewJointFixed(ml *physics.Model, parent, child *Skin, ppos, cpos math32.Vector3) int32 {
pidx := int32(-1)
if parent != nil {
pidx = parent.DynamicIndex
}
return ml.NewJointFixed(pidx, child.DynamicIndex, ppos, cpos)
}
// NewJointPrismatic adds a new Prismatic (slider) joint between given
// parent and child. Use nil for parent to add a world-anchored joint.
// ppos, cpos are the relative positions from the parent, child.
// These are for the non-rotated body (i.e., body rotation is applied
// to these positions as well).
// Sets relative rotation matricies to identity by default.
// axis is the axis of articulation for the joint.
// Use [SetJointDoF] to set the remaining DoF parameters.
func (sc *Scene) NewJointPrismatic(ml *physics.Model, parent, child *Skin, ppos, cpos, axis math32.Vector3) int32 {
pidx := int32(-1)
if parent != nil {
pidx = parent.DynamicIndex
}
return ml.NewJointPrismatic(pidx, child.DynamicIndex, ppos, cpos, axis)
}
// NewJointRevolute adds a new Revolute (hinge, axel) joint between given
// parent and child. Use nil for parent to add a world-anchored joint.
// ppos, cpos are the relative positions from the parent, child.
// These are for the non-rotated body (i.e., body rotation is applied
// to these positions as well).
// Sets relative rotation matricies to identity by default.
// axis is the axis of articulation for the joint.
// Use [SetJointDoF] to set the remaining DoF parameters.
func (sc *Scene) NewJointRevolute(ml *physics.Model, parent, child *Skin, ppos, cpos, axis math32.Vector3) int32 {
pidx := int32(-1)
if parent != nil {
pidx = parent.DynamicIndex
}
return ml.NewJointRevolute(pidx, child.DynamicIndex, ppos, cpos, axis)
}
// NewJointBall adds a new Ball joint (3 angular DoF) between given parent
// and child. Use nil for parent to add a world-anchored joint.
// ppos, cpos are the relative positions from the parent, child.
// These are for the non-rotated body (i.e., body rotation is applied
// to these positions as well).
// Sets relative rotation matricies to identity by default.
// Use [SetJointDoF] to set the remaining DoF parameters.
func (sc *Scene) NewJointBall(ml *physics.Model, parent, child *Skin, ppos, cpos math32.Vector3) int32 {
pidx := int32(-1)
if parent != nil {
pidx = parent.DynamicIndex
}
return ml.NewJointBall(pidx, child.DynamicIndex, ppos, cpos)
}
// NewJointDistance adds a new Distance joint (6 DoF),
// with distance constrained only on the first linear X axis,
// between given parent and child. Use nil for parent to add a world-anchored joint.
// ppos, cpos are the relative positions from the parent, child.
// These are for the non-rotated body (i.e., body rotation is applied
// to these positions as well).
// Sets relative rotation matricies to identity by default.
// Use [SetJointDoF] to set the remaining DoF parameters.
func (sc *Scene) NewJointDistance(ml *physics.Model, parent, child *Skin, ppos, cpos math32.Vector3, minDist, maxDist float32) int32 {
pidx := int32(-1)
if parent != nil {
pidx = parent.DynamicIndex
}
return ml.NewJointDistance(pidx, child.DynamicIndex, ppos, cpos, minDist, maxDist)
}
// NewJointFree adds a new Free joint between given parent and child.
// Use nil for parent to add a world-anchored joint.
// ppos, cpos are the relative positions from the parent, child.
// These are for the non-rotated body (i.e., body rotation is applied
// to these positions as well).
// Sets relative rotation matricies to identity by default.
// Use [SetJointDoF] to set the remaining DoF parameters.
func (sc *Scene) NewJointFree(ml *physics.Model, parent, child *Skin, ppos, cpos math32.Vector3) int32 {
pidx := int32(-1)
if parent != nil {
pidx = parent.DynamicIndex
}
return ml.NewJointFree(pidx, child.DynamicIndex, ppos, cpos)
}
// Code generated by "core generate -add-types"; DO NOT EDIT.
package phyxyz
import (
"cogentcore.org/core/math32"
"cogentcore.org/core/tree"
"cogentcore.org/core/types"
"cogentcore.org/core/xyz"
"cogentcore.org/lab/physics"
)
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/physics/phyxyz.Camera", IDName: "camera", Doc: "Camera defines the properties of a camera needed for rendering from a node.", Fields: []types.Field{{Name: "Size", Doc: "size of image to record"}, {Name: "FOV", Doc: "field of view in degrees"}, {Name: "Near", Doc: "near plane z coordinate"}, {Name: "Far", Doc: "far plane z coordinate"}, {Name: "MaxD", Doc: "maximum distance for depth maps. Anything above is 1.\nThis is independent of Near / Far rendering (though must be < Far)\nand is for normalized depth maps."}, {Name: "LogD", Doc: "use the natural log of 1 + depth for normalized depth values in display etc."}, {Name: "MSample", Doc: "number of multi-samples to use for antialising -- 4 is best and default."}, {Name: "UpDir", Doc: "up direction for camera. Defaults to positive Y axis,\nand is reset by call to LookAt method."}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/physics/phyxyz.Editor", IDName: "editor", Doc: "Editor provides a basic viewer and parameter controller widget\nfor exploring physics models. It creates and manages its own\n[physics.Model] and [phyxyz.Scene].", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Embeds: []types.Field{{Name: "Frame"}}, Fields: []types.Field{{Name: "Model", Doc: "Model has the physics simulation."}, {Name: "Scene", Doc: "Scene has the 3D GUI visualization."}, {Name: "UserParams", Doc: "UserParams is a struct with parameters for configuring the physics sim.\nThese are displayed in the editor."}, {Name: "ConfigFunc", Doc: "ConfigFunc is the function that configures the [physics.Model]."}, {Name: "ControlFunc", Doc: "ControlFunc is the function that sets control parameters,\nbased on the current timestep (in milliseconds, converted from physics time)."}, {Name: "CameraPos", Doc: "CameraPos provides the default initial camera position, looking at the origin.\nSet this to larger numbers to zoom out, and smaller numbers to zoom in.\nDefaults to math32.Vec3(0, 25, 20)."}, {Name: "Replica", Doc: "Replica is the replica world to view, if replicas are present in model."}, {Name: "isRunning", Doc: "IsRunning is true if currently running sim."}, {Name: "stop", Doc: "Stop triggers topping of running."}, {Name: "TimeStep", Doc: "TimeStep is current time step in physics update cycles."}, {Name: "editor", Doc: "editor is the xyz GUI visualization widget."}, {Name: "toolbar", Doc: "Toolbar is the top toolbar."}, {Name: "splits", Doc: "Splits is the container for elements."}, {Name: "userParamsForm", Doc: "UserParamsForm has the user's config parameters."}, {Name: "paramsForm", Doc: "ParamsForm has the Physics parameters."}}})
// NewEditor returns a new [Editor] with the given optional parent:
// Editor provides a basic viewer and parameter controller widget
// for exploring physics models. It creates and manages its own
// [physics.Model] and [phyxyz.Scene].
func NewEditor(parent ...tree.Node) *Editor { return tree.New[Editor](parent...) }
// SetModel sets the [Editor.Model]:
// Model has the physics simulation.
func (t *Editor) SetModel(v *physics.Model) *Editor { t.Model = v; return t }
// SetScene sets the [Editor.Scene]:
// Scene has the 3D GUI visualization.
func (t *Editor) SetScene(v *Scene) *Editor { t.Scene = v; return t }
// SetUserParams sets the [Editor.UserParams]:
// UserParams is a struct with parameters for configuring the physics sim.
// These are displayed in the editor.
func (t *Editor) SetUserParams(v any) *Editor { t.UserParams = v; return t }
// SetConfigFunc sets the [Editor.ConfigFunc]:
// ConfigFunc is the function that configures the [physics.Model].
func (t *Editor) SetConfigFunc(v func()) *Editor { t.ConfigFunc = v; return t }
// SetControlFunc sets the [Editor.ControlFunc]:
// ControlFunc is the function that sets control parameters,
// based on the current timestep (in milliseconds, converted from physics time).
func (t *Editor) SetControlFunc(v func(timeStep int)) *Editor { t.ControlFunc = v; return t }
// SetCameraPos sets the [Editor.CameraPos]:
// CameraPos provides the default initial camera position, looking at the origin.
// Set this to larger numbers to zoom out, and smaller numbers to zoom in.
// Defaults to math32.Vec3(0, 25, 20).
func (t *Editor) SetCameraPos(v math32.Vector3) *Editor { t.CameraPos = v; return t }
// SetReplica sets the [Editor.Replica]:
// Replica is the replica world to view, if replicas are present in model.
func (t *Editor) SetReplica(v int) *Editor { t.Replica = v; return t }
// SetTimeStep sets the [Editor.TimeStep]:
// TimeStep is current time step in physics update cycles.
func (t *Editor) SetTimeStep(v int) *Editor { t.TimeStep = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/physics/phyxyz.Scene", IDName: "scene", Doc: "Scene displays a [physics.Model] using a [xyz.Scene].\nOne Scene can be used for multiple different [physics.Model]s which\nis more efficient when running multiple in parallel.\nInitial construction of the physics and visualization happens here.", Fields: []types.Field{{Name: "Scene", Doc: "Scene is the [xyz.Scene] object for visualizing."}, {Name: "Root", Doc: "Root is the root Group node in the Scene under which the world is rendered."}, {Name: "Skins", Doc: "Skins are the view elements for each body in [physics.Model]."}, {Name: "ReplicasView", Doc: "ReplicasView enables viewing of different replicated worlds\nusing the same skins."}, {Name: "ReplicasIndex", Doc: "ReplicasIndex is the replicated world to view."}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/physics/phyxyz.Skin", IDName: "skin", Doc: "Skin has visualization functions for physics elements.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "Name", Doc: "Name is a name for element (index always appended, so it is unique)."}, {Name: "Shape", Doc: "Shape is the physical shape of the element."}, {Name: "Color", Doc: "Color is the color of the element."}, {Name: "HSize", Doc: "HSize is the half-size (e.g., radius) of the body.\nValues depend on shape type: X is generally radius,\nY is half-height."}, {Name: "Pos", Doc: "Pos is the position."}, {Name: "Quat", Doc: "Quat is the rotation as a quaternion."}, {Name: "NewSkin", Doc: "NewSkin is a function that returns a new [xyz.Node]\nto represent this element. If nil, uses appropriate defaults."}, {Name: "InitSkin", Doc: "InitSkin is a function that initializes a new [xyz.Node]\nthat represents this element. If nil, uses appropriate defaults."}, {Name: "BodyIndex", Doc: "BodyIndex is the index of the body in [physics.Bodies]"}, {Name: "DynamicIndex", Doc: "DynamicIndex is the index in [physics.Dynamics] (-1 if not dynamic)."}}})
// SetName sets the [Skin.Name]:
// Name is a name for element (index always appended, so it is unique).
func (t *Skin) SetName(v string) *Skin { t.Name = v; return t }
// SetShape sets the [Skin.Shape]:
// Shape is the physical shape of the element.
func (t *Skin) SetShape(v physics.Shapes) *Skin { t.Shape = v; return t }
// SetColor sets the [Skin.Color]:
// Color is the color of the element.
func (t *Skin) SetColor(v string) *Skin { t.Color = v; return t }
// SetHSize sets the [Skin.HSize]:
// HSize is the half-size (e.g., radius) of the body.
// Values depend on shape type: X is generally radius,
// Y is half-height.
func (t *Skin) SetHSize(v math32.Vector3) *Skin { t.HSize = v; return t }
// SetPos sets the [Skin.Pos]:
// Pos is the position.
func (t *Skin) SetPos(v math32.Vector3) *Skin { t.Pos = v; return t }
// SetQuat sets the [Skin.Quat]:
// Quat is the rotation as a quaternion.
func (t *Skin) SetQuat(v math32.Quat) *Skin { t.Quat = v; return t }
// SetNewSkin sets the [Skin.NewSkin]:
// NewSkin is a function that returns a new [xyz.Node]
// to represent this element. If nil, uses appropriate defaults.
func (t *Skin) SetNewSkin(v func() tree.Node) *Skin { t.NewSkin = v; return t }
// SetInitSkin sets the [Skin.InitSkin]:
// InitSkin is a function that initializes a new [xyz.Node]
// that represents this element. If nil, uses appropriate defaults.
func (t *Skin) SetInitSkin(v func(sld *xyz.Solid)) *Skin { t.InitSkin = v; return t }
// SetBodyIndex sets the [Skin.BodyIndex]:
// BodyIndex is the index of the body in [physics.Bodies]
func (t *Skin) SetBodyIndex(v int32) *Skin { t.BodyIndex = v; return t }
// SetDynamicIndex sets the [Skin.DynamicIndex]:
// DynamicIndex is the index in [physics.Dynamics] (-1 if not dynamic).
func (t *Skin) SetDynamicIndex(v int32) *Skin { t.DynamicIndex = v; return t }
// Code generated by "goal build"; DO NOT EDIT.
//line shapecollide.goal:1
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package physics
import (
"cogentcore.org/core/math32"
"cogentcore.org/lab/gosl/slmath"
)
//gosl:start
// newton: geometry/kernels.py class GeoData
// GeomData contains all geometric data for narrow-phase collision.
type GeomData struct {
BodyIdx int32
Shape Shapes
// MinSize is the min of the Size dimensions.
MinSize float32
// Thickness of shape.
Thick float32
// Radius is the effective radius for sphere-like elements (Sphere, Capsule, Cone)
Radius float32
Size math32.Vector3
// World-to-Body transform
// Position (R) (i.e., BodyPos)
WbR math32.Vector3
// Quaternion (Q) (i.e., BodyQuat)
WbQ math32.Quat
// Body-to-World transform (inverse)
// Position (R)
BwR math32.Vector3
// Quaternion (Q)
BwQ math32.Quat
}
func NewGeomData(bi, cni int32, shp Shapes) GeomData {
var gd GeomData
gd.BodyIdx = bi
gd.Shape = shp
gd.Size = BodyHSize(bi)
gd.Thick = Bodies.Value(int(bi), int(BodyThick))
gd.MinSize = min(gd.Size.X, gd.Size.Y)
gd.MinSize = min(gd.MinSize, gd.Size.Z)
gd.WbR = BodyDynamicPos(bi, cni)
gd.WbQ = BodyDynamicQuat(bi, cni)
InitGeomData(bi, &gd)
return gd
}
func InitGeomData(bi int32, gd *GeomData) {
var bwR math32.Vector3
var bwQ math32.Quat
slmath.SpatialTransformInverse(gd.WbR, gd.WbQ, &bwR, &bwQ)
gd.BwR = bwR
gd.BwQ = bwQ
gd.Radius = 0
if gd.Shape == Sphere || gd.Shape == Capsule || gd.Shape == Cone {
gd.Radius = gd.Size.X
}
}
// ContactPoints is the common final pathway for all Col* shape-specific
// collision functions, first determining if the computed distance is
// within the given margin and returning false if not (not a true contact).
// Otherwise, sets the actual points of contact and their offsets
// based on ptA, ptB, norm and dist values returned from collision functions.
// The actual distance is reduced by the radius values for Sphere,
// Capsule, and Cone types, and is returned in distActual.
// This is broken out here to support independent testing of the collision functions.
func ContactPoints(dist, margin float32, gdA *GeomData, gdB *GeomData, ptA, ptB, norm math32.Vector3, ctA, ctB, offA, offB *math32.Vector3, distActual, offMagA, offMagB *float32) bool {
thick := gdA.Thick + gdB.Thick
// Total separation required by radii and additional thicknesses
totSepReq := gdA.Radius + gdB.Radius + thick
*distActual = dist - totSepReq
if *distActual >= margin {
return false
}
// transform from world into body frame (so the contact point includes the shape transform)
*ctA = slmath.MulSpatialPoint(gdA.BwR, gdA.BwQ, ptA)
*ctB = slmath.MulSpatialPoint(gdB.BwR, gdB.BwQ, ptB)
// fmt.Println(ptA, ptB, *ctA, *ctB, gdA.BwR, gdA.BwQ)
*offMagA = gdA.Radius + gdA.Thick
*offMagB = gdB.Radius + gdB.Thick
*offA = slmath.MulQuatVector(gdA.BwQ, norm.MulScalar(-(*offMagA)))
*offB = slmath.MulQuatVector(gdB.BwQ, norm.MulScalar(*offMagB))
return true
}
/////// Collision methods: in geometry/kernels.py
// note: have to pass a non-pointer arg as first arg, due to gosl issue.
// cpi = contact point index.
// X_wb, X_ws -> WtoB
// X_bw, X_sw -> BtoW
// pAw = point in A, world coords; b = body coords
func ColSphereSphere(cpi, maxIter int32, gdA *GeomData, gdB *GeomData, pA, pB, norm *math32.Vector3) float32 {
pAw := gdA.WbR
pBw := gdB.WbR
diff := pAw.Sub(pBw)
*pA = pAw
*pB = pBw
*norm = slmath.Normal3(diff)
return slmath.Dot3(diff, *norm)
}
func ColCapsulePlane(cpi, maxIter int32, gdA *GeomData, gdB *GeomData, pA, pB, norm *math32.Vector3) float32 {
var pAw, pBw, diff math32.Vector3
hh := gdA.Size.Y - gdA.Size.X
if cpi < 2 { // vertex. Note: radius is automatically subtracted!! so this is correct with hh
side := float32(cpi)*2 - 1
pAw = slmath.MulSpatialPoint(gdA.WbR, gdA.WbQ, math32.Vec3(0, side*hh, 0))
queryB := slmath.MulSpatialPoint(gdB.BwR, gdB.BwQ, pAw)
pBb := ClosestPointPlane(gdB.Size.X, gdB.Size.Z, queryB)
pBw = slmath.MulSpatialPoint(gdB.WbR, gdB.WbQ, pBb)
diff = pAw.Sub(pBw)
if gdB.Size.X > 0 {
*norm = slmath.Normal3(diff)
} else {
*norm = slmath.MulQuatVector(gdB.WbQ, math32.Vec3(0, 1, 0))
}
} else { // edges of finite plane -- only here if plane is finite
var edge0, edge1 math32.Vector3
PlaneEdge(cpi-2, gdB.Size.X, gdB.Size.Z, &edge0, &edge1)
edge0w := slmath.MulSpatialPoint(gdB.WbR, gdB.WbQ, edge0)
edge1w := slmath.MulSpatialPoint(gdB.WbR, gdB.WbQ, edge1)
edge0a := slmath.MulSpatialPoint(gdA.BwR, gdA.BwQ, edge0w)
edge1a := slmath.MulSpatialPoint(gdA.BwR, gdA.BwQ, edge1w)
u := ClosestEdgeCapsule(gdA.Size.X, hh, edge0a, edge1a, maxIter)
pBw = edge0w.MulScalar(1 - u).Add(edge1w.MulScalar(u))
// find closest point + contact normal on capsule A
p0Aw := slmath.MulSpatialPoint(gdA.WbR, gdA.WbQ, math32.Vec3(0, hh, 0))
p1Aw := slmath.MulSpatialPoint(gdA.WbR, gdA.WbQ, math32.Vec3(0, -hh, 0))
pAw = ClosestPointLineSegment(p0Aw, p1Aw, pBw)
diff = pAw.Sub(pBw)
*norm = slmath.MulQuatVector(gdB.WbQ, math32.Vec3(0, 1, 0))
}
*pA = pAw
*pB = pBw
return slmath.Dot3(diff, *norm)
}
// Handle collision between two capsules (gdA and gdB).
func ColCapsuleCapsule(cpi, maxIter int32, gdA *GeomData, gdB *GeomData, pA, pB, norm *math32.Vector3) float32 {
// find closest edge coordinate to capsule SDF B
hhA := gdA.Size.Y - gdA.Size.X
hhB := gdB.Size.Y - gdB.Size.X
// edge from capsule A
// depending on point id, we query an edge from 0 to 0.5 or 0.5 to 1
e0 := math32.Vec3(0, 0, hhA*float32(cpi%2))
e1 := math32.Vec3(0, 0, -hhA*float32((cpi+1)%2))
edge0w := slmath.MulSpatialPoint(gdA.WbR, gdA.WbQ, e0)
edge1w := slmath.MulSpatialPoint(gdA.WbR, gdA.WbQ, e1)
edge0b := slmath.MulSpatialPoint(gdA.BwR, gdA.BwQ, edge0w)
edge1b := slmath.MulSpatialPoint(gdA.BwR, gdA.BwQ, edge1w)
u := ClosestEdgeCapsule(gdB.Size.X, hhB, edge0b, edge1b, maxIter)
pAw := edge0w.MulScalar(1 - u).Add(edge1w.MulScalar(u))
p0Bw := slmath.MulSpatialPoint(gdB.WbR, gdB.WbQ, math32.Vec3(0, hhB, 0))
p1Bw := slmath.MulSpatialPoint(gdB.WbR, gdB.WbQ, math32.Vec3(0, -hhB, 0))
pBw := ClosestPointLineSegment(p0Bw, p1Bw, pAw)
diff := pAw.Sub(pBw)
*norm = slmath.Normal3(diff)
*pA = pAw
*pB = pBw
return slmath.Dot3(diff, *norm)
}
// Handle collision between two boxes (gdA and gdB).
func ColBoxBox(cpi, maxIter int32, gdA *GeomData, gdB *GeomData, pA, pB, norm *math32.Vector3) float32 {
// edge-based box contact
var edge0, edge1 math32.Vector3
BoxEdge(cpi, gdA.Size, &edge0, &edge1)
edge0w := slmath.MulSpatialPoint(gdA.WbR, gdA.WbQ, edge0)
edge1w := slmath.MulSpatialPoint(gdA.WbR, gdA.WbQ, edge1)
edge0b := slmath.MulSpatialPoint(gdB.BwR, gdB.BwQ, edge0w)
edge1b := slmath.MulSpatialPoint(gdB.BwR, gdB.BwQ, edge1w)
u := ClosestEdgeBox(gdB.Size, edge0b, edge1b, maxIter)
pAw := edge0w.MulScalar(1 - u).Add(edge1w.MulScalar(u))
// find closest point + contact normal on box B
queryB := slmath.MulSpatialPoint(gdB.BwR, gdB.BwQ, pAw)
pBody := ClosestPointBox(gdB.Size, queryB)
pBw := slmath.MulSpatialPoint(gdB.WbR, gdB.WbQ, pBody)
diff := pAw.Sub(pBw)
*norm = slmath.MulQuatVector(gdB.WbQ, BoxSDFGrad(gdB.Size, queryB))
*pA = pAw
*pB = pBw
return slmath.Dot3(diff, *norm)
}
// Handle collision between a box (gdA) and a capsule (gdB).
func ColBoxCapsule(cpi, maxIter int32, gdA *GeomData, gdB *GeomData, pA, pB, norm *math32.Vector3) float32 {
hhB := gdB.Size.Y - gdB.Size.X
// capsule B
// depending on point id, we query an edge from 0 to 0.5 or 0.5 to 1
e0 := math32.Vec3(0, -hhB*float32(cpi%2), 0)
e1 := math32.Vec3(0, hhB*float32((cpi+1)%2), 0)
edge0w := slmath.MulSpatialPoint(gdB.WbR, gdB.WbQ, e0)
edge1w := slmath.MulSpatialPoint(gdB.WbR, gdB.WbQ, e1)
edge0a := slmath.MulSpatialPoint(gdA.BwR, gdA.BwQ, edge0w)
edge1a := slmath.MulSpatialPoint(gdA.BwR, gdA.BwQ, edge1w)
u := ClosestEdgeBox(gdA.Size, edge0a, edge1a, maxIter)
pBw := edge0w.MulScalar(1 - u).Add(edge1w.MulScalar(u))
// find closest point + contact normal on box A
queryA := slmath.MulSpatialPoint(gdA.BwR, gdA.BwQ, pBw)
pABody := ClosestPointBox(gdA.Size, queryA)
pAw := slmath.MulSpatialPoint(gdA.WbR, gdA.WbQ, pABody)
diff := pAw.Sub(pBw)
// the contact point inside the capsule should already be outside the box
*norm = slmath.Negate3(slmath.MulQuatVector(gdA.WbQ, BoxSDFGrad(gdA.Size, queryA)))
*pA = pAw
*pB = pBw
return slmath.Dot3(diff, *norm)
}
// Handle collision between a box (gdA) and a plane (gdB).
func ColBoxPlane(cpi, maxIter int32, gdA *GeomData, gdB *GeomData, pA, pB, norm *math32.Vector3) float32 {
width := gdB.Size.X
length := gdB.Size.Z
var pAw, pBw, diff math32.Vector3
if cpi < 8 {
// vertex-based contact
pABody := BoxVertex(cpi, gdA.Size)
pAw = slmath.MulSpatialPoint(gdA.WbR, gdA.WbQ, pABody)
queryB := slmath.MulSpatialPoint(gdB.BwR, gdB.BwQ, pAw)
pBody := ClosestPointPlane(width, length, queryB)
pBw = slmath.MulSpatialPoint(gdB.WbR, gdB.WbQ, pBody)
diff = pAw.Sub(pBw)
*norm = slmath.MulQuatVector(gdB.WbQ, math32.Vec3(0, 1, 0))
if width > 0 && length > 0 {
if math32.Abs(queryB.X) > width || math32.Abs(queryB.Z) > length {
// skip, we will evaluate the plane edge contact with the box later
return 1e6 // invalid
}
// note: commented out in original:
// check whether the COM is above the plane
// sign = wp.sign(slmath.Dot3(wp.transform_get_translation(gdA.X_ws) - pBw, normal))
// if sign < 0:
//
// // the entire box is most likely below the plane
// return
}
// the contact point is within plane boundaries
} else {
// contact between box A and edges of finite plane B
var edge0, edge1 math32.Vector3
PlaneEdge(cpi-8, width, length, &edge0, &edge1)
edge0w := slmath.MulSpatialPoint(gdB.WbR, gdB.WbQ, edge0)
edge1w := slmath.MulSpatialPoint(gdB.WbR, gdB.WbQ, edge1)
edge0a := slmath.MulSpatialPoint(gdA.BwR, gdA.BwQ, edge0w)
edge1a := slmath.MulSpatialPoint(gdA.BwR, gdA.BwQ, edge1w)
u := ClosestEdgeBox(gdA.Size, edge0a, edge1a, maxIter)
pBw = edge0w.MulScalar(1 - u).Add(edge1w.MulScalar(u))
// find closest point + contact normal on box A
queryA := slmath.MulSpatialPoint(gdA.BwR, gdA.BwQ, pBw)
pABody := ClosestPointBox(gdA.Size, queryA)
pAw = slmath.MulSpatialPoint(gdA.WbR, gdA.WbQ, pABody)
queryB := slmath.MulSpatialPoint(gdA.BwR, gdA.BwQ, pAw)
if math32.Abs(queryB.X) > width || math32.Abs(queryB.Z) > length {
// ensure that the closest point is actually inside the plane
return 1e6 // invalid
}
diff = pAw.Sub(pBw)
comA := gdA.WbR
queryB = slmath.MulSpatialPoint(gdB.BwR, gdB.BwQ, comA)
if math32.Abs(queryB.X) > width || math32.Abs(queryB.Z) > length {
// the COM is outside the plane
*norm = slmath.Normal3(comA.Sub(pBw))
} else {
*norm = slmath.MulQuatVector(gdB.WbQ, math32.Vec3(0, 1, 0))
}
}
*pA = pAw
*pB = pBw
return slmath.Dot3(diff, *norm)
}
// Handle collision between a sphere (gdA) and a box (gdB).
func ColSphereBox(cpi, maxIter int32, gdA *GeomData, gdB *GeomData, pA, pB, norm *math32.Vector3) float32 {
pAw := gdA.WbR
// contact point in frame of body B
pABody := slmath.MulSpatialPoint(gdB.BwR, gdB.BwQ, pAw)
pBody := ClosestPointBox(gdB.Size, pABody)
pBw := slmath.MulSpatialPoint(gdB.WbR, gdB.WbQ, pBody)
diff := pAw.Sub(pBw)
*norm = slmath.Normal3(diff)
*pA = pAw
*pB = pBw
return slmath.Dot3(diff, *norm)
}
// Handle collision between a sphere (gdA) and a capsule (gdB).
func ColSphereCapsule(cpi, maxIter int32, gdA *GeomData, gdB *GeomData, pA, pB, norm *math32.Vector3) float32 {
pAw := gdA.WbR
hhB := gdB.Size.Y - gdB.Size.X
// capsule B
AB := slmath.MulSpatialPoint(gdB.WbR, gdB.WbQ, math32.Vec3(0, hhB, 0))
BB := slmath.MulSpatialPoint(gdB.WbR, gdB.WbQ, math32.Vec3(0, -hhB, 0))
pBw := ClosestPointLineSegment(AB, BB, pAw)
diff := pAw.Sub(pBw)
*norm = slmath.Normal3(diff)
*pA = pAw
*pB = pBw
return slmath.Dot3(diff, *norm)
}
// Handle collision between a sphere (gdA) and a plane (gdB).
func ColSpherePlane(cpi, maxIter int32, gdA *GeomData, gdB *GeomData, pA, pB, norm *math32.Vector3) float32 {
pAw := gdA.WbR
pBody := ClosestPointPlane(gdB.Size.X, gdB.Size.Z, slmath.MulSpatialPoint(gdB.BwR, gdB.BwQ, pAw))
pBw := slmath.MulSpatialPoint(gdB.WbR, gdB.WbQ, pBody)
diff := pAw.Sub(pBw)
*norm = slmath.MulQuatVector(gdB.WbQ, math32.Vec3(0, 1, 0))
*pA = pAw
*pB = pBw
return slmath.Dot3(diff, *norm)
}
// Handle collision between a cylinder (geo_a) and an infinite plane (geo_b).
func ColCylinderPlane(cpi, maxIter int32, gdA *GeomData, gdB *GeomData, pA, pB, norm *math32.Vector3) float32 {
// World-space plane
plNorm := slmath.MulQuatVector(gdB.WbQ, math32.Vec3(0, 1, 0))
plPos := slmath.MulSpatialPoint(gdB.WbR, gdB.WbQ, math32.Vec3(0, 0, 0))
// World-space cylinder params
cylCtr := gdA.WbR
cylAx := slmath.Normal3(slmath.MulQuatVector(gdA.WbQ, math32.Vec3(0, 1, 0)))
cylRad := gdA.Size.X
cylHh := gdA.Size.Y
var dist float32
var pos math32.Vector3
n := plNorm
axis := cylAx
// Project, make sure axis points toward plane
prjaxis := slmath.Dot3(n, axis)
if prjaxis > 0 {
axis = slmath.Negate3(axis)
prjaxis = -prjaxis
}
// Compute normal distance from plane to cylinder center
dist0 := slmath.Dot3(cylCtr.Sub(plPos), n)
// Remove component of -normal along cylinder axis
vec := axis.MulScalar(prjaxis).Sub(n)
lenSqr := slmath.Dot3(vec, vec)
// If vector is nondegenerate, normalize and scale by radius
// Otherwise use cylinder's x-axis scaled by radius
if lenSqr >= 1e-12 {
vec = vec.MulScalar(cylRad / math32.Sqrt(lenSqr))
} else {
vec = math32.Vec3(1, 0, 0).MulScalar(cylRad) // Default x-axis when degenerate
}
// Project scaled vector on normal
prjvec := slmath.Dot3(vec, n)
// Scale cylinder axis by half-length
axis = axis.MulScalar(cylHh)
prjaxis *= cylHh
switch cpi {
case 0: // First contact point (end cap closer to plane)
dist = dist0 + prjaxis + prjvec
pos = cylCtr.Add(vec).Add(axis).Sub(n.MulScalar(dist * 0.5))
case 1: // Second contact point (end cap farther from plane)
dist = dist0 - prjaxis + prjvec
pos = cylCtr.Add(vec).Sub(axis).Sub(n.MulScalar(dist * 0.5))
case 2, 3: // Try triangle contact points on side closer to plane
prjvec1 := prjvec * -0.5
dist = dist0 + prjaxis + prjvec1
// Compute sideways vector scaled by radius*sqrt(3)/2
vec1 := slmath.Cross3(vec, axis)
vec1 = slmath.Normal3(vec1).MulScalar(cylRad * math32.Sqrt(3.0) * 0.5)
pextra := vec1.Add(axis).Sub(vec.MulScalar(0.5)).Sub(n.MulScalar(dist * 0.5))
pos = cylCtr.Add(pextra)
if cpi == 3 { // Add contact point B - adjust to closest side
pos = cylCtr.Sub(pextra)
}
default:
}
// Split midpoint into shape-plane endpoints
*pA = pos.Add(n.MulScalar(dist * 0.5))
*pB = pos.Sub(n.MulScalar(dist * 0.5))
*norm = n
return dist
}
// todo: newton geometry/collision_primitive.py supports collide_sphere_cylinder
// could adapt that. But there is no Box-Cylinder collision, nor anything with Capsule.
// so in general it is not super urgent.
//gosl:end
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package physics
import (
"cogentcore.org/core/math32"
"cogentcore.org/lab/gosl/slmath"
)
//gosl:start
func SphereSDF(center math32.Vector3, radius float32, p math32.Vector3) float32 {
return slmath.Length3(p.Sub(center)) - radius
}
func BoxSDF(upper, p math32.Vector3) float32 {
// adapted from https://www.iquilezles.org/www/articles/distfunctions/distfunctions.htm
qx := math32.Abs(p.X) - upper.X
qy := math32.Abs(p.Y) - upper.Y
qz := math32.Abs(p.Z) - upper.Z
e := math32.Vec3(max(qx, 0.0), max(qy, 0.0), max(qz, 0.0))
return slmath.Length3(e) + min(max(qx, max(qy, qz)), 0.0)
}
func BoxSDFGrad(upper, p math32.Vector3) math32.Vector3 {
qx := math32.Abs(p.X) - upper.X
qy := math32.Abs(p.Y) - upper.Y
qz := math32.Abs(p.Z) - upper.Z
// exterior case
if qx > 0.0 || qy > 0.0 || qz > 0.0 {
x := math32.Clamp(p.X, -upper.X, upper.X)
y := math32.Clamp(p.Y, -upper.Y, upper.Y)
z := math32.Clamp(p.Z, -upper.Z, upper.Z)
return slmath.Normal3(p.Sub(math32.Vec3(x, y, z)))
}
sx := math32.Sign(p.X)
sy := math32.Sign(p.Y)
sz := math32.Sign(p.Z)
// x projection
if (qx > qy && qx > qz) || (qy == 0.0 && qz == 0.0) {
return math32.Vec3(sx, 0.0, 0.0)
}
// y projection
if (qy > qx && qy > qz) || (qx == 0.0 && qz == 0.0) {
return math32.Vec3(0.0, sy, 0.0)
}
// z projection
return math32.Vec3(0.0, 0.0, sz)
}
func CapsuleSDF(radius, hh float32, p math32.Vector3) float32 {
if p.Y > hh {
return slmath.Length3(math32.Vec3(p.X, p.Y-hh, p.Z)) - radius
}
if p.Y < -hh {
return slmath.Length3(math32.Vec3(p.X, p.Y+hh, p.Z)) - radius
}
return slmath.Length3(math32.Vec3(p.X, 0.0, p.Z)) - radius
}
func CylinderSDF(radius, hh float32, p math32.Vector3) float32 {
dx := slmath.Length3(math32.Vec3(p.X, 0.0, p.Z)) - radius
dy := math32.Abs(p.Y) - hh
return min(max(dx, dy), 0.0) + slmath.Length2(math32.Vec2(max(dx, 0.0), max(dy, 0.0)))
}
// Cone with apex at +hh and base at -hh
func ConeSDF(radius, hh float32, p math32.Vector3) float32 {
dx := slmath.Length3(math32.Vec3(p.X, 0.0, p.Z)) - radius*(hh-p.Y)/(2.0*hh)
dy := math32.Abs(p.Y) - hh
return min(max(dx, dy), 0.0) + slmath.Length2(math32.Vec2(max(dx, 0.0), max(dy, 0.0)))
}
// SDF for a quad in the xz plane
func PlaneSDF(width, length float32, p math32.Vector3) float32 {
if width > 0.0 && length > 0.0 {
d := max(math32.Abs(p.X)-width, math32.Abs(p.Z)-length)
return max(d, math32.Abs(p.Y))
}
return p.Y
}
// ClosestPointPlane projects the point onto the quad in
// the xz plane (if size > 0.0), otherwise infinite.
func ClosestPointPlane(width, length float32, pt math32.Vector3) math32.Vector3 {
cp := pt
cp.Y = 0
if width == 0.0 {
return cp
}
cp.X = math32.Clamp(pt.X, -width, width)
cp.Z = math32.Clamp(pt.Z, -length, length)
return cp
}
func ClosestPointLineSegment(a, b, pt math32.Vector3) math32.Vector3 {
ab := b.Sub(a)
ap := pt.Sub(a)
t := slmath.Dot3(ap, ab) / slmath.Dot3(ab, ab)
t = math32.Clamp(t, 0.0, 1.0)
return a.Add(ab.MulScalar(t))
}
// closest point to box surface
func ClosestPointBox(upper, pt math32.Vector3) math32.Vector3 {
x := math32.Clamp(pt.X, -upper.X, upper.X)
y := math32.Clamp(pt.Y, -upper.Y, upper.Y)
z := math32.Clamp(pt.Z, -upper.Z, upper.Z)
if math32.Abs(pt.X) <= upper.X && math32.Abs(pt.Y) <= upper.Y && math32.Abs(pt.Z) <= upper.Z {
// the point is inside, find closest face
sx := math32.Abs(math32.Abs(pt.X) - upper.X)
sy := math32.Abs(math32.Abs(pt.Y) - upper.Y)
sz := math32.Abs(math32.Abs(pt.Z) - upper.Z)
// return closest point on closest side, handle corner cases
if (sx < sy && sx < sz) || (sy == 0.0 && sz == 0.0) {
x = math32.Sign(pt.X) * upper.X
} else if (sy < sx && sy < sz) || (sx == 0.0 && sz == 0.0) {
y = math32.Sign(pt.Y) * upper.Y
} else {
z = math32.Sign(pt.Z) * upper.Z
}
}
return math32.Vec3(x, y, z)
}
// box vertex numbering:
//
// 6---7
// |\ |\ y
// | 2-+-3 |
// 4-+-5 | z \|
// \| \| o---x
// 0---1
//
// get the vertex of the box given its ID (0-7)
func BoxVertex(ptId int32, upper math32.Vector3) math32.Vector3 {
sign_x := float32(ptId%2)*2.0 - 1.0
sign_y := float32((ptId/2)%2)*2.0 - 1.0
sign_z := float32((ptId/4)%2)*2.0 - 1.0
return math32.Vec3(sign_x*upper.X, sign_y*upper.Y, sign_z*upper.Z)
}
// get the edge of the box given its ID (0-11)
func BoxEdge(edgeId int32, upper math32.Vector3, edge0, edge1 *math32.Vector3) {
eid := edgeId
if eid < 4 {
// edges along x: 0-1, 2-3, 4-5, 6-7
i := eid * 2
j := i + 1
*edge0 = BoxVertex(i, upper)
*edge1 = BoxVertex(j, upper)
} else if eid < 8 {
// edges along y: 0-2, 1-3, 4-6, 5-7
eid -= 4
i := eid%2 + eid // 2 * 4
j := i + 2
*edge0 = BoxVertex(i, upper)
*edge1 = BoxVertex(j, upper)
}
// edges along z: 0-4, 1-5, 2-6, 3-7
eid -= 8
i := eid
j := i + 4
*edge0 = BoxVertex(i, upper)
*edge1 = BoxVertex(j, upper)
}
// get the edge of the plane given its ID (0-3)
func PlaneEdge(edgeId int32, width, length float32, edge0, edge1 *math32.Vector3) {
p0x := (2*float32(edgeId%2) - 1) * width
p0z := (2*float32(edgeId/2) - 1) * length
var p1x, p1z float32
if edgeId == 0 || edgeId == 3 {
p1x = p0x
p1z = -p0z
} else {
p1x = -p0x
p1z = p0z
}
*edge0 = math32.Vec3(p0x, 0, p0z)
*edge1 = math32.Vec3(p1x, 0, p1z)
}
// find point on edge closest to box, return its barycentric edge coordinate
func ClosestEdgeBox(upper, edgeA, edgeB math32.Vector3, maxIter int32) float32 {
// Golden-section search
a := float32(0.0)
b := float32(1.0)
h := b - a
invphi := float32(0.61803398875) // 1 / phi
invphi2 := float32(0.38196601125) // 1 / phi^2
c := a + invphi2*h
d := a + invphi*h
query := edgeA.MulScalar(1.0 - c).Add(edgeB.MulScalar(c))
yc := BoxSDF(upper, query)
query = edgeA.MulScalar(1.0 - d).Add(edgeB.MulScalar(d))
yd := BoxSDF(upper, query)
for range maxIter {
if yc < yd { // yc > yd to find the maximum
b = d
d = c
yd = yc
h = invphi * h
c = a + invphi2*h
query = edgeA.MulScalar(1.0 - c).Add(edgeB.MulScalar(c))
yc = BoxSDF(upper, query)
} else {
a = c
c = d
yc = yd
h = invphi * h
d = a + invphi*h
query = edgeA.MulScalar(1.0 - d).Add(edgeB.MulScalar(d))
yd = BoxSDF(upper, query)
}
}
if yc < yd {
return 0.5 * (a + d)
}
return 0.5 * (c + b)
}
// find point on edge closest to plane, return its barycentric edge coordinate
func ClosestEdgePlane(width, length float32, edgeA, edgeB math32.Vector3, maxIter int32) float32 {
// Golden-section search
a := float32(0.0)
b := float32(1.0)
h := b - a
invphi := float32(0.61803398875) // 1 / phi
invphi2 := float32(0.38196601125) // 1 / phi^2
c := a + invphi2*h
d := a + invphi*h
query := edgeA.MulScalar(1.0 - c).Add(edgeB.MulScalar(c))
yc := PlaneSDF(width, length, query)
query = edgeA.MulScalar(1.0 - d).Add(edgeB.MulScalar(d))
yd := PlaneSDF(width, length, query)
for range maxIter {
if yc < yd { // yc > yd to find the maximum
b = d
d = c
yd = yc
h = invphi * h
c = a + invphi2*h
query = edgeA.MulScalar(1.0 - c).Add(edgeB.MulScalar(c))
yc = PlaneSDF(width, length, query)
} else {
a = c
c = d
yc = yd
h = invphi * h
d = a + invphi*h
query = edgeA.MulScalar(1.0 - d).Add(edgeB.MulScalar(d))
yd = PlaneSDF(width, length, query)
}
}
if yc < yd {
return 0.5 * (a + d)
}
return 0.5 * (c + b)
}
// find point on edge closest to capsule, return its barycentric edge coordinate
func ClosestEdgeCapsule(radius, hh float32, edgeA, edgeB math32.Vector3, maxIter int32) float32 {
// Golden-section search
a := float32(0.0)
b := float32(1.0)
h := b - a
invphi := float32(0.61803398875) // 1 / phi
invphi2 := float32(0.38196601125) // 1 / phi^2
c := a + invphi2*h
d := a + invphi*h
query := edgeA.MulScalar(1.0 - c).Add(edgeB.MulScalar(c))
yc := CylinderSDF(radius, hh, query)
query = edgeA.MulScalar(1.0 - d).Add(edgeB.MulScalar(d))
yd := CylinderSDF(radius, hh, query)
for range maxIter {
if yc < yd { // yc > yd to find the maximum
b = d
d = c
yd = yc
h = invphi * h
c = a + invphi2*h
query = edgeA.MulScalar(1.0 - c).Add(edgeB.MulScalar(c))
yc = CylinderSDF(radius, hh, query)
} else {
a = c
c = d
yc = yd
h = invphi * h
d = a + invphi*h
query = edgeA.MulScalar(1.0 - d).Add(edgeB.MulScalar(d))
yd = CylinderSDF(radius, hh, query)
}
}
if yc < yd {
return 0.5 * (a + d)
}
return 0.5 * (c + b)
}
//gosl:end
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package physics
import (
"cogentcore.org/core/math32"
)
// see: newton/geometry for lots of helpful methods.
//gosl:start
// newton: geometry/types.py
// Shapes are elemental shapes for rigid bodies.
// In general, size dimensions are half values
// (e.g., radius, half-height, etc), which is natural for
// center-based body coordinates.
type Shapes int32 //enums:enum
const (
// Plane cannot be a dynamic shape, but is most efficient for
// collision computations. Use size = 0 for an infinite plane.
// Natively extends in the X-Z plane: SizeX x SizeZ.
Plane Shapes = iota
// todo: HeightField here (terrain)
// Sphere. SizeX is the radius.
Sphere
// Capsule is a cylinder with half-spheres on the ends.
// Natively oriented vertically along the Y axis.
// SizeX = radius of end caps, SizeY = _total_ half-height
// (i.e., SizeX + half-height of cylindrical portion, must
// be >= SizeX). This parameterization allows joint offsets
// to be SizeY, and direct swapping of shape across Box and
// Cylinder with same total extent.
Capsule
// todo: Ellipsoid goes here
// Cylinder, natively oriented vertically along the Y axis.
// SizeX = radius, SizeY = half-height of Y axis
// Cylinder does not support most collisions and is thus not recommended
// where collision data is needed.
Cylinder
// Box is a 3D rectalinear shape.
// The sizes are _half_ sizes along each dimension,
// relative to the center.
Box
// todo: Mesh, SDF here
// Cone is like a cylinder with the top radius = 0,
// oriented up. SizeX = bottom radius, SizeY = half-height in Y.
// Cone does not support any collisions and is not recommended for
// interacting bodies.
Cone
)
// newton: geometry/kernels.py: count_contact_points_for_pair
// ShapePairContacts returns the number of contact points possible
// for given pair of shapes. a <= b ordering. returns from a to b,
// ba is from b to a (mostly 0).
// infPlane means that a is a Plane and it is infinite (size = 0).
func ShapePairContacts(a, b Shapes, infPlane bool, ba *int32) int32 {
*ba = 0
switch a {
case Plane:
switch b {
case Plane:
return 0
case Sphere:
return 1
case Capsule:
if infPlane {
return 2
} else {
return 2 + 4
}
case Cylinder:
return 4
case Box:
if infPlane {
return 8
} else {
return 8 + 4
}
default:
return 0
}
case Sphere:
return 1
case Capsule:
switch b {
case Capsule:
return 2
case Box:
return 8
default:
return 0
}
case Cylinder:
return 0 // no box collisions!
case Box:
*ba = 12
return 12
default: // note: Cone has no collision points!
return 0
}
}
//gosl:end
// Radius returns the shape radius for given size.
// this is used for broad-phase collision.
func (sh Shapes) Radius(sz math32.Vector3) float32 {
switch sh {
case Plane:
if sz.X > 0 {
return sz.Length()
}
return 1.0e6 // infinite
case Sphere:
return sz.X
case Capsule:
return sz.Y // full half-height
case Cylinder:
return sz.X + sz.Y // over-estimate for cylinder
case Box:
return sz.Length()
}
return 0
}
// BBox returns the bounding box for shape of given size.
func (sh Shapes) BBox(sz math32.Vector3) math32.Box3 {
var bb math32.Box3
switch sh {
case Sphere:
bb.SetMinMax(math32.Vec3(-sz.X, -sz.X, -sz.X), math32.Vec3(sz.X, sz.X, sz.X))
case Capsule:
bb.SetMinMax(math32.Vec3(-sz.X, -sz.Y, -sz.X), math32.Vec3(sz.X, sz.Y, sz.X))
case Cylinder:
bb.SetMinMax(math32.Vec3(-sz.X, -sz.Y, -sz.X), math32.Vec3(sz.X, sz.Y, sz.X))
case Box:
bb.SetMinMax(sz.Negate(), sz)
}
return bb
}
// Inertia returns the inertia tensor for solid shape of given size,
// with uniform density and given mass.
func (sh Shapes) Inertia(sz math32.Vector3, mass float32) math32.Matrix3 {
var inertia math32.Matrix3
switch sh {
// todo: other shapes!! see below.
case Sphere:
r := sz.X
// v := 4.0 / 3.0 * math32.Pi * r * r * r
ia := 2.0 / 5.0 * mass * r * r
inertia = math32.Mat3(ia, 0.0, 0.0, 0.0, ia, 0.0, 0.0, 0.0, ia)
case Capsule:
r := sz.X
h := (sz.Y - sz.X) * 2
vs := (4.0 / 3.0) * math32.Pi * r * r * r
vc := math32.Pi * r * r * h
ms := mass * (vs / (vs + vc))
mc := mass * (vc / (vs + vc))
ia := mc*(0.25*r*r+(1.0/12.0)*h*h) + ms*(0.4*r*r+0.375*r*h+0.25*h*h)
ib := (mc*0.5 + ms*0.4) * r * r
inertia = math32.Mat3(ia, 0.0, 0.0, 0.0, ib, 0.0, 0.0, 0.0, ia)
case Cylinder:
r := sz.X
h := sz.Y * 2
ia := (1.0 / 12) * mass * (3*r*r + h*h)
ib := (1.0 / 2.0) * mass * r * r
inertia = math32.Mat3(ia, 0.0, 0.0, 0.0, ib, 0.0, 0.0, 0.0, ia)
case Box:
w := 2 * sz.X
h := 2 * sz.Y
d := 2 * sz.Z
ia := 1.0 / 12.0 * mass * (h*h + d*d)
ib := 1.0 / 12.0 * mass * (w*w + d*d)
ic := 1.0 / 12.0 * mass * (w*w + h*h)
inertia = math32.Mat3(ia, 0.0, 0.0, 0.0, ib, 0.0, 0.0, 0.0, ic)
}
return inertia
}
/*
def compute_cone_inertia(density: float, r: float, h: float) -> tuple[float, wp.vec3, wp.mat33]:
"""Helper to compute mass and inertia of a solid cone extending along the z-axis
Args:
density: The cone density
r: The cone radius
h: The cone height (extent along the z-axis)
Returns:
A tuple of (mass, center of mass, inertia) with inertia specified around the center of mass
"""
m = density * wp.pi * r * r * h / 3.0
# Center of mass is at -h/4 from the geometric center
# Since the cone has base at -h/2 and apex at +h/2, the COM is 1/4 of the height from base toward apex
com = wp.vec3(0.0, 0.0, -h / 4.0)
# Inertia about the center of mass
Ia = 3 / 20 * m * r * r + 3 / 80 * m * h * h
Ib = 3 / 10 * m * r * r
# For Z-axis orientation: I_xx = I_yy = Ia, I_zz = Ib
I = wp.mat33([[Ia, 0.0, 0.0], [0.0, Ia, 0.0], [0.0, 0.0, Ib]])
return (m, com, I)
def compute_ellipsoid_inertia(density: float, a: float, b: float, c: float) -> tuple[float, wp.vec3, wp.mat33]:
"""Helper to compute mass and inertia of a solid ellipsoid
The ellipsoid is centered at the origin with semi-axes a, b, c along the x, y, z axes respectively.
Args:
density: The ellipsoid density
a: The semi-axis along the x-axis
b: The semi-axis along the y-axis
c: The semi-axis along the z-axis
Returns:
A tuple of (mass, center of mass, inertia) with inertia specified around the center of mass
"""
# Volume of ellipsoid: V = (4/3) * pi * a * b * c
v = (4.0 / 3.0) * wp.pi * a * b * c
m = density * v
# Inertia tensor for a solid ellipsoid about its center of mass:
# Ixx = (1/5) * m * (b² + c²)
# Iyy = (1/5) * m * (a² + c²)
# Izz = (1/5) * m * (a² + b²)
Ixx = (1.0 / 5.0) * m * (b * b + c * c)
Iyy = (1.0 / 5.0) * m * (a * a + c * c)
Izz = (1.0 / 5.0) * m * (a * a + b * b)
I = wp.mat33([[Ixx, 0.0, 0.0], [0.0, Iyy, 0.0], [0.0, 0.0, Izz]])
return (m, wp.vec3(), I)
*/
// Code generated by "goal build"; DO NOT EDIT.
//line step.goal:1
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This code is adapted directly from https://github.com/newton-physics/newton
// Copyright (c) 2025 The Newton Developers, Released under an Apache-2.0 license
package physics
import (
"fmt"
"cogentcore.org/core/math32"
)
//gosl:start
//gosl:import "cogentcore.org/lab/gosl/slmath"
func OneIfNonzero(f float32) float32 {
if f != 0.0 {
return 1.0
}
return 0.0
}
// StepInit performs initialization at start of Step.
func StepInit(i uint32) { //gosl:kernel read-write:Params
if i > 0 {
return
}
params := GetParams(0)
BroadContactsN.Values[0] = 0
ContactsN.Values[0] = 0
if params.Cur == 0 {
params.Cur = 1
params.Next = 0
} else {
params.Cur = 0
params.Next = 1
}
for j := range params.JointDoFsN {
tpos := JointControls.Value(int(j), int(JointTargetPos))
tcur := JointControls.Value(int(j), int(JointTargetPosCur))
if math32.Abs(tpos-tcur) < params.ControlDtThr {
tcur = tpos
} else {
tcur += params.ControlDt * (tpos - tcur)
}
JointControls.Set(tcur, int(j), int(JointTargetPosCur))
}
}
// newton step does the following:
// if self.compute_body_velocity_from_position_delta or self.enable_restitution:
// // save initial state:
// body_q_init = wp.clone(state_in.body_q)
// body_qd_init = wp.clone(state_in.body_qd)
// body_deltas = wp.empty_like(state_out.body_qd)
// kernel=apply_joint_forces,
// self.integrate_bodies(model, state_in, state_out, dt, self.angular_damping)
// for i in range(self.iterations):
// kernel=solve_body_joints,
// body_q, body_qd = self.apply_body_deltas(model, state_in, state_out, body_deltas, dt)
// kernel=solve_body_contact_positions,
// if self.enable_restitution and i == 0:
// # remember contact constraint weighting from the first iteration
// if self.rigid_contact_con_weighting:
// rigid_contact_inv_weight_init = wp.clone(rigid_contact_inv_weight)
// else:
// rigid_contact_inv_weight_init = None
// body_q, body_qd = self.apply_body_deltas(
// model, state_in, state_out, body_deltas, dt, rigid_contact_inv_weight
// )
// # update body velocities from position changes
// if self.compute_body_velocity_from_position_delta and model.body_count and not requires_grad:
// kernel=update_body_velocities,
// kernel=apply_rigid_restitution,
// kernel=apply_body_delta_velocities,
//gosl:end
// Step runs one physics step, sending Params and JointControls
// to the GPU, and getting the Dynamics state vars back.
// Each step has SubSteps integration sub-steps.
func (ml *Model) Step() {
params := GetParams(0)
ToGPU(ParamsVar, JointControlsVar)
if params.SubSteps > 1 {
for range params.SubSteps - 1 {
ml.StepGet()
}
}
vars := []GPUVars{ParamsVar, DynamicsVar, ContactsNVar}
if ml.GetContacts {
vars = append(vars, ContactsVar)
}
ml.StepGet(vars...)
if ContactsN.Value(0) >= params.ContactsMax {
fmt.Println("Warning: over ContactsMax:", ContactsN.Value(0), "Max:", params.ContactsMax)
}
if ml.ReportTotalKE {
ke := ml.TotalKineticEnergy()
fmt.Println("Total KE:", ke)
}
}
// StepGet runs one physics step and gets the given vars back
// from the GPU.
func (ml *Model) StepGet(vars ...GPUVars) {
params := GetParams(0)
RunStepInit(1)
ml.StepCollision()
ml.StepJointForces()
ml.StepIntegrateBodies()
for range params.Iterations {
ml.StepSolveJoints()
ml.StepBodyContacts()
}
RunDone(vars...)
}
func (ml *Model) StepCollision() {
params := GetParams(0)
RunCollisionBroad(int(params.BodyCollidePairsN))
// note: time getting BroadContactsN back down and using that vs. running full
RunCollisionNarrow(int(params.ContactsMax))
// note: too slow to get this back, so just using ContactsMax always.
// RunDone(ContactsNVar) // we do multiple iterations so useful to have this
// fmt.Println("contacts:", ContactsN.Value(0), "max:", params.ContactsMax)
}
func (ml *Model) StepJointForces() {
params := GetParams(0)
RunStepJointForces(int(params.JointsN))
RunForcesFromJoints(int(params.DynamicsN))
}
func (ml *Model) StepIntegrateBodies() {
params := GetParams(0)
RunStepIntegrateBodies(int(params.DynamicsN))
}
func (ml *Model) StepSolveJoints() {
params := GetParams(0)
RunStepSolveJoints(int(params.ObjectsN))
}
func (ml *Model) StepBodyContacts() {
params := GetParams(0)
if !ml.GPU {
cmax := int(ContactsN.Values[0])
if cmax > 0 {
RunStepBodyContacts(cmax)
}
RunStepBodyContactDeltas(int(params.DynamicsN))
} else {
RunStepBodyContacts(int(params.ContactsMax)) // just do max and let the routines bail
RunStepBodyContactDeltas(int(params.DynamicsN))
}
}
// Code generated by "goal build"; DO NOT EDIT.
//line step_body.goal:1
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This code is adapted directly from https://github.com/newton-physics/newton
// Copyright (c) 2025 The Newton Developers, Released under an Apache-2.0 license
package physics
import (
// "fmt"
"cogentcore.org/core/math32"
"cogentcore.org/lab/gosl/slmath"
)
//gosl:start
//gosl:import "cogentcore.org/lab/gosl/slmath"
// InitDynamics copies Body initial state to dynamic state (cur and next).
func InitDynamics(i uint32) { //gosl:kernel
params := GetParams(0)
ii := int32(i)
if ii >= params.DynamicsN {
return
}
for cni := range 2 {
bi := DynamicBody(ii)
Dynamics.Set(Bodies.Value(int(bi), int(BodyPosX)), int(ii), int(cni), int(DynPosX))
Dynamics.Set(Bodies.Value(int(bi), int(BodyPosY)), int(ii), int(cni), int(DynPosY))
Dynamics.Set(Bodies.Value(int(bi), int(BodyPosZ)), int(ii), int(cni), int(DynPosZ))
Dynamics.Set(Bodies.Value(int(bi), int(BodyQuatX)), int(ii), int(cni), int(DynQuatX))
Dynamics.Set(Bodies.Value(int(bi), int(BodyQuatY)), int(ii), int(cni), int(DynQuatY))
Dynamics.Set(Bodies.Value(int(bi), int(BodyQuatZ)), int(ii), int(cni), int(DynQuatZ))
Dynamics.Set(Bodies.Value(int(bi), int(BodyQuatW)), int(ii), int(cni), int(DynQuatW))
for v := DynVelX; v < DynamicVarsN; v++ {
Dynamics.Set(0.0, int(ii), int(cni), int(v))
}
}
}
// DynamicsCurToNext copies [Dynamics] state from Cur to Next.
func DynamicsCurToNext(i uint32) { //gosl:kernel
params := GetParams(0)
ii := int32(i)
if ii >= params.DynamicsN {
return
}
for di := DynBody; di < DynamicVarsN; di++ {
Dynamics.Set(Dynamics.Value(int(ii), int(params.Cur), int(di)), int(ii), int(params.Next), int(di))
}
}
// ForcesFromJoints gathers forces and torques from joints per dynamic
func ForcesFromJoints(i uint32) { //gosl:kernel
params := GetParams(0)
di := int32(i)
if di >= params.DynamicsN {
return
}
np := BodyJoints.Value(int(di), int(0), int(0))
nc := BodyJoints.Value(int(di), int(1), int(0))
tf := math32.Vec3(0, 0, 0)
tt := math32.Vec3(0, 0, 0)
for i := int32(1); i <= np; i++ {
ji := BodyJoints.Value(int(di), int(0), int(i))
f := JointPForce(ji)
tf = tf.Add(f)
t := JointPTorque(ji)
tt = tt.Add(t)
}
for i := int32(1); i <= nc; i++ {
ji := BodyJoints.Value(int(di), int(1), int(i))
f := JointCForce(ji)
tf = tf.Add(f)
t := JointCTorque(ji)
tt = tt.Add(t)
}
SetDynamicForce(di, params.Next, tf)
SetDynamicTorque(di, params.Next, tt)
}
// newton: solvers/solver.py: integrate_rigid_body
// StepIntegrateBodies applies forces to update pos and deltas
func StepIntegrateBodies(i uint32) { //gosl:kernel
params := GetParams(0)
di := int32(i)
if di >= params.DynamicsN {
return
}
bi := DynamicBody(di)
invMass := Bodies.Value(int(bi), int(BodyInvMass))
inertia := BodyInertia(bi)
invInertia := BodyInvInertia(bi)
grav := params.Gravity.V()
com := BodyCom(bi)
// current pos
r0 := DynamicPos(di, params.Cur)
q0 := DynamicQuat(di, params.Cur)
// current deltas
v0 := DynamicDelta(di, params.Cur)
w0 := DynamicAngDelta(di, params.Cur)
// new forces integrated from joints
f0 := DynamicForce(di, params.Next)
t0 := DynamicTorque(di, params.Next)
pcom := slmath.MulQuatVector(q0, com).Add(r0)
// linear part
v1 := v0.Add(f0.MulScalar(invMass).Add(grav.MulScalar(OneIfNonzero(invMass))).MulScalar(params.Dt))
p1 := pcom.Add(v1.MulScalar(params.Dt))
// angular part (compute in body frame)
wb := slmath.MulQuatVectorInverse(q0, w0)
tb := slmath.MulQuatVectorInverse(q0, t0).Sub(slmath.Cross3(wb, inertia.MulVector3(wb))) // coriolis forces
tb = slmath.ClampMagnitude3(tb, params.MaxForce)
w1 := slmath.MulQuatVector(q0, wb.Add(invInertia.MulVector3(tb).MulScalar(params.Dt)))
q1 := slmath.QuatAdd(q0, slmath.MulQuats(math32.NewQuat(w1.X, w1.Y, w1.Z, 0), q0).MulScalar(0.5*params.Dt))
q1 = slmath.QuatNormalize(q1)
// angular damping
w1 = w1.MulScalar(1.0 - params.AngularDamping*params.Dt)
w1 = slmath.ClampMagnitude3(w1, params.MaxForce)
p1a := p1.Sub(slmath.MulQuatVector(q1, com)) // pos corrected to nominal center.
// fmt.Println(params.Next, "integrate:", v0, v1)
// if p1a.IsNaN() || q1.IsNaN() {
// if di == 0 {
// fmt.Println("integ:", di, p1a, q1, "r0:", r0, "q0:", q0, "v0:", v0, "w0:", w0, "f0:", f0, "t0:", t0, "pcom:", pcom, "v1:", v1, "p1:", p1, "wb:", wb, "tb:", tb, "w1:", w1)
// }
SetDynamicPos(di, params.Next, p1a)
SetDynamicQuat(di, params.Next, q1)
SetDynamicDelta(di, params.Next, v1)
SetDynamicAngDelta(di, params.Next, w1)
}
// newton: solvers/xpbd/kernels.py: apply_body_deltas
// StepBodyDeltas updates Next position with deltas from joints
// or contacts (if contacts true). Also updates kinetics (velocity and acceleration)
// based on position & orientation changes if contacts=true (usually just 1 iteration).
func StepBodyDeltas(di, bi int32, contacts bool, cWt float32, linDel, angDel math32.Vector3) {
params := GetParams(0)
invMass := Bodies.Value(int(bi), int(BodyInvMass))
inertia := BodyInertia(bi)
invInertia := BodyInvInertia(bi)
// starting pos (from force integration)
r0 := DynamicPos(di, params.Next)
q0 := DynamicQuat(di, params.Next)
// starting deltas
v0 := DynamicDelta(di, params.Next)
w0 := DynamicAngDelta(di, params.Next)
weight := float32(1.0)
if contacts && params.ContactWeighting.IsTrue() {
if cWt > 0 {
weight = 1.0 / cWt
}
}
// weighted
dp := linDel.MulScalar(invMass * weight)
dq := angDel.MulScalar(weight)
// note: this is essential for rationalizing PlaneXZ and ball collision behavior!
dp = LimitDelta(dp, params.MaxDelta)
dq = LimitDelta(dq, params.MaxDelta)
wb := slmath.MulQuatVectorInverse(q0, w0)
dwb := invInertia.MulVector3(slmath.MulQuatVectorInverse(q0, dq))
// coriolis forces delta from dwb = (wb + dwb) I (wb + dwb) - wb I wb
tb := slmath.Cross3(dwb, inertia.MulVector3(wb.Add(dwb))).Add(slmath.Cross3(wb, inertia.MulVector3(dwb)))
dw1 := slmath.MulQuatVector(q0, dwb.Sub(invInertia.MulVector3(tb).MulScalar(params.Dt)))
// update orientation
q1 := q0.Add(slmath.MulQuats(math32.NewQuat(dw1.X, dw1.Y, dw1.Z, 0), q0).MulScalar(0.5 * params.Dt))
// q1 := q0 + 0.5 * wp.quat(dw1 * dt, 0.0) * q0
q1 = slmath.QuatNormalize(q1)
// update position
com := BodyCom(bi)
pcom := slmath.MulQuatVector(q0, com).Add(r0)
p1 := pcom.Add(dp.MulScalar(params.Dt))
p1 = p1.Sub(slmath.MulQuatVector(q1, com))
// update linear and angular velocity
v1 := v0.Add(dp)
w1 := w0.Add(dw1)
// this improves gradient stability
if slmath.Length3(v1) < 1e-4 {
v1 = math32.Vec3(0, 0, 0)
}
if slmath.Length3(w1) < 1e-4 {
w1 = math32.Vec3(0, 0, 0)
}
SetDynamicPos(di, params.Next, p1)
SetDynamicQuat(di, params.Next, q1)
SetDynamicDelta(di, params.Next, v1)
SetDynamicAngDelta(di, params.Next, w1)
if contacts {
StepBodyKinetics(di, bi)
}
}
// StepBodyKinetics computes the empirical velocities and
// accelerations from the final changes in position and orientation
// from Cur to Next.
func StepBodyKinetics(di, bi int32) {
params := GetParams(0)
r0 := DynamicPos(di, params.Cur)
q0 := DynamicQuat(di, params.Cur)
v0 := DynamicVel(di, params.Cur)
w0 := DynamicAngVel(di, params.Cur)
r1 := DynamicPos(di, params.Next)
q1 := DynamicQuat(di, params.Next)
com := BodyCom(bi)
com0 := slmath.MulQuatVector(q0, com).Add(r0)
com1 := slmath.MulQuatVector(q1, com).Add(r1)
v1 := com1.Sub(com0).DivScalar(params.Dt)
dq := slmath.MulQuats(q1, slmath.QuatInverse(q0))
w1 := math32.Vec3(dq.X, dq.Y, dq.Z).MulScalar(2 / params.Dt)
if dq.W < 0 {
w1 = slmath.Negate3(w1)
}
SetDynamicVel(di, params.Next, v1)
SetDynamicAngVel(di, params.Next, w1)
a1 := v1.Sub(v0).DivScalar(params.Dt)
wa1 := w1.Sub(w0).DivScalar(params.Dt)
SetDynamicAcc(di, params.Next, a1)
SetDynamicAngAcc(di, params.Next, wa1)
}
// LimitDelta limits the magnitude of a delta vector
func LimitDelta(v math32.Vector3, lim float32) math32.Vector3 {
l := slmath.Length3(v)
if l < lim {
return v
}
return v.MulScalar((lim / l))
}
func VelocityAtPoint(lin, ang, r math32.Vector3) math32.Vector3 {
return lin.Add(slmath.Cross3(ang, r))
}
//gosl:end
// Code generated by "goal build"; DO NOT EDIT.
//line step_joint.goal:1
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This code is adapted directly from https://github.com/newton-physics/newton
// Copyright (c) 2025 The Newton Developers, Released under an Apache-2.0 license
package physics
import (
// "fmt"
"cogentcore.org/core/math32"
"cogentcore.org/lab/gosl/slmath"
)
// notation convention:
// spatial transform: R = position, Q = quat rotation
// P = parent, C = child
// x = transform, w = world
// d = moment arm
//gosl:start
//gosl:import "cogentcore.org/lab/gosl/slmath"
// newton: solvers/xpbd/kernels.py: apply_joint_forces
// StepJointForces computes joint forces.
func StepJointForces(i uint32) { //gosl:kernel
params := GetParams(0)
ji := int32(i)
if ji >= params.JointsN {
return
}
zv := math32.Vec3(0, 0, 0)
SetJointPForce(ji, zv)
SetJointCForce(ji, zv)
SetJointPTorque(ji, zv)
SetJointCTorque(ji, zv)
jt := GetJointType(ji)
if !GetJointEnabled(ji) {
return
}
jPi := JointParentIndex(ji)
jPbi := int32(-1)
if jPi >= 0 {
jPbi = DynamicBody(jPi)
}
jCi := JointChildIndex(ji)
jCbi := DynamicBody(jCi)
jLinearN := GetJointLinearDoFN(ji)
jAngularN := GetJointAngularDoFN(ji)
jPR := JointPPos(ji)
jPQ := JointPQuat(ji)
// parent world transform
xwPR := jPR
xwPQ := jPQ
posePR := jPR
posePQ := jPQ
comP := math32.Vec3(0, 0, 0)
if jPi >= 0 { // can be fixed
posePR = DynamicPos(jPi, params.Cur)
posePQ = DynamicQuat(jPi, params.Cur)
slmath.MulSpatialTransforms(posePR, posePQ, jPR, jPQ, &xwPR, &xwPQ)
comP = BodyCom(jPbi)
}
dP := xwPR.Sub(slmath.MulSpatialPoint(posePR, posePQ, comP)) // parent moment arm
// child world transform
poseCR := DynamicPos(jCi, params.Cur)
poseCQ := DynamicQuat(jCi, params.Cur)
// note: NOT doing this: slmath.MulSpatialTransforms(poseCR, poseCQ, jCR, jCQ, &xwCR, &xwCQ)
// https://github.com/newton-physics/newton/issues/1261
comC := BodyCom(jCbi)
dC := poseCR.Sub(slmath.MulSpatialPoint(poseCR, poseCQ, comC)) // child moment arm
var f, t math32.Vector3
switch jt {
case Free, Distance:
f = math32.Vec3(JointControl(ji, 0, JointControlForce), JointControl(ji, 1, JointControlForce), JointControl(ji, 2, JointControlForce))
t = math32.Vec3(JointControl(ji, 3, JointControlForce), JointControl(ji, 4, JointControlForce), JointControl(ji, 5, JointControlForce))
case Ball:
// note: assuming the axes are x, y, z
t = math32.Vec3(JointControl(ji, 0, JointControlForce), JointControl(ji, 1, JointControlForce), JointControl(ji, 2, JointControlForce))
case Revolute:
axis := JointAxis(ji, 0)
t = slmath.MulQuatVector(xwPQ, axis).MulScalar(JointControl(ji, 0, JointControlForce))
case Prismatic:
axis := JointAxis(ji, 0)
f = slmath.MulQuatVector(xwPQ, axis).MulScalar(JointControl(ji, 0, JointControlForce))
default:
for dof := range jLinearN {
axis := JointAxis(ji, int32(dof))
f = f.Add(slmath.MulQuatVector(xwPQ, axis).MulScalar(JointControl(ji, int32(dof), JointControlForce)))
}
for dof := range jAngularN {
di := int32(jLinearN) + int32(dof)
axis := JointAxis(ji, di)
t = t.Add(slmath.MulQuatVector(xwPQ, axis).MulScalar(JointControl(ji, di, JointControlForce)))
}
}
// These are unique to joint: aggregate into dynamics Next in [ForcesFromJoints]
SetJointPForce(ji, slmath.Negate3(f))
SetJointCForce(ji, f)
SetJointPTorque(ji, slmath.Negate3(t.Add(slmath.Cross3(dP, f))))
SetJointCTorque(ji, t.Add(slmath.Cross3(dC, f)))
}
// newton: solvers/xpbd/kernels.py: solve_body_joints
// StepSolveJoints applies target positions to joints.
// This is per Object because it needs to solve joints in parent -> child order.
func StepSolveJoints(i uint32) { //gosl:kernel
params := GetParams(0)
oi := int32(i)
if oi >= params.ObjectsN {
return
}
n := Objects.Value(int(oi), int(0))
for i := int32(1); i < n+1; i++ {
ji := Objects.Value(int(oi), int(i))
jt := GetJointType(ji)
if jt == Free || !GetJointEnabled(ji) {
continue
}
StepSolveJoint(ji)
}
}
// StepSolveJoint applies target positions to linear DoFs.
// Position is updated prior to computing angulars.
func StepSolveJoint(ji int32) {
params := GetParams(0)
jt := GetJointType(ji)
jPi := JointParentIndex(ji)
jPbi := int32(-1)
parentFixed := true
if jPi >= 0 {
jPbi = DynamicBody(jPi)
parentFixed = GetJointParentFixed(ji)
}
jCi := JointChildIndex(ji)
jCbi := DynamicBody(jCi)
noLinearRot := GetJointNoLinearRotation(ji)
jLinearN := GetJointLinearDoFN(ji)
// jAngularN := GetJointAngularDoFN(ji)
jPR := JointPPos(ji)
jPQ := JointPQuat(ji)
xwPR := jPR // world xform, parent, pos
xwPQ := jPQ // quat
mInvP := float32(0.0)
iInvP := math32.Mat3(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
posePR := jPR
posePQ := jPQ
var comP, vP, wP math32.Vector3
// parent transform and moment arm
if jPi >= 0 {
posePR = DynamicPos(jPi, params.Next) // now using next
posePQ = DynamicQuat(jPi, params.Next)
slmath.MulSpatialTransforms(posePR, posePQ, jPR, jPQ, &xwPR, &xwPQ)
comP = BodyCom(jPbi)
mInvP = Bodies.Value(int(jPbi), int(BodyInvMass))
iInvP = BodyInvInertia(jPbi)
vP = DynamicDelta(jPi, params.Next)
wP = DynamicAngDelta(jPi, params.Next)
if mInvP == 0 {
parentFixed = true
}
}
// child transform and moment arm
poseCR := DynamicPos(jCi, params.Next)
poseCQ := DynamicQuat(jCi, params.Next)
jCR := JointCPos(ji)
jCQ := JointCQuat(ji)
xwCR := jCR
xwCQ := jCQ
slmath.MulSpatialTransforms(poseCR, poseCQ, jCR, jCQ, &xwCR, &xwCQ)
comC := BodyCom(jCbi)
mInvC := Bodies.Value(int(jCbi), int(BodyInvMass))
iInvC := BodyInvInertia(jCbi)
vC := DynamicDelta(jCi, params.Next)
wC := DynamicAngDelta(jCi, params.Next)
if mInvP == 0.0 && mInvC == 0.0 { // connection between two immovable bodies
return
}
// accumulate constraint deltas
var linDeltaP, angDeltaP, linDeltaC, angDeltaC math32.Vector3
relPoseR := xwPR
relPoseQ := xwPQ
slmath.SpatialTransformInverse(xwPR, xwPQ, &relPoseR, &relPoseQ)
slmath.MulSpatialTransforms(relPoseR, relPoseQ, xwCR, xwCQ, &relPoseR, &relPoseQ)
wComP := slmath.MulSpatialPoint(posePR, posePQ, comP)
wComC := slmath.MulSpatialPoint(poseCR, poseCQ, comC)
lambdaPrev := JointLinLambda(ji)
lambdaNext := math32.Vec3(0, 0, 0)
// handle positional constraints
if jt == Distance {
dP := xwPR.Sub(wComP)
dC := xwCR.Sub(wComC)
lo := JointDoF(ji, 0, JointLimitLower) // only first one has constraint
up := JointDoF(ji, 0, JointLimitUpper)
if lo < 0 && up < 0 { // not limited
return
}
d := slmath.Length3(relPoseR)
err := float32(0.0)
if lo >= 0.0 && d < lo {
err = d - lo
// use a more descriptive direction vector for the constraint
// in case the joint parent and child anchors are very close
relPoseR = slmath.Normal3(wComC.Sub(wComP)).MulScalar(err)
} else if up >= 0.0 && d > up {
err = d - up
}
if math32.Abs(err) > 1e-9 {
// compute gradients
linearC := relPoseR
linearP := slmath.Negate3(linearC)
dC = xwCR.Sub(wComC)
angularP := slmath.Negate3(slmath.Cross3(dP, linearC))
angularC := slmath.Cross3(dC, linearC)
// constraint time derivative
derr := slmath.Dot3(linearP, vP) + slmath.Dot3(linearC, vC) + slmath.Dot3(angularP, wP) + slmath.Dot3(angularC, wC)
lambdaIn := float32(0.0) // note: multiple iter is supposed to increment these
compliance := params.JointLinearComply
ke := JointControl(ji, 0, JointTargetStiff)
kd := JointControl(ji, 0, JointTargetDamp)
if ke > 0.0 {
compliance = 1.0 / ke
}
dLambda := PositionalCorrection(err, derr, posePQ, poseCQ, mInvP, mInvC,
iInvP, iInvC, linearP, linearC, angularP, angularC, lambdaIn, compliance, kd, params.Dt)
linDeltaP = linDeltaP.Add(linearP.MulScalar(dLambda * params.JointLinearRelax))
linDeltaC = linDeltaC.Add(linearC.MulScalar(dLambda * params.JointLinearRelax))
if !noLinearRot {
angDeltaP = angDeltaP.Add(angularP.MulScalar(dLambda * params.JointAngularRelax))
angDeltaC = angDeltaC.Add(angularC.MulScalar(dLambda * params.JointAngularRelax))
}
}
} else {
// all joints impose linear constraints!
var axisLimitsD, axisLimitsA math32.Vector3
var axisTargetPosKeD, axisTargetPosKeA math32.Vector3
var axisTargetVelKdD, axisTargetVelKdA math32.Vector3
for dof := range jLinearN {
axis := JointAxis(ji, dof)
JointAxisLimitsUpdate(dof, axis, JointDoF(ji, dof, JointLimitLower), JointDoF(ji, dof, JointLimitUpper), &axisLimitsD, &axisLimitsA)
ke := JointControl(ji, dof, JointTargetStiff)
kd := JointControl(ji, dof, JointTargetDamp)
targetPos := JointControl(ji, dof, JointTargetPosCur)
targetVel := JointControl(ji, dof, JointTargetVel)
if ke > 0.0 { // has position control
JointAxisTarget(axis, targetPos, ke, &axisTargetPosKeD, &axisTargetPosKeA)
}
if kd > 0.0 { // has velocity control
JointAxisTarget(axis, targetVel, kd, &axisTargetVelKdD, &axisTargetVelKdA)
}
}
axisStiffness := axisTargetPosKeA
axisDamping := axisTargetVelKdA
axisTargetPosKeD = slmath.DivSafe3(axisTargetPosKeD, axisStiffness)
axisTargetVelKdD = slmath.DivSafe3(axisTargetVelKdD, axisDamping)
axisLimitsLower := axisLimitsD
axisLimitsUpper := axisLimitsA
// note that xwCR appearing in both is correct:
dP := xwCR.Sub(wComP)
dC := xwCR.Sub(slmath.MulSpatialPoint(poseCR, poseCQ, comC))
for dim := range int32(3) {
e := slmath.Dim3(relPoseR, dim)
// compute gradients
// matrix indexing is [row, col] here: dim = col
// quat_to_matrix cols are q rotations of axis vectors
dima := slmath.SetDim3(math32.Vec3(0, 0, 0), dim, 1) // axis for dim
linearC := slmath.MulQuatVector(xwPQ, dima)
linearP := slmath.Negate3(linearC)
angularP := slmath.Negate3(slmath.Cross3(dP, linearC))
angularC := slmath.Cross3(dC, linearC)
// constraint time derivative
derr := slmath.Dot3(linearP, vP) + slmath.Dot3(linearC, vC) + slmath.Dot3(angularP, wP) + slmath.Dot3(angularC, wC)
err := float32(0.0)
compliance := params.JointLinearComply
damping := float32(0.0)
targetVel := slmath.Dim3(axisTargetVelKdD, dim)
derrRel := derr - targetVel
// consider joint limits irrespective of axis mode
lower := slmath.Dim3(axisLimitsLower, dim)
upper := slmath.Dim3(axisLimitsUpper, dim)
if e < lower {
err = e - lower
} else if e > upper {
err = e - upper
} else {
targetPos := slmath.Dim3(axisTargetPosKeD, dim)
targetPos = math32.Clamp(targetPos, lower, upper)
ke := slmath.Dim3(axisStiffness, dim)
kd := slmath.Dim3(axisDamping, dim)
if ke > 0.0 {
err = e - targetPos
compliance = 1.0 / ke
damping = slmath.Dim3(axisDamping, dim)
} else if kd > 0.0 {
compliance = 1.0 / kd
damping = kd
}
}
if math32.Abs(err) > 1e-9 || math32.Abs(derrRel) > 1e-9 {
// lambdaIn := slmath.Dim3(lambdaPrev, dim)
lambdaIn := float32(0)
dLambda := PositionalCorrection(err, derrRel, posePQ, poseCQ, mInvP, mInvC,
iInvP, iInvC, linearP, linearC, angularP, angularC, lambdaIn, compliance, damping, params.Dt)
linDeltaP = linDeltaP.Add(linearP.MulScalar(dLambda * params.JointLinearRelax))
linDeltaC = linDeltaC.Add(linearC.MulScalar(dLambda * params.JointLinearRelax))
if !noLinearRot {
angDeltaP = angDeltaP.Add(angularP.MulScalar(dLambda * params.JointAngularRelax))
angDeltaC = angDeltaC.Add(angularC.MulScalar(dLambda * params.JointAngularRelax))
}
lambdaNext = slmath.SetDim3(lambdaNext, dim, dLambda)
}
}
}
SetJointLinLambda(ji, lambdaNext)
//////// Angular DoFs
jAngularN := GetJointAngularDoFN(ji)
qP := xwPQ
qC := xwCQ
// make quats lie in same hemisphere
if slmath.QuatDot(qP, qC) < 0 {
qC = slmath.QuatMulScalar(qC, -1.0)
}
relQ := slmath.MulQuats(slmath.QuatInverse(qP), qC)
qtwist := slmath.QuatNormalize(math32.NewQuat(relQ.X, 0.0, 0.0, relQ.W))
qswing := slmath.MulQuats(relQ, slmath.QuatInverse(qtwist))
// decompose to a compound rotation each axis
s := math32.Sqrt(relQ.X*relQ.X + relQ.W*relQ.W)
if s == 0 {
// fmt.Println("s = 0", relQ, qP, qC)
s = 1
}
invs := 1.0 / s
invscube := invs * invs * invs
// handle axis-angle joints
// rescale twist from quaternion space to angular
err0 := 2.0 * math32.Asin(math32.Clamp(qtwist.X, -1.0, 1.0))
err1 := qswing.Y
err2 := qswing.Z
// analytic gradients of swing-twist decomposition
grad0 := math32.NewQuat(invs-relQ.X*relQ.X*invscube, 0.0, 0.0, -(relQ.W*relQ.X)*invscube)
grad1 := math32.NewQuat(
-relQ.W*(relQ.W*relQ.Z+relQ.X*relQ.Y)*invscube,
relQ.W*invs,
-relQ.X*invs,
relQ.X*(relQ.W*relQ.Z+relQ.X*relQ.Y)*invscube)
grad2 := math32.NewQuat(
relQ.W*(relQ.W*relQ.Y-relQ.X*relQ.Z)*invscube,
relQ.X*invs,
relQ.W*invs,
relQ.X*(relQ.Z*relQ.X-relQ.W*relQ.Y)*invscube)
grad0 = slmath.QuatMulScalar(grad0, 2.0/math32.Abs(qtwist.W))
// grad0 *= 2.0 / wp.sqrt(1.0-qtwist[0]*qtwist[0]) // derivative of asin(x) = 1/sqrt(1-x^2)
// rescale swing
swing_sq := qswing.W * qswing.W
// if swing axis magnitude close to zero vector, just treat in quaternion space
angularEps := float32(1.0e-4)
if swing_sq+angularEps < 1.0 {
d := math32.Sqrt(1.0 - qswing.W*qswing.W)
theta := 2.0 * math32.Acos(math32.Clamp(qswing.W, -1.0, 1.0))
scale := theta / d
err1 *= scale
err2 *= scale
grad1 = slmath.QuatMulScalar(grad1, scale)
grad2 = slmath.QuatMulScalar(grad2, scale)
}
errs := math32.Vec3(err0, err1, err2)
gradX := math32.Vec3(grad0.X, grad1.X, grad2.X)
gradY := math32.Vec3(grad0.Y, grad1.Y, grad2.Y)
gradZ := math32.Vec3(grad0.Z, grad1.Z, grad2.Z)
gradW := math32.Vec3(grad0.W, grad1.W, grad2.W)
// compute joint target, stiffness, damping
var axisLimitsD, axisLimitsA math32.Vector3
var axisTargetPosKeD, axisTargetPosKeA math32.Vector3
var axisTargetVelKdD, axisTargetVelKdA math32.Vector3
lambdaPrev = JointAngLambda(ji)
lambdaNext = math32.Vec3(0, 0, 0)
_ = lambdaPrev
for dof := range jAngularN {
di := dof + jLinearN
axis := JointAxis(ji, di)
JointAxisLimitsUpdate(dof, axis, JointDoF(ji, di, JointLimitLower), JointDoF(ji, di, JointLimitUpper), &axisLimitsD, &axisLimitsA)
ke := JointControl(ji, di, JointTargetStiff)
kd := JointControl(ji, di, JointTargetDamp)
targetPos := JointControl(ji, di, JointTargetPosCur)
targetVel := JointControl(ji, di, JointTargetVel)
if ke > 0.0 { // has position control
JointAxisTarget(axis, targetPos, ke, &axisTargetPosKeD, &axisTargetPosKeA)
}
if kd > 0.0 { // has velocity control
JointAxisTarget(axis, targetVel, kd, &axisTargetVelKdD, &axisTargetVelKdA)
}
}
axisStiffness := axisTargetPosKeA
axisDamping := axisTargetVelKdA
axisTargetPosKeD = slmath.DivSafe3(axisTargetPosKeD, axisStiffness)
axisTargetVelKdD = slmath.DivSafe3(axisTargetVelKdD, axisDamping)
axisLimitsLower := axisLimitsD
axisLimitsUpper := axisLimitsA
for dim := range int32(3) {
e := slmath.Dim3(errs, dim)
// analytic gradients of swing-twist decomposition
grad := math32.NewQuat(slmath.Dim3(gradX, dim), slmath.Dim3(gradY, dim), slmath.Dim3(gradZ, dim), slmath.Dim3(gradW, dim))
// todo: verify -- does the 0.5 go inside??
// quat_c = 0.5 * q_p * grad * wp.quat_inverse(q_c)
quatC := slmath.MulQuats(slmath.MulQuats(slmath.QuatMulScalar(qP, 0.5), grad), slmath.QuatInverse(qC))
angularC := math32.Vec3(quatC.X, quatC.Y, quatC.Z)
angularP := slmath.Negate3(angularC)
// constraint time derivative
derr := slmath.Dot3(angularP, wP) + slmath.Dot3(angularC, wC)
err := float32(0.0)
compliance := params.JointLinearComply
damping := float32(0.0)
targetVel := slmath.Dim3(axisTargetVelKdD, dim)
angularClen := slmath.Length3(angularC)
derrRel := derr - targetVel*angularClen
// consider joint limits irrespective of axis mode
lower := slmath.Dim3(axisLimitsLower, dim)
upper := slmath.Dim3(axisLimitsUpper, dim)
if e < lower {
err = e - lower
} else if e > upper {
err = e - upper
} else {
targetPos := slmath.Dim3(axisTargetPosKeD, dim)
targetPos = math32.Clamp(targetPos, lower, upper)
ke := slmath.Dim3(axisStiffness, dim)
kd := slmath.Dim3(axisDamping, dim)
if ke > 0.0 {
err = slmath.MinAngleDiff(e, targetPos)
compliance = 1.0 / ke
damping = slmath.Dim3(axisDamping, dim)
} else if kd > 0.0 {
compliance = 1.0 / kd
damping = kd
}
}
// lambdaIn := slmath.Dim3(lambdaPrev, dim)
lambdaIn := float32(0)
dLambda := AngularCorrection(err, derrRel, posePQ, poseCQ, iInvP, iInvC, angularP, angularC, lambdaIn, compliance, damping, params.Dt)
// note: no relaxation factors here:
angDeltaP = angDeltaP.Add(angularP.MulScalar(dLambda))
angDeltaC = angDeltaC.Add(angularC.MulScalar(dLambda))
lambdaNext = slmath.SetDim3(lambdaNext, dim, dLambda)
}
SetJointAngLambda(ji, lambdaNext)
if !parentFixed {
StepBodyDeltas(jPi, jPbi, false, 0, linDeltaP, angDeltaP)
}
if mInvC > 0 {
StepBodyDeltas(jCi, jCbi, false, 0, linDeltaC, angDeltaC)
}
}
func JointAxisTarget(axis math32.Vector3, targ, weight float32, axisTargets, axisWeights *math32.Vector3) {
weightedAxis := axis.MulScalar(weight)
*axisTargets = (*axisTargets).Add(weightedAxis.MulScalar(targ)) // weighted target (to be normalized later by sum of weights)
*axisWeights = (*axisWeights).Add(slmath.Abs3(weightedAxis))
}
func PositionalCorrection(err, derr float32, tfaQ, tfbQ math32.Quat, mInvA, mInvB float32, iInvA, iInvB math32.Matrix3, linA, linB, angA, angB math32.Vector3, lambdaIn, compliance, damping, dt float32) float32 {
denom := float32(0.0)
denom += slmath.LengthSquared3(linA) * mInvA
denom += slmath.LengthSquared3(linB) * mInvB
// # Eq. 2-3 (make sure to project into the frame of the body)
rotAngA := slmath.MulQuatVectorInverse(tfaQ, angA)
rotAngB := slmath.MulQuatVectorInverse(tfbQ, angB)
denom += slmath.Dot3(rotAngA, iInvA.MulVector3(rotAngA))
denom += slmath.Dot3(rotAngB, iInvB.MulVector3(rotAngB))
alpha := compliance
gamma := compliance * damping
lambda := -(err + alpha*lambdaIn + gamma*derr)
if denom+alpha > 0.0 {
lambda /= (dt+gamma)*denom + alpha/dt
}
return lambda
}
func AngularCorrection(err, derr float32, tfaQ, tfbQ math32.Quat, iInvA, iInvB math32.Matrix3, angA, angB math32.Vector3, lambdaIn, compliance, damping, dt float32) float32 {
// # Eq. 2-3 (make sure to project into the frame of the body)
rotAngA := slmath.MulQuatVectorInverse(tfaQ, angA)
rotAngB := slmath.MulQuatVectorInverse(tfbQ, angB)
denom := float32(0.0)
denom += slmath.Dot3(rotAngA, iInvA.MulVector3(rotAngA))
denom += slmath.Dot3(rotAngB, iInvB.MulVector3(rotAngB))
alpha := compliance
gamma := compliance * damping
deltaLambda := -(err + alpha*lambdaIn + gamma*derr)
if denom+alpha > 0.0 {
deltaLambda /= (dt+gamma)*denom + alpha/dt
}
return deltaLambda
}
// update the 3D linear/angular limits (spatial_vector [lower, upper])
// given the axis vector and limits
func JointAxisLimitsUpdate(dof int32, axis math32.Vector3, lower, upper float32, axisLimitsD, axisLimitsA *math32.Vector3) {
loTemp := axis.MulScalar(lower)
upTemp := axis.MulScalar(upper)
lo := slmath.Min3(loTemp, upTemp)
up := slmath.Max3(loTemp, upTemp)
if dof == 0 {
*axisLimitsD = lo
*axisLimitsA = up
} else {
*axisLimitsD = slmath.Min3(*axisLimitsD, lo)
*axisLimitsA = slmath.Max3(*axisLimitsA, up)
}
}
//gosl:end
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Adapted initially from gonum/plot:
// Copyright ©2015 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
import (
"math"
"cogentcore.org/core/math32"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/units"
)
// AxisScales are the scaling options for how values are distributed
// along an axis: Linear, Log, etc.
type AxisScales int32 //enums:enum
const (
// Linear is a linear axis scale.
Linear AxisScales = iota
// Log is a Logarithmic axis scale.
Log
// InverseLinear is an inverted linear axis scale.
InverseLinear
// InverseLog is an inverted log axis scale.
InverseLog
)
func (as AxisScales) Normalizer() Normalizer {
switch as {
case Linear:
return LinearScale{}
case Log:
return LogScale{}
case InverseLinear:
return InvertedScale{LinearScale{}}
case InverseLog:
return InvertedScale{LogScale{}}
}
return LinearScale{}
}
// AxisStyle has style properties for the axis.
type AxisStyle struct { //types:add -setters
// On determines whether the axis is rendered.
On bool
// Text has the text style parameters for the text label.
Text TextStyle
// Line has styling properties for the axis line.
Line LineStyle
// Padding between the axis line and the data. Having
// non-zero padding ensures that the data is never drawn
// on the axis, thus making it easier to see.
Padding units.Value
// NTicks is the desired number of ticks (actual likely
// will be different). If < 2 then the axis will not be drawn.
NTicks int
// Scale specifies how values are scaled along the axis:
// Linear, Log, Inverted
Scale AxisScales
// TickText has the text style for rendering tick labels,
// and is shared for actual rendering.
TickText TextStyle
// TickLine has line style for drawing tick lines.
TickLine LineStyle
// TickLength is the length of tick lines.
TickLength units.Value
}
func (ax *AxisStyle) Defaults() {
ax.On = true
ax.Line.Defaults()
ax.Text.Defaults()
ax.Text.Size.Dp(20)
ax.Padding.Pt(5)
ax.NTicks = 5
ax.TickText.Defaults()
ax.TickText.Size.Dp(16)
ax.TickText.Padding.Dp(2)
ax.TickLine.Defaults()
ax.TickLength.Pt(8)
}
// Axis represents either a horizontal or vertical axis of a plot.
// This is the "internal" data structure and should not be used for styling.
type Axis struct {
// Range has the Min, Max range of values for the axis (in raw data units.)
Range minmax.F64
// specifies which axis this is: X, Y or Z.
Axis math32.Dims
// For a Y axis, this puts the axis on the right (i.e., the second Y axis).
RightY bool
// Label for the axis.
Label Text
// Style has the style parameters for the Axis,
// copied from [PlotStyle] source.
Style AxisStyle
// TickText is used for rendering the tick text labels.
TickText Text
// Ticker generates the tick marks. Any tick marks
// returned by the Marker function that are not in
// range of the axis are not drawn.
Ticker Ticker
// Scale transforms a value given in the data coordinate system
// to the normalized coordinate system of the axis—its distance
// along the axis as a fraction of the axis range.
Scale Normalizer
// AutoRescale enables an axis to automatically adapt its minimum
// and maximum boundaries, according to its underlying Ticker.
AutoRescale bool
// cached list of ticks, set in size
ticks []Tick
}
// Sets Defaults, range is (∞, Â∞), and thus any finite
// value is less than Min and greater than Max.
func (ax *Axis) Defaults(dim math32.Dims) {
ax.Style.Defaults()
ax.Axis = dim
if dim == math32.Y {
ax.Label.Style.Rotation = -90
if ax.RightY {
ax.Style.TickText.Align = styles.Start
} else {
ax.Style.TickText.Align = styles.End
}
}
ax.Scale = LinearScale{}
ax.Ticker = DefaultTicks{}
}
// drawConfig configures for drawing.
func (ax *Axis) drawConfig() {
ax.Scale = ax.Style.Scale.Normalizer()
}
// SanitizeRange ensures that the range of the axis makes sense.
func (ax *Axis) SanitizeRange() {
ax.Range.Sanitize()
if ax.AutoRescale {
marks := ax.Ticker.Ticks(ax.Range.Min, ax.Range.Max, ax.Style.NTicks)
for _, t := range marks {
ax.Range.FitValInRange(t.Value)
}
}
}
func (ax *Axis) SetTickLabel(i int, lbl string) {
if len(ax.ticks) <= i {
return
}
ax.ticks[i].Label = lbl
}
// Normalizer rescales values from the data coordinate system to the
// normalized coordinate system.
type Normalizer interface {
// Normalize transforms a value x in the data coordinate system to
// the normalized coordinate system.
Normalize(min, max, x float64) float64
}
// LinearScale an be used as the value of an Axis.Scale function to
// set the axis to a standard linear scale.
type LinearScale struct{}
var _ Normalizer = LinearScale{}
// Normalize returns the fractional distance of x between min and max.
func (LinearScale) Normalize(min, max, x float64) float64 {
return (x - min) / (max - min)
}
// LogScale can be used as the value of an Axis.Scale function to
// set the axis to a log scale.
type LogScale struct{}
var _ Normalizer = LogScale{}
// Normalize returns the fractional logarithmic distance of
// x between min and max.
func (LogScale) Normalize(min, max, x float64) float64 {
if min <= 0 || max <= 0 || x <= 0 {
panic("Values must be greater than 0 for a log scale.")
}
logMin := math.Log(min)
return (math.Log(x) - logMin) / (math.Log(max) - logMin)
}
// InvertedScale can be used as the value of an Axis.Scale function to
// invert the axis using any Normalizer.
type InvertedScale struct{ Normalizer }
var _ Normalizer = InvertedScale{}
// Normalize returns a normalized [0, 1] value for the position of x.
func (is InvertedScale) Normalize(min, max, x float64) float64 {
return is.Normalizer.Normalize(max, min, x)
}
// Norm returns the value of x, given in the data coordinate
// system, normalized to its distance as a fraction of the
// range of this axis. For example, if x is a.Min then the return
// value is 0, and if x is a.Max then the return value is 1.
func (ax *Axis) Norm(x float64) float64 {
return ax.Scale.Normalize(ax.Range.Min, ax.Range.Max, x)
}
//////// VirtualAxis
// VirtualAxisStyle has style properties for a virtual (non-plotted) axis.
type VirtualAxisStyle struct { //types:add -setters
// Scale specifies how values are scaled along the axis:
// Linear, Log, Inverted
Scale AxisScales
}
// VirtualAxis represents a data role that is not plotted as a visible axis,
// such as the Size role controlling size of points.
// This is the "internal" data structure and should not be used for styling.
type VirtualAxis struct {
// Range has the Min, Max range of values for the axis (in raw data units.)
Range minmax.F64
// Style has the style parameters for the Axis,
// copied from [PlotStyle] source.
Style VirtualAxisStyle
// Scale transforms a value given in the data coordinate system
// to the normalized coordinate system of the axis—its distance
// along the axis as a fraction of the axis range.
Scale Normalizer
}
// drawConfig configures for drawing.
func (ax *VirtualAxis) drawConfig() {
ax.Scale = ax.Style.Scale.Normalizer()
}
// Norm returns the value of x, given in the data coordinate
// system, normalized to its distance as a fraction of the
// range of this axis. For example, if x is a.Min then the return
// value is 0, and if x is a.Max then the return value is 1.
func (ax *VirtualAxis) Norm(x float64) float64 {
return ax.Scale.Normalize(ax.Range.Min, ax.Range.Max, x)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Adapted from github.com/gonum/plot:
// Copyright ©2015 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
import (
"log/slog"
"math"
"reflect"
"strconv"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/reflectx"
"cogentcore.org/core/math32/minmax"
)
// data defines the main data interfaces for plotting
// and the different Roles for data.
var (
ErrInfinity = errors.New("plotter: infinite data point")
ErrNoData = errors.New("plotter: no data points")
)
// DataOrValuer processes the given argument which can be either a [Data]
// or a [Valuer]. If the latter, it is returned as a [Data] with the given Role.
// This is used for New Plotter methods that can take the default required Valuer,
// or a Data with roles explicitly defined. If the arg is neither, an error
// is returned.
func DataOrValuer(v any, role Roles) (Data, error) {
switch x := v.(type) {
case Valuer:
return Data{role: x}, nil
case Data:
return x, nil
case *Data:
return *x, nil
default:
return nil, errors.New("plot: argument was not a Data or Valuer (e.g., Tensor)")
}
}
// Data is a map of Roles and Data for that Role, providing the
// primary way of passing data to a Plotter
type Data map[Roles]Valuer
// Valuer is the data interface for plotting, supporting either
// float64 or string representations. It is satisfied by the tensor.Tensor
// interface, so a tensor can be used directly for plot Data.
type Valuer interface {
// Len returns the number of values.
Len() int
// Float1D(i int) returns float64 value at given index.
Float1D(i int) float64
// String1D(i int) returns string value at given index.
String1D(i int) string
}
// Roles are the roles that a given set of data values can play,
// designed to be sufficiently generalizable across all different
// types of plots, even if sometimes it is a bit of a stretch.
type Roles int32 //enums:enum
const (
// NoRole is the default no-role specified case.
NoRole Roles = iota
// X axis
X
// Y axis
Y
// Z axis
Z
// U is the X component of a vector or first quartile in Box plot, etc.
U
// V is the Y component of a vector or third quartile in a Box plot, etc.
V
// W is the Z component of a vector
W
// Low is a lower error bar or region.
Low
// High is an upper error bar or region.
High
// Size controls the size of points etc.
Size
// Label renders a label, typically from string data,
// but can also be used for values.
Label
// Split is a special role for table-based plots. The
// unique values of this data are used to split the other
// plot data into groups, with each group added to the legend.
// A different default color will be used for each such group.
Split
)
// CheckFloats returns an error if any of the arguments are Infinity.
// or if there are no non-NaN data points available for plotting.
func CheckFloats(fs ...float64) error {
n := 0
for _, f := range fs {
switch {
case math.IsNaN(f):
case math.IsInf(f, 0):
return ErrInfinity
default:
n++
}
}
if n == 0 {
return ErrNoData
}
return nil
}
// CheckNaNs returns true if any of the floats are NaN
func CheckNaNs(fs ...float64) bool {
for _, f := range fs {
if math.IsNaN(f) {
return true
}
}
return false
}
// Range updates given Range with values from data.
func Range(data Valuer, rng *minmax.F64) {
for i := 0; i < data.Len(); i++ {
v := data.Float1D(i)
if math.IsNaN(v) {
continue
}
rng.FitValInRange(v)
}
}
// RangeClamp updates the given axis Min, Max range values based
// on the range of values in the given [Data], and the given style range.
func RangeClamp(data Valuer, axisRng *minmax.F64, styleRng *minmax.Range64) {
Range(data, axisRng)
axisRng.Min, axisRng.Max = styleRng.Clamp(axisRng.Min, axisRng.Max)
}
// CheckLengths checks that all the data elements have the same length.
// Logs and returns an error if not.
func (dt Data) CheckLengths() error {
n := 0
for _, v := range dt {
if n == 0 {
n = v.Len()
} else {
if v.Len() != n {
err := errors.New("plot.Data has inconsistent lengths -- all data elements must have the same length -- plotting aborted")
return errors.Log(err)
}
}
}
return nil
}
// NewXY returns a new Data with X, Y roles for given values.
func NewXY(x, y Valuer) Data {
return Data{X: x, Y: y}
}
// NewY returns a new Data with Y role for given values.
func NewY(y Valuer) Data {
return Data{Y: y}
}
// Values provides a minimal implementation of the Data interface
// using a slice of float64.
type Values []float64
func (vs Values) Len() int {
return len(vs)
}
func (vs Values) Float1D(i int) float64 {
return vs[i]
}
func (vs Values) String1D(i int) string {
return strconv.FormatFloat(vs[i], 'g', -1, 64)
}
// CopyValues returns a Values that is a copy of the values
// from Data, or an error if there are no values, or if one of
// the copied values is a Infinity.
// NaN values are skipped in the copying process.
func CopyValues(data Valuer) (Values, error) {
if reflectx.IsNil(reflect.ValueOf(data)) {
return nil, ErrNoData
}
cpy := make(Values, 0, data.Len())
for i := 0; i < data.Len(); i++ {
v := data.Float1D(i)
if math.IsNaN(v) {
continue
}
if err := CheckFloats(v); err != nil {
return nil, err
}
cpy = append(cpy, v)
}
return cpy, nil
}
// MustCopyRole returns Values copy of given role from given data map,
// logging an error and returning nil if not present.
func MustCopyRole(data Data, role Roles) Values {
d, ok := data[role]
if !ok {
slog.Error("plot Data role not present, but is required", "role:", role)
return nil
}
v, _ := CopyValues(d)
return v
}
// CopyRole returns Values copy of given role from given data map,
// returning nil if role not present.
func CopyRole(data Data, role Roles) Values {
d, ok := data[role]
if !ok {
return nil
}
v, _ := CopyValues(d)
return v
}
// CopyRoleLabels returns Labels copy of given role from given data map,
// returning nil if role not present.
func CopyRoleLabels(data Data, role Roles) Labels {
d, ok := data[role]
if !ok {
return nil
}
l := make(Labels, d.Len())
for i := range l {
l[i] = d.String1D(i)
}
return l
}
// PlotX returns plot pixel X coordinate values for given data.
func PlotX(plt *Plot, data Valuer) []float32 {
px := make([]float32, data.Len())
for i := range px {
px[i] = plt.PX(data.Float1D(i))
}
return px
}
// PlotY returns plot pixel Y coordinate values for given data.
func PlotY(plt *Plot, data Valuer) []float32 {
py := make([]float32, data.Len())
for i := range py {
py[i] = plt.PY(data.Float1D(i))
}
return py
}
// PlotYR returns plot pixel YR right axis coordinate values for given data.
func PlotYR(plt *Plot, data Valuer) []float32 {
py := make([]float32, data.Len())
for i := range py {
py[i] = plt.PYR(data.Float1D(i))
}
return py
}
//////// Labels
// Labels provides a minimal implementation of the Data interface
// using a slice of string. It always returns 0 for Float1D.
type Labels []string
func (lb Labels) Len() int {
return len(lb)
}
func (lb Labels) Float1D(i int) float64 {
return 0
}
func (lb Labels) String1D(i int) string {
return lb[i]
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Adapted from gonum/plot:
// Copyright ©2015 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
import (
"image"
"os"
"cogentcore.org/core/base/iox/imagex"
"cogentcore.org/core/math32"
"cogentcore.org/core/paint"
"cogentcore.org/core/paint/render"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/sides"
"cogentcore.org/core/text/shaped"
)
// RenderImage renders the plot to an image and returns it.
func (pt *Plot) RenderImage() image.Image {
return paint.RenderToImage(pt.Draw(nil))
}
// RenderSVG renders the plot to an SVG document and returns it.
func (pt *Plot) RenderSVG() []byte {
return paint.RenderToSVG(pt.Draw(nil))
}
// SaveImage renders the plot to an image and saves it to given filename,
// using the filename extension to determine the file type.
func (pt *Plot) SaveImage(fname string) error {
return imagex.Save(pt.RenderImage(), fname)
}
// SaveSVG renders the plot to an SVG document and saves it to given filename.
func (pt *Plot) SaveSVG(fname string) error {
return os.WriteFile(fname, pt.RenderSVG(), 0666)
}
// drawConfig configures everything for drawing, applying styles etc.
func (pt *Plot) drawConfig() {
pt.applyStyle()
pt.X.drawConfig()
pt.Y.drawConfig()
pt.YR.drawConfig()
pt.Z.drawConfig()
pt.SizeAxis.drawConfig()
pt.Painter.ToDots()
}
// Draw draws the plot to a core paint.Painter, which then
// can be used with a core render.Renderer to generate an image or SVG
// document, etc. If painter is nil, then one is created.
// See [Plot.RenderImage] and [Plot.RenderSVG] for convenience methods.
// Plotters are drawn in the order in which they were added to the plot.
func (pt *Plot) Draw(pc *paint.Painter) *paint.Painter {
ptb := pt.PaintBox
off := math32.FromPoint(pt.PaintBox.Min)
sz := pt.PaintBox.Size()
ptw := float32(sz.X)
pth := float32(sz.Y)
if pc != nil {
pt.Painter = pc
if pt.unitContext != nil && pt.unitContext.DPI != pc.UnitContext.DPI {
pt.unitContext = nil // regenerate with given painter
}
} else {
pt.Painter = paint.NewPainter(math32.FromPoint(sz))
pc = pt.Painter
}
if pt.TextShaper == nil {
shaperMu.Lock()
if plotShaper == nil {
plotShaper = shaped.NewShaper()
}
pt.TextShaper = plotShaper
shaperMu.Unlock()
defer func() {
pt.TextShaper = nil
}()
}
pt.drawConfig()
pc.PushContext(pc.Paint, render.NewBoundsRect(pt.PaintBox, sides.Floats{}))
if pt.Style.Background != nil {
pc.BlitBox(off, math32.FromPoint(sz), pt.Style.Background)
}
if pt.Title.Text != "" {
pt.Title.Config(pt)
pos := pt.Title.PosX(ptw)
pad := pt.Title.Style.Padding.Dots
pos.Y = pad
pt.Title.Draw(pt, pos.Add(off))
rsz := pt.Title.PaintText.Bounds.Size().Ceil()
th := rsz.Y + 2*pad
pth -= th
ptb.Min.Y += int(math32.Ceil(th))
}
pt.X.SanitizeRange()
pt.Y.SanitizeRange()
pt.YR.SanitizeRange()
ywidth, tickWidth, tpad, bpad := pt.Y.sizeY(pt, ptb.Min.Y)
yrwidth, yrtickWidth, yrtpad, yrbpad := pt.YR.sizeY(pt, ptb.Min.Y)
xheight, lpad, rpad := pt.X.sizeX(pt, float32(ywidth), float32(yrwidth), float32(sz.X-int(ywidth+yrwidth)))
tb := ptb
tb.Min.X += ywidth
tb.Max.X -= yrwidth
pt.X.drawX(pt, tb, lpad, rpad)
tb = ptb
tb.Max.Y -= xheight
pt.Y.drawY(pt, tb, tickWidth, tpad, bpad)
pt.YR.drawY(pt, tb, yrtickWidth, yrtpad, yrbpad)
tb = ptb
tb.Min.X += ywidth + lpad
tb.Max.X -= yrwidth + rpad
tb.Max.Y -= xheight + bpad
tb.Min.Y += tpad
pt.PlotBox.SetFromRect(tb)
// don't cut off lines
tb.Min.X -= 2
tb.Min.Y -= 2
tb.Max.X += 2
tb.Max.Y += 2
pt.PushBounds(tb)
for _, plt := range pt.Plotters {
plt.Plot(pt)
}
pt.Legend.draw(pt)
pc.PopContext()
pc.PopContext() // global
pt.Painter = nil
return pc
}
//////// Axis
// drawTicks returns true if the tick marks should be drawn.
func (ax *Axis) drawTicks() bool {
return ax.Style.TickLine.Width.Value > 0 && ax.Style.TickLength.Value > 0
}
// sizeX returns the total height of the axis, left and right padding
func (ax *Axis) sizeX(pt *Plot, yw, yrw, axw float32) (ht, lpad, rpad int) {
if !ax.Style.On {
return
}
uc := pt.UnitContext()
ax.Style.TickLength.ToDots(uc)
ax.ticks = ax.Ticker.Ticks(ax.Range.Min, ax.Range.Max, ax.Style.NTicks)
h := float32(0)
if ax.Label.Text != "" { // We assume that the label isn't rotated.
ax.Label.Config(pt)
h += ax.Label.Size().Y
h += ax.Label.Style.Padding.Dots
}
lw := ax.Style.Line.Width.Dots
lpad = int(math32.Ceil(lw)) + 4
rpad = int(math32.Ceil(lw)) + 4
tht := float32(0)
if len(ax.ticks) > 0 {
if ax.drawTicks() {
h += ax.Style.TickLength.Dots
}
ftk := ax.firstTickLabel()
if ftk.Label != "" {
px, _ := ax.tickPosX(pt, ftk, axw)
if px < -yw {
lpad += int(math32.Ceil(-px - yw))
}
tht = max(tht, ax.TickText.Size().Y)
}
ltk := ax.lastTickLabel()
if ltk.Label != "" {
px, wd := ax.tickPosX(pt, ltk, axw)
if px+wd > axw+yrw {
rpad += int(math32.Ceil((px + wd) - (axw + yrw)))
}
tht = max(tht, ax.TickText.Size().Y)
}
ax.TickText.Text = ax.longestTickLabel()
if ax.TickText.Text != "" {
ax.TickText.Config(pt)
tht = max(tht, ax.TickText.Size().Y)
}
h += ax.TickText.Style.Padding.Dots
}
h += tht + lw + ax.Style.Padding.Dots
ht = int(math32.Ceil(h))
return
}
// tickLabelPosX returns the relative position and width for given tick along X axis
// for given total axis width
func (ax *Axis) tickPosX(pt *Plot, t Tick, axw float32) (px, wd float32) {
x := axw * float32(ax.Norm(t.Value))
if x < 0 || x > axw {
return
}
ax.TickText.Text = t.Label
ax.TickText.Config(pt)
pos := ax.TickText.PosX(0)
px = pos.X + x
wd = ax.TickText.Size().X
return
}
func (ax *Axis) firstTickLabel() Tick {
for _, tk := range ax.ticks {
if tk.Label != "" {
return tk
}
}
return Tick{}
}
func (ax *Axis) lastTickLabel() Tick {
n := len(ax.ticks)
for i := n - 1; i >= 0; i-- {
tk := ax.ticks[i]
if tk.Label != "" {
return tk
}
}
return Tick{}
}
func (ax *Axis) longestTickLabel() string {
lst := ""
for _, tk := range ax.ticks {
if len(tk.Label) > len(lst) {
lst = tk.Label
}
}
return lst
}
func (ax *Axis) sizeY(pt *Plot, theight int) (ywidth, tickWidth, tpad, bpad int) {
if !ax.Style.On {
return
}
uc := pt.UnitContext()
ax.ticks = ax.Ticker.Ticks(ax.Range.Min, ax.Range.Max, ax.Style.NTicks)
ax.Style.TickLength.ToDots(uc)
w := float32(0)
if ax.Label.Text != "" {
ax.Label.Config(pt)
w += ax.Label.Size().X
w += ax.Label.Style.Padding.Dots
}
lw := ax.Style.Line.Width.Dots
tpad = int(math32.Ceil(lw)) + 2
bpad = int(math32.Ceil(lw)) + 2
if len(ax.ticks) > 0 {
if ax.drawTicks() {
w += ax.Style.TickLength.Dots
}
ax.TickText.Text = ax.longestTickLabel()
if ax.TickText.Text != "" {
ax.TickText.Config(pt)
tw := math32.Ceil(ax.TickText.Size().X + ax.TickText.Style.Padding.Dots)
w += tw
tickWidth = int(tw)
tht := int(math32.Ceil(0.5 * ax.TickText.Size().X))
if theight == 0 {
tpad += tht
}
}
}
w += lw + ax.Style.Padding.Dots
ywidth = int(math32.Ceil(w))
return
}
// drawX draws the horizontal axis
func (ax *Axis) drawX(pt *Plot, ab image.Rectangle, lpad, rpad int) {
if !ax.Style.On {
return
}
ab.Min.X += lpad
ab.Max.X -= rpad
axw := float32(ab.Size().X)
// axh := float32(ab.Size().Y) // height of entire plot
if ax.Label.Text != "" {
ax.Label.Config(pt)
pos := ax.Label.PosX(axw)
pos.X += float32(ab.Min.X)
th := ax.Label.Size().Y
pos.Y = float32(ab.Max.Y) - th
ax.Label.Draw(pt, pos)
ab.Max.Y -= int(math32.Ceil(th + ax.Label.Style.Padding.Dots))
}
tickHt := float32(0)
for _, t := range ax.ticks {
x := axw * float32(ax.Norm(t.Value))
if x < 0 || x > axw || t.IsMinor() {
continue
}
ax.TickText.Text = t.Label
ax.TickText.Config(pt)
pos := ax.TickText.PosX(0)
pos.X += x + float32(ab.Min.X)
tickHt = ax.TickText.Size().Y + ax.TickText.Style.Padding.Dots
pos.Y += float32(ab.Max.Y) - tickHt
ax.TickText.Draw(pt, pos)
}
if len(ax.ticks) > 0 {
ab.Max.Y -= int(math32.Ceil(tickHt))
// } else {
// y += ax.Width / 2
}
if len(ax.ticks) > 0 && ax.drawTicks() {
ln := ax.Style.TickLength.Dots
for _, t := range ax.ticks {
yoff := float32(0)
if t.IsMinor() {
yoff = 0.5 * ln
}
x := axw * float32(ax.Norm(t.Value))
if x < 0 || x > axw {
continue
}
x += float32(ab.Min.X)
ax.Style.TickLine.Draw(pt, math32.Vec2(x, float32(ab.Max.Y)-yoff), math32.Vec2(x, float32(ab.Max.Y)-ln))
}
ab.Max.Y -= int(ln - 0.5*ax.Style.Line.Width.Dots)
}
ax.Style.Line.Draw(pt, math32.Vec2(float32(ab.Min.X), float32(ab.Max.Y)), math32.Vec2(float32(ab.Min.X)+axw, float32(ab.Max.Y)))
}
// drawY draws the Y axis along the left side
func (ax *Axis) drawY(pt *Plot, ab image.Rectangle, tickWidth, tpad, bpad int) {
if !ax.Style.On {
return
}
ab.Min.Y += tpad
ab.Max.Y -= bpad
axh := float32(ab.Size().Y)
xpos := float32(ab.Min.X)
if ax.RightY {
xpos = float32(ab.Max.X)
}
if ax.Label.Text != "" {
ax.Label.Style.Align = styles.Center
pos := ax.Label.PosY(axh)
tw := math32.Ceil(ax.Label.Size().X + ax.Label.Style.Padding.Dots)
if ax.RightY {
pos.Y += float32(ab.Min.Y)
pos.X = xpos
xpos -= tw
} else {
pos.Y += float32(ab.Min.Y) + ax.Label.Size().Y
pos.X = xpos
xpos += tw
}
ax.Label.Draw(pt, pos)
}
if len(ax.ticks) > 0 && ax.RightY {
xpos -= float32(tickWidth)
}
for _, t := range ax.ticks {
y := axh * (1 - float32(ax.Norm(t.Value)))
if y < 0 || y > axh || t.IsMinor() {
continue
}
ax.TickText.Text = t.Label
ax.TickText.Config(pt)
pos := ax.TickText.PosX(float32(tickWidth))
pos.X += xpos
pos.Y = float32(ab.Min.Y) + y - 0.5*ax.TickText.Size().Y
ax.TickText.Draw(pt, pos)
}
if len(ax.ticks) > 0 && !ax.RightY {
xpos += float32(tickWidth)
}
if len(ax.ticks) > 0 && ax.drawTicks() {
ln := ax.Style.TickLength.Dots
if ax.RightY {
xpos -= math32.Ceil(ln + 0.5*ax.Style.Line.Width.Dots)
}
for _, t := range ax.ticks {
xoff := float32(0)
eln := ln
if t.IsMinor() {
if ax.RightY {
eln *= .5
} else {
xoff = 0.5 * ln
}
}
y := axh * (1 - float32(ax.Norm(t.Value)))
if y < 0 || y > axh {
continue
}
y += float32(ab.Min.Y)
ax.Style.TickLine.Draw(pt, math32.Vec2(xpos+xoff, y), math32.Vec2(xpos+eln, y))
}
if !ax.RightY {
xpos += math32.Ceil(ln + 0.5*ax.Style.Line.Width.Dots)
}
}
ax.Style.Line.Draw(pt, math32.Vec2(xpos, float32(ab.Min.Y)), math32.Vec2(xpos, float32(ab.Max.Y)))
}
//////// Legend
// draw draws the legend
func (lg *Legend) draw(pt *Plot) {
pc := pt.Painter
uc := pt.UnitContext()
ptb := pt.CurBounds()
lg.Style.ThumbnailWidth.ToDots(uc)
lg.Style.Position.XOffs.ToDots(uc)
lg.Style.Position.YOffs.ToDots(uc)
var ltxt Text
ltxt.Defaults()
ltxt.Style = lg.Style.Text
ltxt.ToDots(uc)
pad := math32.Ceil(ltxt.Style.Padding.Dots)
em := ltxt.textStyle.FontHeight(<xt.font)
var sz image.Point
maxTht := 0
for _, e := range lg.Entries {
ltxt.Text = e.Text
ltxt.Config(pt)
sz.X = max(sz.X, int(math32.Ceil(ltxt.Size().X)))
tht := int(math32.Ceil(ltxt.Size().Y + pad))
maxTht = max(tht, maxTht)
}
sz.X += int(em)
sz.Y = len(lg.Entries) * maxTht
txsz := sz
sz.X += int(lg.Style.ThumbnailWidth.Dots)
pos := ptb.Min
if lg.Style.Position.Left {
pos.X += int(lg.Style.Position.XOffs.Dots)
} else {
pos.X = ptb.Max.X - sz.X - int(lg.Style.Position.XOffs.Dots)
}
if lg.Style.Position.Top {
pos.Y += int(lg.Style.Position.YOffs.Dots)
} else {
pos.Y = ptb.Max.Y - sz.Y - int(lg.Style.Position.YOffs.Dots)
}
if lg.Style.Fill != nil {
pc.FillBox(math32.FromPoint(pos), math32.FromPoint(sz), lg.Style.Fill)
}
cp := pos
thsz := image.Point{X: int(lg.Style.ThumbnailWidth.Dots), Y: maxTht - 2*int(pad)}
for _, e := range lg.Entries {
tp := cp
tp.X += int(txsz.X)
tp.Y += int(pad)
tb := image.Rectangle{Min: tp, Max: tp.Add(thsz)}
pt.PushBounds(tb)
for _, t := range e.Thumbs {
t.Thumbnail(pt)
}
pc.PopContext()
ltxt.Text = e.Text
ltxt.Config(pt)
ltxt.Draw(pt, math32.FromPoint(cp))
cp.Y += maxTht
}
}
// Code generated by "core generate -add-types"; DO NOT EDIT.
package plot
import (
"cogentcore.org/core/enums"
)
var _AxisScalesValues = []AxisScales{0, 1, 2, 3}
// AxisScalesN is the highest valid value for type AxisScales, plus one.
const AxisScalesN AxisScales = 4
var _AxisScalesValueMap = map[string]AxisScales{`Linear`: 0, `Log`: 1, `InverseLinear`: 2, `InverseLog`: 3}
var _AxisScalesDescMap = map[AxisScales]string{0: `Linear is a linear axis scale.`, 1: `Log is a Logarithmic axis scale.`, 2: `InverseLinear is an inverted linear axis scale.`, 3: `InverseLog is an inverted log axis scale.`}
var _AxisScalesMap = map[AxisScales]string{0: `Linear`, 1: `Log`, 2: `InverseLinear`, 3: `InverseLog`}
// String returns the string representation of this AxisScales value.
func (i AxisScales) String() string { return enums.String(i, _AxisScalesMap) }
// SetString sets the AxisScales value from its string representation,
// and returns an error if the string is invalid.
func (i *AxisScales) SetString(s string) error {
return enums.SetString(i, s, _AxisScalesValueMap, "AxisScales")
}
// Int64 returns the AxisScales value as an int64.
func (i AxisScales) Int64() int64 { return int64(i) }
// SetInt64 sets the AxisScales value from an int64.
func (i *AxisScales) SetInt64(in int64) { *i = AxisScales(in) }
// Desc returns the description of the AxisScales value.
func (i AxisScales) Desc() string { return enums.Desc(i, _AxisScalesDescMap) }
// AxisScalesValues returns all possible values for the type AxisScales.
func AxisScalesValues() []AxisScales { return _AxisScalesValues }
// Values returns all possible values for the type AxisScales.
func (i AxisScales) Values() []enums.Enum { return enums.Values(_AxisScalesValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i AxisScales) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *AxisScales) UnmarshalText(text []byte) error {
return enums.UnmarshalText(i, text, "AxisScales")
}
var _RolesValues = []Roles{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}
// RolesN is the highest valid value for type Roles, plus one.
const RolesN Roles = 12
var _RolesValueMap = map[string]Roles{`NoRole`: 0, `X`: 1, `Y`: 2, `Z`: 3, `U`: 4, `V`: 5, `W`: 6, `Low`: 7, `High`: 8, `Size`: 9, `Label`: 10, `Split`: 11}
var _RolesDescMap = map[Roles]string{0: `NoRole is the default no-role specified case.`, 1: `X axis`, 2: `Y axis`, 3: `Z axis`, 4: `U is the X component of a vector or first quartile in Box plot, etc.`, 5: `V is the Y component of a vector or third quartile in a Box plot, etc.`, 6: `W is the Z component of a vector`, 7: `Low is a lower error bar or region.`, 8: `High is an upper error bar or region.`, 9: `Size controls the size of points etc.`, 10: `Label renders a label, typically from string data, but can also be used for values.`, 11: `Split is a special role for table-based plots. The unique values of this data are used to split the other plot data into groups, with each group added to the legend. A different default color will be used for each such group.`}
var _RolesMap = map[Roles]string{0: `NoRole`, 1: `X`, 2: `Y`, 3: `Z`, 4: `U`, 5: `V`, 6: `W`, 7: `Low`, 8: `High`, 9: `Size`, 10: `Label`, 11: `Split`}
// String returns the string representation of this Roles value.
func (i Roles) String() string { return enums.String(i, _RolesMap) }
// SetString sets the Roles value from its string representation,
// and returns an error if the string is invalid.
func (i *Roles) SetString(s string) error { return enums.SetString(i, s, _RolesValueMap, "Roles") }
// Int64 returns the Roles value as an int64.
func (i Roles) Int64() int64 { return int64(i) }
// SetInt64 sets the Roles value from an int64.
func (i *Roles) SetInt64(in int64) { *i = Roles(in) }
// Desc returns the description of the Roles value.
func (i Roles) Desc() string { return enums.Desc(i, _RolesDescMap) }
// RolesValues returns all possible values for the type Roles.
func RolesValues() []Roles { return _RolesValues }
// Values returns all possible values for the type Roles.
func (i Roles) Values() []enums.Enum { return enums.Values(_RolesValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i Roles) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *Roles) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Roles") }
var _StepKindValues = []StepKind{0, 1, 2, 3}
// StepKindN is the highest valid value for type StepKind, plus one.
const StepKindN StepKind = 4
var _StepKindValueMap = map[string]StepKind{`NoStep`: 0, `PreStep`: 1, `MidStep`: 2, `PostStep`: 3}
var _StepKindDescMap = map[StepKind]string{0: `NoStep connects two points by simple line.`, 1: `PreStep connects two points by following lines: vertical, horizontal.`, 2: `MidStep connects two points by following lines: horizontal, vertical, horizontal. Vertical line is placed in the middle of the interval.`, 3: `PostStep connects two points by following lines: horizontal, vertical.`}
var _StepKindMap = map[StepKind]string{0: `NoStep`, 1: `PreStep`, 2: `MidStep`, 3: `PostStep`}
// String returns the string representation of this StepKind value.
func (i StepKind) String() string { return enums.String(i, _StepKindMap) }
// SetString sets the StepKind value from its string representation,
// and returns an error if the string is invalid.
func (i *StepKind) SetString(s string) error {
return enums.SetString(i, s, _StepKindValueMap, "StepKind")
}
// Int64 returns the StepKind value as an int64.
func (i StepKind) Int64() int64 { return int64(i) }
// SetInt64 sets the StepKind value from an int64.
func (i *StepKind) SetInt64(in int64) { *i = StepKind(in) }
// Desc returns the description of the StepKind value.
func (i StepKind) Desc() string { return enums.Desc(i, _StepKindDescMap) }
// StepKindValues returns all possible values for the type StepKind.
func StepKindValues() []StepKind { return _StepKindValues }
// Values returns all possible values for the type StepKind.
func (i StepKind) Values() []enums.Enum { return enums.Values(_StepKindValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i StepKind) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *StepKind) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "StepKind") }
var _ShapesValues = []Shapes{0, 1, 2, 3, 4, 5, 6, 7}
// ShapesN is the highest valid value for type Shapes, plus one.
const ShapesN Shapes = 8
var _ShapesValueMap = map[string]Shapes{`Circle`: 0, `Box`: 1, `Pyramid`: 2, `Plus`: 3, `Cross`: 4, `Ring`: 5, `Square`: 6, `Triangle`: 7}
var _ShapesDescMap = map[Shapes]string{0: `Circle is a solid circle`, 1: `Box is a filled square`, 2: `Pyramid is a filled triangle`, 3: `Plus is a plus sign`, 4: `Cross is a big X`, 5: `Ring is the outline of a circle`, 6: `Square is the outline of a square`, 7: `Triangle is the outline of a triangle`}
var _ShapesMap = map[Shapes]string{0: `Circle`, 1: `Box`, 2: `Pyramid`, 3: `Plus`, 4: `Cross`, 5: `Ring`, 6: `Square`, 7: `Triangle`}
// String returns the string representation of this Shapes value.
func (i Shapes) String() string { return enums.String(i, _ShapesMap) }
// SetString sets the Shapes value from its string representation,
// and returns an error if the string is invalid.
func (i *Shapes) SetString(s string) error { return enums.SetString(i, s, _ShapesValueMap, "Shapes") }
// Int64 returns the Shapes value as an int64.
func (i Shapes) Int64() int64 { return int64(i) }
// SetInt64 sets the Shapes value from an int64.
func (i *Shapes) SetInt64(in int64) { *i = Shapes(in) }
// Desc returns the description of the Shapes value.
func (i Shapes) Desc() string { return enums.Desc(i, _ShapesDescMap) }
// ShapesValues returns all possible values for the type Shapes.
func ShapesValues() []Shapes { return _ShapesValues }
// Values returns all possible values for the type Shapes.
func (i Shapes) Values() []enums.Enum { return enums.Values(_ShapesValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i Shapes) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *Shapes) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Shapes") }
var _DefaultOffOnValues = []DefaultOffOn{0, 1, 2}
// DefaultOffOnN is the highest valid value for type DefaultOffOn, plus one.
const DefaultOffOnN DefaultOffOn = 3
var _DefaultOffOnValueMap = map[string]DefaultOffOn{`Default`: 0, `Off`: 1, `On`: 2}
var _DefaultOffOnDescMap = map[DefaultOffOn]string{0: `Default means use the default value.`, 1: `Off means to override the default and turn Off.`, 2: `On means to override the default and turn On.`}
var _DefaultOffOnMap = map[DefaultOffOn]string{0: `Default`, 1: `Off`, 2: `On`}
// String returns the string representation of this DefaultOffOn value.
func (i DefaultOffOn) String() string { return enums.String(i, _DefaultOffOnMap) }
// SetString sets the DefaultOffOn value from its string representation,
// and returns an error if the string is invalid.
func (i *DefaultOffOn) SetString(s string) error {
return enums.SetString(i, s, _DefaultOffOnValueMap, "DefaultOffOn")
}
// Int64 returns the DefaultOffOn value as an int64.
func (i DefaultOffOn) Int64() int64 { return int64(i) }
// SetInt64 sets the DefaultOffOn value from an int64.
func (i *DefaultOffOn) SetInt64(in int64) { *i = DefaultOffOn(in) }
// Desc returns the description of the DefaultOffOn value.
func (i DefaultOffOn) Desc() string { return enums.Desc(i, _DefaultOffOnDescMap) }
// DefaultOffOnValues returns all possible values for the type DefaultOffOn.
func DefaultOffOnValues() []DefaultOffOn { return _DefaultOffOnValues }
// Values returns all possible values for the type DefaultOffOn.
func (i DefaultOffOn) Values() []enums.Enum { return enums.Values(_DefaultOffOnValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i DefaultOffOn) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *DefaultOffOn) UnmarshalText(text []byte) error {
return enums.UnmarshalText(i, text, "DefaultOffOn")
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Copied directly from gonum/plot:
// Copyright ©2017 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This is an implementation of the Talbot, Lin and Hanrahan algorithm
// described in doi:10.1109/TVCG.2010.130 with reference to the R
// implementation in the labeling package, ©2014 Justin Talbot (Licensed
// MIT+file LICENSE|Unlimited).
package plot
import "math"
const (
// dlamchE is the machine epsilon. For IEEE this is 2^{-53}.
dlamchE = 1.0 / (1 << 53)
// dlamchB is the radix of the machine (the base of the number system).
dlamchB = 2
// dlamchP is base * eps.
dlamchP = dlamchB * dlamchE
)
const (
// free indicates no restriction on label containment.
free = iota
// containData specifies that all the data range lies
// within the interval [label_min, label_max].
containData
// withinData specifies that all labels lie within the
// interval [dMin, dMax].
withinData
)
// talbotLinHanrahan returns an optimal set of approximately want label values
// for the data range [dMin, dMax], and the step and magnitude of the step between values.
// containment is specifies are guarantees for label and data range containment, valid
// values are free, containData and withinData.
// The optional parameters Q, nice numbers, and w, weights, allow tuning of the
// algorithm but by default (when nil) are set to the parameters described in the
// paper.
// The legibility function allows tuning of the legibility assessment for labels.
// By default, when nil, legbility will set the legibility score for each candidate
// labelling scheme to 1.
// See the paper for an explanation of the function of Q, w and legibility.
func talbotLinHanrahan(dMin, dMax float64, want int, containment int, Q []float64, w *weights, legibility func(lMin, lMax, lStep float64) float64) (values []float64, step, q float64, magnitude int) {
const eps = dlamchP * 100
if dMin > dMax {
panic("labelling: invalid data range: min greater than max")
}
if Q == nil {
Q = []float64{1, 5, 2, 2.5, 4, 3}
}
if w == nil {
w = &weights{
simplicity: 0.25,
coverage: 0.2,
density: 0.5,
legibility: 0.05,
}
}
if legibility == nil {
legibility = unitLegibility
}
if r := dMax - dMin; r < eps {
l := make([]float64, want)
step := r / float64(want-1)
for i := range l {
l[i] = dMin + float64(i)*step
}
magnitude = minAbsMag(dMin, dMax)
return l, step, 0, magnitude
}
type selection struct {
// n is the number of labels selected.
n int
// lMin and lMax are the selected min
// and max label values. lq is the q
// chosen.
lMin, lMax, lStep, lq float64
// score is the score for the selection.
score float64
// magnitude is the magnitude of the
// label step distance.
magnitude int
}
best := selection{score: -2}
outer:
for skip := 1; ; skip++ {
for _, q := range Q {
sm := maxSimplicity(q, Q, skip)
if w.score(sm, 1, 1, 1) < best.score {
break outer
}
for have := 2; ; have++ {
dm := maxDensity(have, want)
if w.score(sm, 1, dm, 1) < best.score {
break
}
delta := (dMax - dMin) / float64(have+1) / float64(skip) / q
const maxExp = 309
for mag := int(math.Ceil(math.Log10(delta))); mag < maxExp; mag++ {
step := float64(skip) * q * math.Pow10(mag)
cm := maxCoverage(dMin, dMax, step*float64(have-1))
if w.score(sm, cm, dm, 1) < best.score {
break
}
fracStep := step / float64(skip)
kStep := step * float64(have-1)
minStart := (math.Floor(dMax/step) - float64(have-1)) * float64(skip)
maxStart := math.Ceil(dMax/step) * float64(skip)
for start := minStart; start <= maxStart && start != start-1; start++ {
lMin := start * fracStep
lMax := lMin + kStep
switch containment {
case containData:
if dMin < lMin || lMax < dMax {
continue
}
case withinData:
if lMin < dMin || dMax < lMax {
continue
}
case free:
// Free choice.
}
score := w.score(
simplicity(q, Q, skip, lMin, lMax, step),
coverage(dMin, dMax, lMin, lMax),
density(have, want, dMin, dMax, lMin, lMax),
legibility(lMin, lMax, step),
)
if score > best.score {
best = selection{
n: have,
lMin: lMin,
lMax: lMax,
lStep: float64(skip) * q,
lq: q,
score: score,
magnitude: mag,
}
}
}
}
}
}
}
if best.score == -2 {
l := make([]float64, want)
step := (dMax - dMin) / float64(want-1)
for i := range l {
l[i] = dMin + float64(i)*step
}
magnitude = minAbsMag(dMin, dMax)
return l, step, 0, magnitude
}
l := make([]float64, best.n)
step = best.lStep * math.Pow10(best.magnitude)
for i := range l {
l[i] = best.lMin + float64(i)*step
}
return l, best.lStep, best.lq, best.magnitude
}
// minAbsMag returns the minumum magnitude of the absolute values of a and b.
func minAbsMag(a, b float64) int {
return int(math.Min(math.Floor(math.Log10(math.Abs(a))), (math.Floor(math.Log10(math.Abs(b))))))
}
// simplicity returns the simplicity score for how will the curent q, lMin, lMax,
// lStep and skip match the given nice numbers, Q.
func simplicity(q float64, Q []float64, skip int, lMin, lMax, lStep float64) float64 {
const eps = dlamchP * 100
for i, v := range Q {
if v == q {
m := math.Mod(lMin, lStep)
v = 0
if (m < eps || lStep-m < eps) && lMin <= 0 && 0 <= lMax {
v = 1
}
return 1 - float64(i)/(float64(len(Q))-1) - float64(skip) + v
}
}
panic("labelling: invalid q for Q")
}
// maxSimplicity returns the maximum simplicity for q, Q and skip.
func maxSimplicity(q float64, Q []float64, skip int) float64 {
for i, v := range Q {
if v == q {
return 1 - float64(i)/(float64(len(Q))-1) - float64(skip) + 1
}
}
panic("labelling: invalid q for Q")
}
// coverage returns the coverage score for based on the average
// squared distance between the extreme labels, lMin and lMax, and
// the extreme data points, dMin and dMax.
func coverage(dMin, dMax, lMin, lMax float64) float64 {
r := 0.1 * (dMax - dMin)
max := dMax - lMax
min := dMin - lMin
return 1 - 0.5*(max*max+min*min)/(r*r)
}
// maxCoverage returns the maximum coverage achievable for the data
// range.
func maxCoverage(dMin, dMax, span float64) float64 {
r := dMax - dMin
if span <= r {
return 1
}
h := 0.5 * (span - r)
r *= 0.1
return 1 - (h*h)/(r*r)
}
// density returns the density score which measures the goodness of
// the labelling density compared to the user defined target
// based on the want parameter given to talbotLinHanrahan.
func density(have, want int, dMin, dMax, lMin, lMax float64) float64 {
rho := float64(have-1) / (lMax - lMin)
rhot := float64(want-1) / (math.Max(lMax, dMax) - math.Min(dMin, lMin))
if d := rho / rhot; d >= 1 {
return 2 - d
}
return 2 - rhot/rho
}
// maxDensity returns the maximum density score achievable for have and want.
func maxDensity(have, want int) float64 {
if have < want {
return 1
}
return 2 - float64(have-1)/float64(want-1)
}
// unitLegibility returns a default legibility score ignoring label
// spacing.
func unitLegibility(_, _, _ float64) float64 {
return 1
}
// weights is a helper type to calcuate the labelling scheme's total score.
type weights struct {
simplicity, coverage, density, legibility float64
}
// score returns the score for a labelling scheme with simplicity, s,
// coverage, c, density, d and legibility l.
func (w *weights) score(s, c, d, l float64) float64 {
return w.simplicity*s + w.coverage*c + w.density*d + w.legibility*l
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
import (
"image"
"cogentcore.org/core/colors"
"cogentcore.org/core/colors/gradient"
"cogentcore.org/core/styles/units"
)
// LegendStyle has the styling properties for the Legend.
type LegendStyle struct { //types:add -setters
// Column is for table-based plotting, specifying the column with legend values.
Column string
// Text is the style given to the legend entry texts.
Text TextStyle `display:"add-fields"`
// position of the legend
Position LegendPosition `display:"inline"`
// ThumbnailWidth is the width of legend thumbnails.
ThumbnailWidth units.Value `display:"inline"`
// Fill specifies the background fill color for the legend box,
// if non-nil.
Fill image.Image
}
func (ls *LegendStyle) Defaults() {
ls.Text.Defaults()
ls.Text.Padding.Dp(2)
ls.Text.Size.Dp(20)
ls.Position.Defaults()
ls.ThumbnailWidth.Pt(20)
ls.Fill = gradient.ApplyOpacity(colors.Scheme.Surface, 0.75)
}
// LegendPosition specifies where to put the legend
type LegendPosition struct {
// Top and Left specify the location of the legend.
Top, Left bool
// XOffs and YOffs are added to the legend's final position,
// relative to the relevant anchor position
XOffs, YOffs units.Value
}
func (lg *LegendPosition) Defaults() {
lg.Top = true
}
// A Legend gives a description of the meaning of different
// data elements of the plot. Each legend entry has a name
// and a thumbnail, where the thumbnail shows a small
// sample of the display style of the corresponding data.
type Legend struct {
// Style has the legend styling parameters.
Style LegendStyle
// Entries are all of the LegendEntries described by this legend.
Entries []LegendEntry
}
func (lg *Legend) Defaults() {
lg.Style.Defaults()
}
// Add adds an entry to the legend with the given name.
// The entry's thumbnail is drawn as the composite of all of the
// thumbnails.
func (lg *Legend) Add(name string, thumbs ...Thumbnailer) {
lg.Entries = append(lg.Entries, LegendEntry{Text: name, Thumbs: thumbs})
}
// LegendForPlotter returns the legend Text for given plotter,
// if it exists as a Thumbnailer in the legend entries.
// Otherwise returns empty string.
func (lg *Legend) LegendForPlotter(plt Plotter) string {
for _, e := range lg.Entries {
for _, tn := range e.Thumbs {
if tp, isp := tn.(Plotter); isp && tp == plt {
return e.Text
}
}
}
return ""
}
// Thumbnailer wraps the Thumbnail method, which draws the small
// image in a legend representing the style of data.
type Thumbnailer interface {
// Thumbnail draws an thumbnail representing a legend entry.
// The thumbnail will usually show a smaller representation
// of the style used to plot the corresponding data.
Thumbnail(pt *Plot)
}
// A LegendEntry represents a single line of a legend, it
// has a name and an icon.
type LegendEntry struct {
// text is the text associated with this entry.
Text string
// thumbs is a slice of all of the thumbnails styles
Thumbs []Thumbnailer
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
import (
"image"
"cogentcore.org/core/colors"
"cogentcore.org/core/math32"
"cogentcore.org/core/styles/units"
)
// LineStyle has style properties for drawing lines.
type LineStyle struct { //types:add -setters
// On indicates whether to plot lines.
On DefaultOffOn
// Color is the stroke color image specification.
// Setting to nil turns line off.
Color image.Image
// Width is the line width, with a default of 1 Pt (point).
// Setting to 0 turns line off.
Width units.Value
// Dashes are the dashes of the stroke. Each pair of values specifies
// the amount to paint and then the amount to skip.
Dashes []float32
// Fill is the color to fill solid regions, in a plot-specific
// way (e.g., the area below a Line plot, the bar color).
// Use nil to disable filling.
Fill image.Image
// NegativeX specifies whether to draw lines that connect points with a negative
// X-axis direction; otherwise there is a break in the line.
// default is false, so that repeated series of data across the X axis
// are plotted separately.
NegativeX bool
// Step specifies how to step the line between points.
Step StepKind
}
func (ls *LineStyle) Defaults() {
ls.Color = colors.Scheme.OnSurface
ls.Width.Pt(1)
}
// SpacedColor sets the Color to a default spaced color based on index,
// if it still has the initial OnSurface default.
func (ls *LineStyle) SpacedColor(idx int) {
if ls.Color == colors.Scheme.OnSurface {
ls.Color = colors.Uniform(colors.Spaced(idx))
}
}
// SpacedFill sets the Fill to a default spaced color based on index,
// if it still has the initial nil default.
func (ls *LineStyle) SpacedFill(idx int) {
if ls.Fill == nil {
ls.Fill = colors.Uniform(colors.Spaced(idx))
}
}
// SetStroke sets the stroke style in plot paint to current line style.
// returns false if either the Width = 0 or Color is nil
func (ls *LineStyle) SetStroke(pt *Plot) bool {
if ls.On == Off || ls.Color == nil {
return false
}
pc := pt.Painter
uc := pt.UnitContext()
ls.Width.ToDots(uc)
if ls.Width.Dots == 0 {
return false
}
pc.Stroke.Width = ls.Width
pc.Stroke.Color = ls.Color
pc.Stroke.ToDots(uc)
return true
}
func (ls *LineStyle) HasFill() bool {
if ls.Fill == nil {
return false
}
clr := colors.ToUniform(ls.Fill)
if clr == colors.Transparent {
return false
}
return true
}
// Draw draws a line between given coordinates, setting the stroke style
// to current parameters. Returns false if either Width = 0 or Color = nil
func (ls *LineStyle) Draw(pt *Plot, start, end math32.Vector2) bool {
if !ls.SetStroke(pt) {
return false
}
pc := pt.Painter
pc.MoveTo(start.X, start.Y)
pc.LineTo(end.X, end.Y)
pc.Draw()
return true
}
// StepKind specifies a form of a connection of two consecutive points.
type StepKind int32 //enums:enum
const (
// NoStep connects two points by simple line.
NoStep StepKind = iota
// PreStep connects two points by following lines: vertical, horizontal.
PreStep
// MidStep connects two points by following lines: horizontal, vertical, horizontal.
// Vertical line is placed in the middle of the interval.
MidStep
// PostStep connects two points by following lines: horizontal, vertical.
PostStep
)
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Adapted from github.com/gonum/plot:
// Copyright ©2015 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
//go:generate core generate -add-types
import (
"image"
"sync"
"cogentcore.org/core/colors"
"cogentcore.org/core/math32"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/core/paint"
"cogentcore.org/core/paint/render"
_ "cogentcore.org/core/paint/renderers"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/sides"
"cogentcore.org/core/styles/units"
"cogentcore.org/core/text/shaped"
"cogentcore.org/core/text/text"
)
var (
// plotShaper is a shared text shaper.
plotShaper shaped.Shaper
// mutex for sharing the plotShaper.
shaperMu sync.Mutex
)
// XAxisStyle has overall plot level styling properties for the XAxis.
type XAxisStyle struct { //types:add -setters
// Column specifies the column to use for the common X axis,
// for [plot.NewTablePlot] [table.Table] driven plots.
// If empty, standard Group-based role binding is used: the last column
// within the same group with Role=X is used.
Column string
// Rotation is the rotation of the X Axis labels, in degrees.
Rotation float32
// Label is the optional label to use for the XAxis instead of the default.
Label string
// Range is the effective range of XAxis data to plot, where either end can be fixed.
Range minmax.Range64 `display:"inline"`
// Scale specifies how values are scaled along the X axis:
// Linear, Log, Inverted
Scale AxisScales
}
// PlotStyle has overall plot level styling properties.
// Some properties provide defaults for individual elements, which can
// then be overwritten by element-level properties.
type PlotStyle struct { //types:add -setters
// Title is the overall title of the plot.
Title string
// TitleStyle is the text styling parameters for the title.
TitleStyle TextStyle
// Background is the background of the plot.
// The default is [colors.Scheme.Surface].
Background image.Image
// Scale multiplies the plot DPI value, to change the overall scale
// of the rendered plot. Larger numbers produce larger scaling.
// Typically use larger numbers when generating plots for inclusion in
// documents or other cases where the overall plot size will be small.
Scale float32 `default:"1,2"`
// Legend has the styling properties for the Legend.
Legend LegendStyle `display:"add-fields"`
// Axis has the styling properties for the Axis associated with this Data.
Axis AxisStyle `display:"add-fields"`
// XAxis has plot-level properties specific to the XAxis.
XAxis XAxisStyle `display:"add-fields"`
// YAxisLabel is the optional label to use for the YAxis instead of the default.
YAxisLabel string
// SizeAxis has plot-level properties specific to the Size virtual axis.
SizeAxis VirtualAxisStyle
// LinesOn determines whether lines are plotted by default at the overall,
// Plot level, for elements that plot lines (e.g., plots.XY).
LinesOn DefaultOffOn
// LineWidth sets the default line width for data plotting lines at the
// overall Plot level.
LineWidth units.Value
// PointsOn determines whether points are plotted by default at the
// overall Plot level, for elements that plot points (e.g., plots.XY).
PointsOn DefaultOffOn
// PointSize sets the default point size at the overall Plot level.
PointSize units.Value
// LabelSize sets the default label text size at the overall Plot level.
LabelSize units.Value
// BarWidth for Bar plot sets the default width of the bars,
// which should be less than the Stride (1 typically) to prevent
// bar overlap. Defaults to .8.
BarWidth float64
// ShowErrors can be set to have Plot configuration errors reported.
// This is particularly important for table-driven plots (e.g., [plotcore.Editor]),
// but it is not on by default because often there are transitional states
// with known errors that can lead to false alarms.
ShowErrors bool
}
func (ps *PlotStyle) Defaults() {
ps.TitleStyle.Defaults()
ps.TitleStyle.Size.Dp(24)
ps.Background = colors.Scheme.Surface
ps.Scale = 1
ps.Legend.Defaults()
ps.Axis.Defaults()
ps.LineWidth.Pt(1)
ps.PointSize.Pt(3)
ps.LabelSize.Dp(16)
ps.BarWidth = .8
}
// SetElementStyle sets the properties for given element's style
// based on the global default settings in this PlotStyle.
func (ps *PlotStyle) SetElementStyle(es *Style) {
if ps.LinesOn != Default {
es.Line.On = ps.LinesOn
}
if ps.PointsOn != Default {
es.Point.On = ps.PointsOn
}
es.Line.Width = ps.LineWidth
es.Point.Size = ps.PointSize
es.Width.Width = ps.BarWidth
es.Text.Size = ps.LabelSize
}
// PanZoom provides post-styling pan and zoom range manipulation.
type PanZoom struct {
// XOffset adds offset to X range (pan).
XOffset float64
// XScale multiplies X range (zoom).
XScale float64
// YOffset adds offset to Y range (pan).
YOffset float64
// YScale multiplies Y range (zoom).
YScale float64
}
func (pz *PanZoom) Defaults() {
pz.XScale = 1
pz.YScale = 1
}
// Plot is the basic type representing a plot.
// It renders into its own image.RGBA Pixels image,
// and can also save a corresponding SVG version.
type Plot struct {
// Title of the plot
Title Text
// Style has the styling properties for the plot.
// All end-user configuration should be put in here,
// rather than modifying other fields directly on the plot.
Style PlotStyle
// standard text style with default options
StandardTextStyle text.Style
// X, Y, YR, and Z are the horizontal, vertical, right vertical, and depth axes
// of the plot respectively. These are the actual compiled
// state data and should not be used for styling: use Style.
X, Y, YR, Z Axis
// SizeAxis is a virtual axis for the Size data role.
SizeAxis VirtualAxis
// Legend is the plot's legend.
Legend Legend
// Plotters are drawn by calling their Plot method after the axes are drawn.
Plotters []Plotter
// PanZoom provides post-styling pan and zoom range factors.
PanZoom PanZoom
// HighlightPlotter is the Plotter to highlight. Used for mouse hovering for example.
// It is the responsibility of the Plotter Plot function to implement highlighting.
HighlightPlotter Plotter
// HighlightIndex is the index of the data point to highlight, for HighlightPlotter.
HighlightIndex int
// TextShaper for shaping text. Can set to a shared external one,
// or else the shared plotShaper is used under a mutex lock during Render.
TextShaper shaped.Shaper
// PaintBox is the bounding box for the plot within the Paint.
// For standalone, it is the size of the image.
PaintBox image.Rectangle
// Current local plot bounding box in image coordinates, for computing
// plotting coordinates.
PlotBox math32.Box2
// Painter is the current painter being used,
// which is only valid during rendering, and is set by Draw function.
// It needs to be exported for different plot types in other packages.
Painter *paint.Painter
// unitContext is current unit context, only valid during rendering.
unitContext *units.Context
}
// New returns a new plot with some reasonable default settings.
func New() *Plot {
pt := &Plot{}
pt.Defaults()
return pt
}
// Defaults sets defaults
func (pt *Plot) Defaults() {
pt.SetSize(image.Point{640, 480})
pt.Style.Defaults()
pt.Title.Defaults()
pt.Title.Style.Size.Dp(24)
pt.X.Defaults(math32.X)
pt.Y.Defaults(math32.Y)
pt.YR.Defaults(math32.Y)
pt.YR.RightY = true
pt.Legend.Defaults()
pt.PanZoom.Defaults()
pt.StandardTextStyle.Defaults()
pt.StandardTextStyle.WhiteSpace = text.WrapNever
}
// SetSize sets the size of the plot, typically in terms
// of actual device pixels (dots).
func (pt *Plot) SetSize(sz image.Point) {
pt.PaintBox.Max = sz
}
// UnitContext returns the [units.Context] to use for styling.
// This includes the scaling factor.
func (pt *Plot) UnitContext() *units.Context {
if pt.unitContext != nil {
return pt.unitContext
}
uc := &units.Context{}
*uc = pt.Painter.UnitContext
uc.DPI *= pt.Style.Scale
pt.unitContext = uc
return uc
}
// applyStyle applies all the style parameters
func (pt *Plot) applyStyle() {
hasYright := false
// first update the global plot style settings
var st Style
st.Defaults()
st.Plot = pt.Style
for _, plt := range pt.Plotters {
stlr := plt.Stylers()
stlr.Run(&st)
var pst Style
pst.Defaults()
stlr.Run(&pst)
if pst.RightY {
hasYright = true
}
if pst.Label != "" {
if pst.RightY {
pt.YR.Label.Text = pst.Label
} else {
pt.Y.Label.Text = pst.Label
}
}
}
pt.Style = st.Plot
// then apply to elements
for i, plt := range pt.Plotters {
plt.ApplyStyle(&pt.Style, i)
}
pt.Title.Style = pt.Style.TitleStyle
if pt.Style.Title != "" {
pt.Title.Text = pt.Style.Title
}
pt.Legend.Style = pt.Style.Legend
pt.X.Style = pt.Style.Axis
pt.X.Style.Scale = pt.Style.XAxis.Scale
if pt.Style.XAxis.Label != "" {
pt.X.Label.Text = pt.Style.XAxis.Label
}
pt.X.Label.Style = pt.Style.Axis.Text
pt.X.TickText.Style = pt.Style.Axis.TickText
pt.X.TickText.Style.Rotation = pt.Style.XAxis.Rotation
pt.SizeAxis.Style = st.Plot.SizeAxis
pt.Y.Style = pt.Style.Axis
pt.YR.Style = pt.Style.Axis
pt.YR.Style.On = hasYright
pt.Y.Label.Style = pt.Style.Axis.Text
pt.YR.Label.Style = pt.Style.Axis.Text
pt.Y.TickText.Style = pt.Style.Axis.TickText
pt.YR.TickText.Style = pt.Style.Axis.TickText
pt.Y.Label.Style.Rotation = -90
pt.Y.Style.TickText.Align = styles.End
pt.YR.Label.Style.Rotation = 90
pt.YR.Style.TickText.Align = styles.Start
pt.UpdateRange()
}
// Add adds Plotter element(s) to the plot.
// When drawing the plot, Plotters are drawn in the
// order in which they were added to the plot.
func (pt *Plot) Add(ps ...Plotter) {
pt.Plotters = append(pt.Plotters, ps...)
}
// CurBounds returns the current render bounds from Paint
func (pt *Plot) CurBounds() image.Rectangle {
return pt.Painter.Context().Bounds.Rect.ToRect()
}
// PushBounds returns the current render bounds from Paint
func (pt *Plot) PushBounds(tb image.Rectangle) {
pt.Painter.PushContext(nil, render.NewBoundsRect(tb, sides.Floats{}))
}
// NominalX configures the plot to have a nominal X
// axis—an X axis with names instead of numbers. The
// X location corresponding to each name are the integers,
// e.g., the x value 0 is centered above the first name and
// 1 is above the second name, etc. Labels for x values
// that do not end up in range of the X axis will not have
// tick marks.
func (pt *Plot) NominalX(names ...string) {
pt.X.Style.TickLine.Width.Pt(0)
pt.X.Style.TickLength.Pt(0)
pt.X.Style.Line.Width.Pt(0)
// pt.Y.Padding.Pt(pt.X.Style.Tick.Label.Width(names[0]) / 2)
ticks := make([]Tick, len(names))
for i, name := range names {
ticks[i] = Tick{float64(i), name}
}
pt.X.Ticker = ConstantTicks(ticks)
}
// HideX configures the X axis so that it will not be drawn.
func (pt *Plot) HideX() {
pt.X.Style.TickLength.Pt(0)
pt.X.Style.Line.Width.Pt(0)
pt.X.Ticker = ConstantTicks([]Tick{})
}
// HideY configures the Y axis so that it will not be drawn.
func (pt *Plot) HideY() {
pt.Y.Style.TickLength.Pt(0)
pt.Y.Style.Line.Width.Pt(0)
pt.Y.Ticker = ConstantTicks([]Tick{})
}
// HideYR configures the YR axis so that it will not be drawn.
func (pt *Plot) HideYR() {
pt.YR.Style.TickLength.Pt(0)
pt.YR.Style.Line.Width.Pt(0)
pt.YR.Ticker = ConstantTicks([]Tick{})
}
// HideAxes hides the X and Y axes.
func (pt *Plot) HideAxes() {
pt.HideX()
pt.HideY()
pt.HideYR()
}
// NominalY is like NominalX, but for the Y axis.
func (pt *Plot) NominalY(names ...string) {
pt.Y.Style.TickLine.Width.Pt(0)
pt.Y.Style.TickLength.Pt(0)
pt.Y.Style.Line.Width.Pt(0)
// pt.X.Padding = pt.Y.Tick.Label.Height(names[0]) / 2
ticks := make([]Tick, len(names))
for i, name := range names {
ticks[i] = Tick{float64(i), name}
}
pt.Y.Ticker = ConstantTicks(ticks)
}
// UpdateRange updates the axis range values based on current Plot values.
// This first resets the range so any fixed additional range values should
// be set after this point.
func (pt *Plot) UpdateRange() {
pt.X.Range.SetInfinity()
pt.Y.Range.SetInfinity()
pt.YR.Range.SetInfinity()
pt.Z.Range.SetInfinity()
pt.SizeAxis.Range.SetInfinity()
// note: putting this after allows it to override
if pt.Style.XAxis.Range.FixMin {
pt.X.Range.Min = pt.Style.XAxis.Range.Min
}
if pt.Style.XAxis.Range.FixMax {
pt.X.Range.Max = pt.Style.XAxis.Range.Max
}
for _, pl := range pt.Plotters {
pl.UpdateRange(pt, &pt.X.Range, &pt.Y.Range, &pt.YR.Range, &pt.Z.Range, &pt.SizeAxis.Range)
}
pt.X.Range.Sanitize()
pt.Y.Range.Sanitize()
pt.YR.Range.Sanitize()
pt.Z.Range.Sanitize()
pt.X.Range.Min *= pt.PanZoom.XScale
pt.X.Range.Max *= pt.PanZoom.XScale
pt.X.Range.Min += pt.PanZoom.XOffset
pt.X.Range.Max += pt.PanZoom.XOffset
pt.Y.Range.Min *= pt.PanZoom.YScale
pt.Y.Range.Max *= pt.PanZoom.YScale
pt.Y.Range.Min += pt.PanZoom.YOffset
pt.Y.Range.Max += pt.PanZoom.YOffset
pt.YR.Range.Min *= pt.PanZoom.YScale
pt.YR.Range.Max *= pt.PanZoom.YScale
pt.YR.Range.Min += pt.PanZoom.YOffset
pt.YR.Range.Max += pt.PanZoom.YOffset
}
// PX returns the X-axis plotting coordinate for given raw data value
// using the current plot bounding region
func (pt *Plot) PX(v float64) float32 {
return pt.PlotBox.ProjectX(float32(pt.X.Norm(v)))
}
// PY returns the Y-axis plotting coordinate for given raw data value
func (pt *Plot) PY(v float64) float32 {
return pt.PlotBox.ProjectY(float32(1 - pt.Y.Norm(v)))
}
// PYR returns the Y-axis plotting coordinate for given raw data value
func (pt *Plot) PYR(v float64) float32 {
return pt.PlotBox.ProjectY(float32(1 - pt.YR.Norm(v)))
}
// ClosestDataToPixel returns the Plotter data point closest to given pixel point,
// in the Pixels image.
func (pt *Plot) ClosestDataToPixel(px, py int) (plt Plotter, plotterIndex, pointIndex int, dist float32, pixel math32.Vector2, data Data, legend string) {
tp := math32.Vec2(float32(px), float32(py))
dist = float32(math32.MaxFloat32)
for pi, pl := range pt.Plotters {
dts, pxX, pxY := pl.Data()
if len(pxY) != len(pxX) {
continue
}
for i, ptx := range pxX {
pty := pxY[i]
pxy := math32.Vec2(ptx, pty)
d := pxy.DistanceTo(tp)
if d < dist {
dist = d
pixel = pxy
plt = pl
plotterIndex = pi
pointIndex = i
data = dts
legend = pt.Legend.LegendForPlotter(pl)
}
}
}
return
}
// Copyright (c) 2019, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This is copied and modified directly from gonum to add better error-bar
// plotting for bar plots, along with multiple groups.
// Copyright ©2015 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plots
import (
"fmt"
"math"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/math32"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/lab/plot"
)
// BarType is be used for specifying the type name.
const BarType = "Bar"
func init() {
plot.RegisterPlotter(BarType, "A Bar presents ordinally-organized data with rectangular bars with lengths proportional to the data values, and an optional error bar at the top of the bar using the High data role.", []plot.Roles{plot.Y}, []plot.Roles{plot.High, plot.X}, func(plt *plot.Plot, data plot.Data) plot.Plotter {
return NewBar(plt, data)
})
}
// A Bar presents ordinally-organized data with rectangular bars
// with lengths proportional to the data values, and an optional
// error bar ("handle") at the top of the bar using the High data role.
//
// Bars are plotted centered at integer multiples of Stride plus Start offset.
// Full data range also includes Pad value to extend range beyond edge bar centers.
// Bar Width is in data units, e.g., should be <= Stride.
// Defaults provide a unit-spaced plot.
type Bar struct {
// copies of data
Y, Err plot.Values
// actual plotting X, Y values in data coordinates, taking into account stacking etc.
X, Yp plot.Values
// optional labels for X axis.
XLabels plot.Labels
// PX, PY are the actual pixel plotting coordinates for each XY value.
PX, PY []float32
// Style has the properties used to render the bars.
Style plot.Style
// Horizontal dictates whether the bars should be in the vertical
// (default) or horizontal direction. If Horizontal is true, all
// X locations and distances referred to here will actually be Y
// locations and distances.
Horizontal bool
// stackedOn is the bar chart upon which this bar chart is stacked.
StackedOn *Bar
stylers plot.Stylers
}
// NewBar adds a new bar plotter with a single bar for each value, for given data
// which can either by a [plot.Valuer] (e.g., Tensor) with the Y values,
// or a [plot.Data] with roles, and values defined.
// The bars heights correspond to the values and their x locations correspond
// to the index of their value in the Valuer.
// Optional error-bar values can be provided using the High data role.
// Styler functions are obtained from the Y metadata if present.
func NewBar(plt *plot.Plot, data any) *Bar {
dt := errors.Log1(plot.DataOrValuer(data, plot.Y))
if dt == nil {
return nil
}
if dt.CheckLengths() != nil {
return nil
}
bc := &Bar{}
bc.Y = plot.MustCopyRole(dt, plot.Y)
if bc.Y == nil {
return nil
}
bc.XLabels = plot.CopyRoleLabels(dt, plot.X)
bc.stylers = plot.GetStylersFromData(dt, plot.Y)
bc.Err = plot.CopyRole(dt, plot.High)
bc.Defaults()
plt.Add(bc)
return bc
}
func (bc *Bar) Defaults() {
bc.Style.Defaults()
}
func (bc *Bar) Styler(f func(s *plot.Style)) *Bar {
bc.stylers.Add(f)
return bc
}
func (bc *Bar) ApplyStyle(ps *plot.PlotStyle, idx int) {
bc.Style.Line.SpacedFill(idx)
ps.SetElementStyle(&bc.Style)
bc.stylers.Run(&bc.Style)
}
func (bc *Bar) Stylers() *plot.Stylers { return &bc.stylers }
func (bc *Bar) Data() (data plot.Data, pixX, pixY []float32) {
pixX = bc.PX
pixY = bc.PY
data = plot.Data{}
data[plot.X] = bc.X
data[plot.Y] = bc.Y
if bc.Err != nil {
data[plot.High] = bc.Err
}
return
}
// BarHeight returns the maximum y value of the
// ith bar, taking into account any bars upon
// which it is stacked.
func (bc *Bar) BarHeight(i int) float64 {
ht := float64(0.0)
if bc == nil {
return 0
}
if i >= 0 && i < len(bc.Y) {
ht += bc.Y[i]
}
if bc.StackedOn != nil {
ht += bc.StackedOn.BarHeight(i)
}
return ht
}
// StackOn stacks a bar chart on top of another,
// and sets the bar positioning options to that of the
// chart upon which it is being stacked.
func (bc *Bar) StackOn(on *Bar) {
bc.Style.Width = on.Style.Width
bc.StackedOn = on
}
// Plot implements the plot.Plotter interface.
func (bc *Bar) Plot(plt *plot.Plot) {
pc := plt.Painter
bc.Style.Line.SetStroke(plt)
pc.Fill.Color = bc.Style.Line.Fill
bw := bc.Style.Width
nv := len(bc.Y)
bc.X = make(plot.Values, nv)
bc.Yp = make(plot.Values, nv)
bc.PX = make([]float32, nv)
bc.PY = make([]float32, nv)
hw := 0.5 * bw.Width
ew := bw.Width / 3
for i, ht := range bc.Y {
cat := bw.Offset + float64(i)*bw.Stride
var bottom float64
var catVal, catMin, catMax, valMin, valMax float32
var box math32.Box2
if bc.Horizontal {
catVal = plt.PY(cat)
catMin = plt.PY(cat - hw)
catMax = plt.PY(cat + hw)
bottom = bc.StackedOn.BarHeight(i) // nil safe
valMin = plt.PX(bottom)
valMax = plt.PX(bottom + ht)
bc.X[i] = bottom + ht
bc.Yp[i] = cat
bc.PX[i] = valMax
bc.PY[i] = catVal
box.Min.Set(valMin, catMin)
box.Max.Set(valMax, catMax)
} else {
catVal = plt.PX(cat)
catMin = plt.PX(cat - hw)
catMax = plt.PX(cat + hw)
bottom = bc.StackedOn.BarHeight(i) // nil safe
valMin = plt.PY(bottom)
valMax = plt.PY(bottom + ht)
bc.X[i] = cat
bc.Yp[i] = bottom + ht
bc.PX[i] = catVal
bc.PY[i] = valMax
box.Min.Set(catMin, valMin)
box.Max.Set(catMax, valMax)
}
pc.Rectangle(box.Min.X, box.Min.Y, box.Size().X, box.Size().Y)
pc.Draw()
if i < len(bc.Err) {
errval := math.Abs(bc.Err[i])
if bc.Horizontal {
eVal := plt.PX(bottom + ht + math.Abs(errval))
pc.MoveTo(valMax, catVal)
pc.LineTo(eVal, catVal)
pc.MoveTo(eVal, plt.PY(cat-ew))
pc.LineTo(eVal, plt.PY(cat+ew))
} else {
eVal := plt.PY(bottom + ht + math.Abs(errval))
pc.MoveTo(catVal, valMax)
pc.LineTo(catVal, eVal)
pc.MoveTo(plt.PX(cat-ew), eVal)
pc.LineTo(plt.PX(cat+ew), eVal)
}
pc.Draw()
}
}
pc.Fill.Color = nil
}
// UpdateRange updates the given ranges.
func (bc *Bar) UpdateRange(plt *plot.Plot, x, y, yr, z, size *minmax.F64) {
bw := bc.Style.Width
catMin := bw.Offset - bw.Pad
catMax := bw.Offset + float64(len(bc.Y)-1)*bw.Stride + bw.Pad
if bc.Style.RightY {
y = yr
}
var ticks plot.ConstantTicks
if bc.XLabels != nil {
ticks = make(plot.ConstantTicks, len(bc.Y))
}
for i, val := range bc.Y {
valBot := bc.StackedOn.BarHeight(i)
valTop := valBot + val
if i < len(bc.Err) {
valTop += math.Abs(bc.Err[i])
}
if bc.Horizontal {
x.FitValInRange(valBot)
x.FitValInRange(valTop)
} else {
y.FitValInRange(valBot)
y.FitValInRange(valTop)
}
if bc.XLabels != nil {
cat := bw.Offset + float64(i)*bw.Stride
ticks[i].Value = cat
if bc.XLabels.Len() > i {
ticks[i].Label = bc.XLabels[i]
} else {
ticks[i].Label = fmt.Sprintf("%d", i)
}
}
}
if bc.Horizontal {
x.Min, x.Max = bc.Style.Range.Clamp(x.Min, x.Max)
y.FitInRange(minmax.F64{catMin, catMax})
if ticks != nil {
plt.Y.Ticker = ticks
}
} else {
y.Min, y.Max = bc.Style.Range.Clamp(y.Min, y.Max)
x.FitInRange(minmax.F64{catMin, catMax})
if ticks != nil {
plt.X.Ticker = ticks
}
}
}
// Thumbnail fulfills the plot.Thumbnailer interface.
func (bc *Bar) Thumbnail(plt *plot.Plot) {
pc := plt.Painter
bc.Style.Line.SetStroke(plt)
pc.Fill.Color = bc.Style.Line.Fill
ptb := plt.CurBounds()
pc.Rectangle(float32(ptb.Min.X), float32(ptb.Min.Y), float32(ptb.Size().X), float32(ptb.Size().Y))
pc.Draw()
pc.Fill.Color = nil
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plots
import (
"math"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/lab/plot"
)
const (
// YErrorBarsType is be used for specifying the type name.
YErrorBarsType = "YErrorBars"
// XErrorBarsType is be used for specifying the type name.
XErrorBarsType = "XErrorBars"
)
func init() {
plot.RegisterPlotter(YErrorBarsType, "draws draws vertical error bars, denoting error in Y values, using either High or Low & High data roles for error deviations around X, Y coordinates.", []plot.Roles{plot.X, plot.Y, plot.High}, []plot.Roles{plot.Low}, func(plt *plot.Plot, data plot.Data) plot.Plotter {
return NewYErrorBars(plt, data)
})
plot.RegisterPlotter(XErrorBarsType, "draws draws horizontal error bars, denoting error in X values, using either High or Low & High data roles for error deviations around X, Y coordinates.", []plot.Roles{plot.X, plot.Y, plot.High}, []plot.Roles{plot.Low}, func(plt *plot.Plot, data plot.Data) plot.Plotter {
return NewXErrorBars(plt, data)
})
}
// YErrorBars draws vertical error bars, denoting error in Y values,
// using ether High or Low, High data roles for error deviations
// around X, Y coordinates.
type YErrorBars struct {
// copies of data for this line
X, Y, Low, High plot.Values
// PX, PY are the actual pixel plotting coordinates for each XY value.
PX, PY []float32
// Style is the style for plotting.
Style plot.Style
stylers plot.Stylers
ystylers plot.Stylers
}
func (eb *YErrorBars) Defaults() {
eb.Style.Defaults()
}
// NewYErrorBars adds a new YErrorBars plotter to given plot,
// using Low, High data roles for error deviations around X, Y coordinates.
// Styler functions are obtained from the High data if present.
func NewYErrorBars(plt *plot.Plot, data plot.Data) *YErrorBars {
if data.CheckLengths() != nil {
return nil
}
eb := &YErrorBars{}
eb.X = plot.MustCopyRole(data, plot.X)
eb.Y = plot.MustCopyRole(data, plot.Y)
eb.Low = plot.CopyRole(data, plot.Low)
eb.High = plot.CopyRole(data, plot.High)
if eb.Low == nil && eb.High != nil {
eb.Low = eb.High
}
if eb.X == nil || eb.Y == nil || eb.Low == nil || eb.High == nil {
return nil
}
eb.stylers = plot.GetStylersFromData(data, plot.High)
eb.ystylers = plot.GetStylersFromData(data, plot.Y)
eb.Defaults()
plt.Add(eb)
return eb
}
// Styler adds a style function to set style parameters.
func (eb *YErrorBars) Styler(f func(s *plot.Style)) *YErrorBars {
eb.stylers.Add(f)
return eb
}
func (eb *YErrorBars) ApplyStyle(ps *plot.PlotStyle, idx int) {
eb.Style.Line.SpacedColor(idx)
ps.SetElementStyle(&eb.Style)
yst := &plot.Style{}
eb.ystylers.Run(yst)
eb.Style.Range = yst.Range // get range from y
eb.stylers.Run(&eb.Style)
}
func (eb *YErrorBars) Stylers() *plot.Stylers { return &eb.stylers }
func (eb *YErrorBars) Data() (data plot.Data, pixX, pixY []float32) {
pixX = eb.PX
pixY = eb.PY
data = plot.Data{}
data[plot.X] = eb.X
data[plot.Y] = eb.Y
data[plot.Low] = eb.Low
data[plot.High] = eb.High
return
}
func (eb *YErrorBars) Plot(plt *plot.Plot) {
pc := plt.Painter
uc := &pc.UnitContext
eb.Style.Width.Cap.ToDots(uc)
cw := 0.5 * eb.Style.Width.Cap.Dots
nv := len(eb.X)
eb.PX = make([]float32, nv)
eb.PY = make([]float32, nv)
eb.Style.Line.SetStroke(plt)
for i, y := range eb.Y {
x := plt.PX(eb.X.Float1D(i))
ylow := plt.PY(y - math.Abs(eb.Low[i]))
yhigh := plt.PY(y + math.Abs(eb.High[i]))
eb.PX[i] = x
eb.PY[i] = yhigh
pc.MoveTo(x, ylow)
pc.LineTo(x, yhigh)
pc.MoveTo(x-cw, ylow)
pc.LineTo(x+cw, ylow)
pc.MoveTo(x-cw, yhigh)
pc.LineTo(x+cw, yhigh)
pc.Draw()
}
}
// UpdateRange updates the given ranges.
func (eb *YErrorBars) UpdateRange(plt *plot.Plot, x, y, yr, z, size *minmax.F64) {
if eb.Style.RightY {
y = yr
}
plot.Range(eb.X, x)
plot.RangeClamp(eb.Y, y, &eb.Style.Range)
for i, yv := range eb.Y {
ylow := yv - math.Abs(eb.Low[i])
yhigh := yv + math.Abs(eb.High[i])
y.FitInRange(minmax.F64{ylow, yhigh})
}
return
}
//////// XErrorBars
// XErrorBars draws horizontal error bars, denoting error in X values,
// using ether High or Low, High data roles for error deviations
// around X, Y coordinates.
type XErrorBars struct {
// copies of data for this line
X, Y, Low, High plot.Values
// PX, PY are the actual pixel plotting coordinates for each XY value.
PX, PY []float32
// Style is the style for plotting.
Style plot.Style
stylers plot.Stylers
ystylers plot.Stylers
yrange minmax.Range64
}
func (eb *XErrorBars) Defaults() {
eb.Style.Defaults()
}
// NewXErrorBars adds a new XErrorBars plotter to given plot,
// using Low, High data roles for error deviations around X, Y coordinates.
func NewXErrorBars(plt *plot.Plot, data plot.Data) *XErrorBars {
if data.CheckLengths() != nil {
return nil
}
eb := &XErrorBars{}
eb.X = plot.MustCopyRole(data, plot.X)
eb.Y = plot.MustCopyRole(data, plot.Y)
eb.Low = plot.MustCopyRole(data, plot.Low)
eb.High = plot.MustCopyRole(data, plot.High)
eb.Low = plot.CopyRole(data, plot.Low)
eb.High = plot.CopyRole(data, plot.High)
if eb.Low == nil && eb.High != nil {
eb.Low = eb.High
}
if eb.X == nil || eb.Y == nil || eb.Low == nil || eb.High == nil {
return nil
}
eb.stylers = plot.GetStylersFromData(data, plot.High)
eb.ystylers = plot.GetStylersFromData(data, plot.Y)
eb.Defaults()
plt.Add(eb)
return eb
}
// Styler adds a style function to set style parameters.
func (eb *XErrorBars) Styler(f func(s *plot.Style)) *XErrorBars {
eb.stylers.Add(f)
return eb
}
func (eb *XErrorBars) ApplyStyle(ps *plot.PlotStyle, idx int) {
eb.Style.Line.SpacedColor(idx)
ps.SetElementStyle(&eb.Style)
yst := &plot.Style{}
eb.ystylers.Run(yst)
eb.yrange = yst.Range // get range from y
eb.stylers.Run(&eb.Style)
}
func (eb *XErrorBars) Stylers() *plot.Stylers { return &eb.stylers }
func (eb *XErrorBars) Data() (data plot.Data, pixX, pixY []float32) {
pixX = eb.PX
pixY = eb.PY
data = plot.Data{}
data[plot.X] = eb.X
data[plot.Y] = eb.Y
data[plot.Low] = eb.Low
data[plot.High] = eb.High
return
}
func (eb *XErrorBars) Plot(plt *plot.Plot) {
pc := plt.Painter
uc := &pc.UnitContext
eb.Style.Width.Cap.ToDots(uc)
cw := 0.5 * eb.Style.Width.Cap.Dots
nv := len(eb.X)
eb.PX = make([]float32, nv)
eb.PY = make([]float32, nv)
eb.Style.Line.SetStroke(plt)
for i, x := range eb.X {
y := plt.PY(eb.Y.Float1D(i))
xlow := plt.PX(x - math.Abs(eb.Low[i]))
xhigh := plt.PX(x + math.Abs(eb.High[i]))
eb.PX[i] = xhigh
eb.PY[i] = y
pc.MoveTo(xlow, y)
pc.LineTo(xhigh, y)
pc.MoveTo(xlow, y-cw)
pc.LineTo(xlow, y+cw)
pc.MoveTo(xhigh, y-cw)
pc.LineTo(xhigh, y+cw)
pc.Draw()
}
}
// UpdateRange updates the given ranges.
func (eb *XErrorBars) UpdateRange(plt *plot.Plot, x, y, yr, z, size *minmax.F64) {
if eb.Style.RightY {
y = yr
}
plot.RangeClamp(eb.X, x, &eb.Style.Range)
plot.RangeClamp(eb.Y, y, &eb.yrange)
for i, xv := range eb.X {
xlow := xv - math.Abs(eb.Low[i])
xhigh := xv + math.Abs(eb.High[i])
x.FitInRange(minmax.F64{xlow, xhigh})
}
return
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plots
import (
"image"
"cogentcore.org/core/math32"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/lab/plot"
)
// LabelsType is be used for specifying the type name.
const LabelsType = "Labels"
func init() {
plot.RegisterPlotter(LabelsType, "draws text labels at specified X, Y points.", []plot.Roles{plot.X, plot.Y, plot.Label}, []plot.Roles{}, func(plt *plot.Plot, data plot.Data) plot.Plotter {
return NewLabels(plt, data)
})
}
// Labels draws text labels at specified X, Y points.
type Labels struct {
// copies of data for this line
X, Y plot.Values
Labels plot.Labels
// PX, PY are the actual pixel plotting coordinates for each XY value.
PX, PY []float32
// Style is the style of the label text.
Style plot.Style
// plot size and number of TextStyle when styles last generated -- don't regen
styleSize image.Point
stylers plot.Stylers
ystylers plot.Stylers
}
// NewLabels adds a new Labels to given plot for given data,
// which must specify X, Y and Label roles.
// Styler functions are obtained from the Label metadata if present.
func NewLabels(plt *plot.Plot, data plot.Data) *Labels {
if data.CheckLengths() != nil {
return nil
}
lb := &Labels{}
lb.X = plot.MustCopyRole(data, plot.X)
lb.Y = plot.MustCopyRole(data, plot.Y)
if lb.X == nil || lb.Y == nil {
return nil
}
ld := data[plot.Label]
if ld == nil {
return nil
}
lb.Labels = make(plot.Labels, lb.X.Len())
for i := range ld.Len() {
lb.Labels[i] = ld.String1D(i)
}
lb.stylers = plot.GetStylersFromData(data, plot.Label)
lb.ystylers = plot.GetStylersFromData(data, plot.Y)
lb.Defaults()
plt.Add(lb)
return lb
}
func (lb *Labels) Defaults() {
lb.Style.Defaults()
}
// Styler adds a style function to set style parameters.
func (lb *Labels) Styler(f func(s *plot.Style)) *Labels {
lb.stylers.Add(f)
return lb
}
func (lb *Labels) ApplyStyle(ps *plot.PlotStyle, idx int) {
lb.Style.Line.SpacedColor(idx)
ps.SetElementStyle(&lb.Style)
yst := &plot.Style{}
lb.ystylers.Run(yst)
lb.Style.Range = yst.Range // get range from y
lb.stylers.Run(&lb.Style) // can still override here
}
func (lb *Labels) Stylers() *plot.Stylers { return &lb.stylers }
func (lb *Labels) Data() (data plot.Data, pixX, pixY []float32) {
pixX = lb.PX
pixY = lb.PY
data = plot.Data{}
data[plot.X] = lb.X
data[plot.Y] = lb.Y
data[plot.Label] = lb.Labels
return
}
// Plot implements the Plotter interface, drawing labels.
func (lb *Labels) Plot(plt *plot.Plot) {
pc := plt.Painter
uc := &pc.UnitContext
lb.PX = plot.PlotX(plt, lb.X)
lb.PY = plot.PlotY(plt, lb.Y)
st := &lb.Style.Text
st.Offset.ToDots(uc)
var ltxt plot.Text
ltxt.Defaults()
ltxt.Style = *st
ltxt.ToDots(uc)
nskip := lb.Style.LabelSkip
skip := nskip // start with label
for i, label := range lb.Labels {
if label == "" {
continue
}
if skip != nskip {
skip++
continue
}
skip = 0
ltxt.Text = label
ltxt.Config(plt)
tht := ltxt.Size().Y
ltxt.Draw(plt, math32.Vec2(lb.PX[i]+st.Offset.X.Dots, lb.PY[i]+st.Offset.Y.Dots-tht))
}
}
// UpdateRange updates the given ranges.
func (lb *Labels) UpdateRange(plt *plot.Plot, x, y, yr, z, size *minmax.F64) {
if lb.Style.RightY {
y = yr
}
// todo: include point sizes!
plot.Range(lb.X, x)
plot.RangeClamp(lb.Y, y, &lb.Style.Range)
pxToData := math32.FromPoint(plt.PaintBox.Size())
pxToData.X = float32(x.Range()) / pxToData.X
pxToData.Y = float32(y.Range()) / pxToData.Y
st := &lb.Style.Text
var ltxt plot.Text
ltxt.Style = *st
for i, label := range lb.Labels {
if label == "" {
continue
}
ltxt.Text = label
ltxt.Config(plt)
tht := pxToData.Y * ltxt.Size().Y
twd := 1.1 * pxToData.X * ltxt.Size().X
xv := lb.X[i]
yv := lb.Y[i]
maxx := xv + float64(pxToData.X*st.Offset.X.Dots+twd)
maxy := yv + float64(pxToData.Y*st.Offset.Y.Dots+tht) // y is up here
x.FitInRange(minmax.F64{xv, maxx})
y.FitInRange(minmax.F64{yv, maxy})
}
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Adapted from github.com/gonum/plot:
// Copyright ©2015 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plots
//go:generate core generate
import (
"fmt"
"cogentcore.org/core/math32"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/lab/plot"
"cogentcore.org/lab/tensor"
)
// XYType is be used for specifying the type name.
const XYType = "XY"
func init() {
plot.RegisterPlotter(XYType, "draws lines between and / or points for X,Y data values, using optional Size data for the points, for a bubble plot.", []plot.Roles{plot.Y}, []plot.Roles{plot.X, plot.Size}, func(plt *plot.Plot, data plot.Data) plot.Plotter {
return NewXY(plt, data)
})
}
// XY draws lines between and / or points for XY data values.
type XY struct {
// copies of data for this line
X, Y, Size plot.Values
// PX, PY are the actual pixel plotting coordinates for each XY value.
PX, PY []float32
// Style is the style for plotting.
Style plot.Style
stylers plot.Stylers
}
// NewXY adds a new XY plotter to given plot for given data,
// which can either by a [plot.Valuer] (e.g., Tensor) with the Y values,
// or a [plot.Data] with roles, and values defined.
// Data can also include Size for the points.
// Styler functions are obtained from the Y metadata if present.
func NewXY(plt *plot.Plot, data any) *XY {
ln := &XY{}
err := ln.SetData(data)
if err != nil {
// errors.Log(err) not useful actually
return nil
}
ln.Defaults()
plt.Add(ln)
return ln
}
// SetData sets the plot data.
func (ln *XY) SetData(data any) error {
dt, err := plot.DataOrValuer(data, plot.Y)
if err != nil {
return err
}
if err := dt.CheckLengths(); err != nil {
return err
}
ln.Y = plot.MustCopyRole(dt, plot.Y)
if _, ok := dt[plot.X]; !ok {
ln.X, err = plot.CopyValues(tensor.NewIntRange(len(ln.Y)))
if err != nil {
return err
}
} else {
ln.X = plot.MustCopyRole(dt, plot.X)
}
if ln.X == nil || ln.Y == nil {
return fmt.Errorf("X or Y is nil")
}
ln.stylers = plot.GetStylersFromData(dt, plot.Y)
ln.Size = plot.CopyRole(dt, plot.Size)
return nil
}
// newXYWith is a simple helper function that creates a new XY plotter
// with lines and/or points.
func newXYWith(plt *plot.Plot, data any, line, point plot.DefaultOffOn) *XY {
ln := NewXY(plt, data)
if ln == nil {
return ln
}
ln.Style.Line.On = line
ln.Style.Point.On = point
return ln
}
// NewLine adds an XY plot drawing Lines only by default, for given data
// which can either by a [plot.Valuer] (e.g., Tensor) with the Y values,
// or a [plot.Data] with roles, and values defined.
// See also [NewScatter] and [NewPointLine].
func NewLine(plt *plot.Plot, data any) *XY {
return newXYWith(plt, data, plot.On, plot.Off)
}
// NewScatter adds an XY scatter plot drawing Points only by default, for given data
// which can either by a [plot.Valuer] (e.g., Tensor) with the Y values,
// or a [plot.Data] with roles, and values defined.
// See also [NewLine] and [NewPointLine].
func NewScatter(plt *plot.Plot, data any) *XY {
return newXYWith(plt, data, plot.Off, plot.On)
}
// NewPointLine adds an XY plot drawing both lines and points by default, for given data
// which can either by a [plot.Valuer] (e.g., Tensor) with the Y values,
// or a [plot.Data] with roles, and values defined.
// See also [NewLine] and [NewScatter].
func NewPointLine(plt *plot.Plot, data any) *XY {
return newXYWith(plt, data, plot.On, plot.On)
}
func (ln *XY) Defaults() {
ln.Style.Defaults()
}
// Styler adds a style function to set style parameters.
func (ln *XY) Styler(f func(s *plot.Style)) *XY {
ln.stylers.Add(f)
return ln
}
func (ln *XY) Stylers() *plot.Stylers { return &ln.stylers }
func (ln *XY) ApplyStyle(ps *plot.PlotStyle, idx int) {
ln.Style.Line.SpacedColor(idx)
ln.Style.Point.SpacedColor(idx)
ps.SetElementStyle(&ln.Style)
ln.stylers.Run(&ln.Style)
}
func (ln *XY) Data() (data plot.Data, pixX, pixY []float32) {
pixX = ln.PX
pixY = ln.PY
data = plot.Data{}
data[plot.X] = ln.X
data[plot.Y] = ln.Y
if ln.Size != nil {
data[plot.Size] = ln.Size
}
return
}
// Plot does the drawing, implementing the plot.Plotter interface.
func (ln *XY) Plot(plt *plot.Plot) {
ln.PX = plot.PlotX(plt, ln.X)
var minY float32
if ln.Style.RightY {
ln.PY = plot.PlotYR(plt, ln.Y)
minY = plt.PYR(plt.YR.Range.Min)
} else {
ln.PY = plot.PlotY(plt, ln.Y)
minY = plt.PY(plt.Y.Range.Min)
}
np := min(len(ln.PX), len(ln.PY))
if np == 0 {
return
}
pc := plt.Painter
if ln.Style.Line.HasFill() {
pc.Fill.Color = ln.Style.Line.Fill
prevX := ln.PX[0]
prevY := minY
pc.MoveTo(prevX, prevY)
for i, ptx := range ln.PX {
pty := ln.PY[i]
switch ln.Style.Line.Step {
case plot.NoStep:
if ptx < prevX {
pc.LineTo(prevX, minY)
pc.Close()
pc.MoveTo(ptx, minY)
}
pc.LineTo(ptx, pty)
case plot.PreStep:
if i == 0 {
continue
}
if ptx < prevX {
pc.LineTo(prevX, minY)
pc.Close()
pc.MoveTo(ptx, minY)
} else {
pc.LineTo(prevX, pty)
}
pc.LineTo(ptx, pty)
case plot.MidStep:
if ptx < prevX {
pc.LineTo(prevX, minY)
pc.Close()
pc.MoveTo(ptx, minY)
} else {
pc.LineTo(0.5*(prevX+ptx), prevY)
pc.LineTo(0.5*(prevX+ptx), pty)
}
pc.LineTo(ptx, pty)
case plot.PostStep:
if ptx < prevX {
pc.LineTo(prevX, minY)
pc.Close()
pc.MoveTo(ptx, minY)
} else {
pc.LineTo(ptx, prevY)
}
pc.LineTo(ptx, pty)
}
prevX, prevY = ptx, pty
}
pc.LineTo(prevX, minY)
pc.Close()
pc.Draw()
}
pc.Fill.Color = nil
if ln.Style.Line.SetStroke(plt) {
if plt.HighlightPlotter == ln {
pc.Stroke.Width.Dots *= 2
}
prevX, prevY := ln.PX[0], ln.PY[0]
pc.MoveTo(prevX, prevY)
for i := 1; i < np; i++ {
ptx, pty := ln.PX[i], ln.PY[i]
if ln.Style.Line.Step != plot.NoStep {
if ptx >= prevX {
switch ln.Style.Line.Step {
case plot.PreStep:
pc.LineTo(prevX, pty)
case plot.MidStep:
pc.LineTo(0.5*(prevX+ptx), prevY)
pc.LineTo(0.5*(prevX+ptx), pty)
case plot.PostStep:
pc.LineTo(ptx, prevY)
}
} else {
pc.MoveTo(ptx, pty)
}
}
if !ln.Style.Line.NegativeX && ptx < prevX {
pc.MoveTo(ptx, pty)
} else {
pc.LineTo(ptx, pty)
}
prevX, prevY = ptx, pty
}
pc.Draw()
}
if ln.Style.Point.SetStroke(plt) {
origWidth := ln.Style.Point.Width
origSize := ln.Style.Point.Size
for i, ptx := range ln.PX {
pty := ln.PY[i]
pc.Stroke.Width = origWidth
ln.Style.Point.Size = origSize
if plt.HighlightPlotter == ln {
if i == plt.HighlightIndex {
pc.Stroke.Width.Dots *= 2
ln.Style.Point.Size.Dots *= 1.5
}
}
if ln.Size != nil {
ln.Style.Point.Size.Dots = 1 + ln.Style.Point.Size.Dots*float32(plt.SizeAxis.Norm(ln.Size.Float1D(i)))
}
ln.Style.Point.SetColorIndex(pc, i)
ln.Style.Point.DrawShape(pc, math32.Vec2(ptx, pty))
}
ln.Style.Point.Size = origSize
} else if plt.HighlightPlotter == ln {
op := ln.Style.Point.On
origSize := ln.Style.Point.Size
ln.Style.Point.On = plot.On
ln.Style.Point.Width.Pt(2)
ln.Style.Point.Size.Pt(4.5)
ln.Style.Point.SetStroke(plt)
ptx := ln.PX[plt.HighlightIndex]
pty := ln.PY[plt.HighlightIndex]
ln.Style.Point.DrawShape(pc, math32.Vec2(ptx, pty))
ln.Style.Point.On = op
ln.Style.Point.Size = origSize
}
pc.Fill.Color = nil
}
// UpdateRange updates the given ranges.
func (ln *XY) UpdateRange(plt *plot.Plot, x, y, yr, z, size *minmax.F64) {
if ln.Style.RightY {
y = yr
}
plot.Range(ln.X, x)
if !ln.Style.Point.IsOn(plt) {
plot.RangeClamp(ln.Y, y, &ln.Style.Range)
return
}
plot.Range(ln.Y, y)
psz := ln.Style.Point.Size.Dots
ptb := plt.PaintBox
dy := (float64(psz) / float64(ptb.Size().Y)) * y.Range()
y.Min -= dy
y.Max += dy
y.Min, y.Max = ln.Style.Range.Clamp(y.Min, y.Max)
dx := (float64(psz) / float64(ptb.Size().X)) * x.Range()
x.Min -= dx
x.Max += dx
plot.Range(ln.Size, size)
}
// Thumbnail returns the thumbnail, implementing the plot.Thumbnailer interface.
func (ln *XY) Thumbnail(plt *plot.Plot) {
pc := plt.Painter
ptb := plt.CurBounds()
midY := 0.5 * float32(ptb.Min.Y+ptb.Max.Y)
if ln.Style.Line.Fill != nil {
tb := ptb
if ln.Style.Line.Width.Value > 0 {
tb.Min.Y = int(midY)
}
pc.FillBox(math32.FromPoint(tb.Min), math32.FromPoint(tb.Size()), ln.Style.Line.Fill)
}
if ln.Style.Line.SetStroke(plt) {
pc.MoveTo(float32(ptb.Min.X), midY)
pc.LineTo(float32(ptb.Max.X), midY)
pc.Draw()
}
if ln.Style.Point.SetStroke(plt) {
midX := 0.5 * float32(ptb.Min.X+ptb.Max.X)
ln.Style.Point.DrawShape(pc, math32.Vec2(midX, midY))
}
pc.Fill.Color = nil
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
import (
"fmt"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/math32/minmax"
)
// Plotter is an interface that wraps the Plot method.
// Standard implementations of Plotter are in the [plots] package.
type Plotter interface {
// Plot draws the data to the Plot Paint.
Plot(pt *Plot)
// UpdateRange updates the given ranges.
UpdateRange(plt *Plot, x, y, yr, z, size *minmax.F64)
// Data returns the data by roles for this plot, for both the original
// data and the pixel-transformed X,Y coordinates for that data.
// This allows a GUI interface to inspect data etc.
Data() (data Data, pixX, pixY []float32)
// Stylers returns the styler functions for this element.
Stylers() *Stylers
// ApplyStyle applies any stylers to this element,
// first initializing from the given global plot style, which has
// already been styled with defaults and all the plot element stylers.
ApplyStyle(plotStyle *PlotStyle, idx int)
}
// PlotterType registers a Plotter so that it can be created with appropriate data.
type PlotterType struct {
// Name of the plot type.
Name string
// Doc is the documentation for this Plotter.
Doc string
// Required Data roles for this plot. Data for these Roles must be provided.
Required []Roles
// Optional Data roles for this plot.
Optional []Roles
// New returns a new plotter of this type with given data in given roles.
New func(plt *Plot, data Data) Plotter
}
// PlotterName is the name of a specific plotter type.
type PlotterName string
// Plotters is the registry of [Plotter] types.
var Plotters = map[string]PlotterType{}
// RegisterPlotter registers a plotter type.
func RegisterPlotter(name, doc string, required, optional []Roles, newFun func(plt *Plot, data Data) Plotter) {
Plotters[name] = PlotterType{Name: name, Doc: doc, Required: required, Optional: optional, New: newFun}
}
// PlotterByType returns [PlotterType] info for a registered [Plotter]
// of given type name, e.g., "XY", "Bar" etc,
// Returns an error and nil if type name is not a registered type.
func PlotterByType(typeName string) (*PlotterType, error) {
pt, ok := Plotters[typeName]
if !ok {
return nil, fmt.Errorf("plot.PlotterByType type name is not registered: %s", typeName)
}
return &pt, nil
}
// NewPlotter returns a new plotter of given type, e.g., "XY", "Bar" etc,
// for given data roles (which must include Required roles, and may include Optional ones).
// Logs an error and returns nil if type name is not a registered type.
func NewPlotter(plt *Plot, typeName string, data Data) Plotter {
pt, err := PlotterByType(typeName)
if errors.Log(err) != nil {
return nil
}
return pt.New(plt, data)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
import (
"image"
"cogentcore.org/core/colors"
"cogentcore.org/core/math32"
"cogentcore.org/core/paint"
"cogentcore.org/core/styles/units"
)
// PointStyle has style properties for drawing points as different shapes.
type PointStyle struct { //types:add -setters
// On indicates whether to plot points.
On DefaultOffOn
// Shape to draw.
Shape Shapes
// Color is the stroke color image specification.
// Setting to nil turns stroke off. See also [PointStyle.ColorFunc].
Color image.Image
// Fill is the color to fill points.
// Use nil to disable filling. See also [PointStyle.FillFunc].
Fill image.Image
// ColorFunc, if non-nil, is used instead of [PointStyle.Color].
// The function returns the stroke color to use for a given point index.
ColorFunc func(i int) image.Image
// FillFunc, if non-nil, is used instead of [PointStyle.Fill].
// The function returns the fill color to use for a given point index.
FillFunc func(i int) image.Image
// Width is the line width for point glyphs, with a default of 1 Pt (point).
// Setting to 0 turns line off.
Width units.Value
// Size of shape to draw for each point.
// Defaults to 3 Pt (point).
Size units.Value
}
func (ps *PointStyle) Defaults() {
ps.Shape = Circle
ps.Color = colors.Scheme.OnSurface
ps.Fill = colors.Scheme.OnSurface
ps.Width.Pt(1)
ps.Size.Pt(3)
}
// SpacedColor sets the Color to a default spaced color based on index,
// if it still has the initial OnSurface default.
func (ps *PointStyle) SpacedColor(idx int) {
if ps.Color == colors.Scheme.OnSurface {
ps.Color = colors.Uniform(colors.Spaced(idx))
}
if ps.Fill == colors.Scheme.OnSurface {
ps.Fill = colors.Uniform(colors.Spaced(idx))
}
}
// IsOn returns true if points are to be drawn.
// Also computes the dots sizes at this point.
func (ps *PointStyle) IsOn(pt *Plot) bool {
uc := pt.UnitContext()
ps.Width.ToDots(uc)
ps.Size.ToDots(uc)
if ps.On == Off || (ps.Color == nil && ps.Fill == nil && ps.ColorFunc == nil && ps.FillFunc == nil) || ps.Width.Dots == 0 || ps.Size.Dots == 0 {
return false
}
return true
}
// SetStroke sets the stroke style in plot paint to current line style.
// returns false if either the Width = 0 or Color is nil
func (ps *PointStyle) SetStroke(pt *Plot) bool {
if !ps.IsOn(pt) {
return false
}
uc := pt.UnitContext()
pc := pt.Painter
pc.Stroke.Width = ps.Width
pc.Stroke.Color = ps.Color
pc.Stroke.ToDots(uc)
if ps.Shape <= Pyramid {
pc.Fill.Color = ps.Fill
} else {
pc.Fill.Color = nil
}
return true
}
// SetColorIndex sets the stroke and fill colors based on index-specific
// color functions if applicable ([PointStyle.ColorFunc] and
// [PointStyle.FillFunc]).
func (ps *PointStyle) SetColorIndex(pc *paint.Painter, i int) {
if ps.ColorFunc != nil {
pc.Stroke.Color = ps.ColorFunc(i)
}
if ps.FillFunc != nil && ps.Shape <= Pyramid {
pc.Fill.Color = ps.FillFunc(i)
}
}
// DrawShape draws the given shape
func (ps *PointStyle) DrawShape(pc *paint.Painter, pos math32.Vector2) {
size := ps.Size.Dots
if size == 0 {
return
}
switch ps.Shape {
case Ring:
DrawRing(pc, pos, size)
case Circle:
DrawCircle(pc, pos, size)
case Square:
DrawSquare(pc, pos, size)
case Box:
DrawBox(pc, pos, size)
case Triangle:
DrawTriangle(pc, pos, size)
case Pyramid:
DrawPyramid(pc, pos, size)
case Plus:
DrawPlus(pc, pos, size)
case Cross:
DrawCross(pc, pos, size)
}
}
func DrawRing(pc *paint.Painter, pos math32.Vector2, size float32) {
pc.Circle(pos.X, pos.Y, size)
pc.Draw()
}
func DrawCircle(pc *paint.Painter, pos math32.Vector2, size float32) {
pc.Circle(pos.X, pos.Y, size)
pc.Draw()
}
func DrawSquare(pc *paint.Painter, pos math32.Vector2, size float32) {
x := size * 0.9
pc.MoveTo(pos.X-x, pos.Y-x)
pc.LineTo(pos.X+x, pos.Y-x)
pc.LineTo(pos.X+x, pos.Y+x)
pc.LineTo(pos.X-x, pos.Y+x)
pc.Close()
pc.Draw()
}
func DrawBox(pc *paint.Painter, pos math32.Vector2, size float32) {
x := size * 0.9
pc.MoveTo(pos.X-x, pos.Y-x)
pc.LineTo(pos.X+x, pos.Y-x)
pc.LineTo(pos.X+x, pos.Y+x)
pc.LineTo(pos.X-x, pos.Y+x)
pc.Close()
pc.Draw()
}
func DrawTriangle(pc *paint.Painter, pos math32.Vector2, size float32) {
x := size * 0.9
pc.MoveTo(pos.X, pos.Y-x)
pc.LineTo(pos.X-x, pos.Y+x)
pc.LineTo(pos.X+x, pos.Y+x)
pc.Close()
pc.Draw()
}
func DrawPyramid(pc *paint.Painter, pos math32.Vector2, size float32) {
x := size * 0.9
pc.MoveTo(pos.X, pos.Y-x)
pc.LineTo(pos.X-x, pos.Y+x)
pc.LineTo(pos.X+x, pos.Y+x)
pc.Close()
pc.Draw()
}
func DrawPlus(pc *paint.Painter, pos math32.Vector2, size float32) {
x := size * 1.05
pc.MoveTo(pos.X-x, pos.Y)
pc.LineTo(pos.X+x, pos.Y)
pc.MoveTo(pos.X, pos.Y-x)
pc.LineTo(pos.X, pos.Y+x)
pc.Close()
pc.Draw()
}
func DrawCross(pc *paint.Painter, pos math32.Vector2, size float32) {
x := size * 0.9
pc.MoveTo(pos.X-x, pos.Y-x)
pc.LineTo(pos.X+x, pos.Y+x)
pc.MoveTo(pos.X+x, pos.Y-x)
pc.LineTo(pos.X-x, pos.Y+x)
pc.Close()
pc.Draw()
}
// Shapes has the options for how to draw points in the plot.
type Shapes int32 //enums:enum
const (
// Circle is a solid circle
Circle Shapes = iota
// Box is a filled square
Box
// Pyramid is a filled triangle
Pyramid
// Plus is a plus sign
Plus
// Cross is a big X
Cross
// Ring is the outline of a circle
Ring
// Square is the outline of a square
Square
// Triangle is the outline of a triangle
Triangle
)
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
import (
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/core/styles/units"
"cogentcore.org/lab/table"
)
// Style contains the plot styling properties relevant across
// most plot types. These properties apply to individual plot elements
// while the Plot properties applies to the overall plot itself.
type Style struct { //types:add -setters
// Plot has overall plot-level properties, which can be set by any
// plot element, and are updated first, before applying element-wise styles.
Plot PlotStyle `display:"-"`
// On specifies whether to plot this item, for table-based plots.
On bool
// Plotter is the type of plotter to use in plotting this data,
// for [plot.NewTablePlot] [table.Table] driven plots.
// Blank means use default ([plots.XY] is overall default).
Plotter PlotterName
// Role specifies how a particular column of data should be used,
// for [plot.NewTablePlot] [table.Table] driven plots.
Role Roles
// Group specifies a group of related data items,
// for [plot.NewTablePlot] [table.Table] driven plots,
// where different columns of data within the same Group play different Roles.
Group string
// Range is the effective range of data to plot, where either end can be fixed.
Range minmax.Range64 `display:"inline"`
// Label provides an alternative label to use for axis, if set.
Label string
// NoLegend excludes this item from the legend when it otherwise would be included,
// for [plot.NewTablePlot] [table.Table] driven plots.
// Role = Y values are included in the Legend by default.
NoLegend bool
// RightY specifies that this should use the right-side alternate Y axis.
RightY bool
// NTicks sets the desired number of ticks for the axis, if > 0.
NTicks int
// LabelSkip is the number of data points to skip between Labels.
// 0 means plot the Label at every point.
LabelSkip int
// Line has style properties for drawing lines.
Line LineStyle `display:"add-fields"`
// Point has style properties for drawing points.
Point PointStyle `display:"add-fields"`
// Text has style properties for rendering text.
Text TextStyle `display:"add-fields"`
// Width has various plot width properties.
Width WidthStyle `display:"inline"`
}
// NewStyle returns a new Style object with defaults applied.
func NewStyle() *Style {
st := &Style{}
st.Defaults()
return st
}
func (st *Style) Defaults() {
st.Plot.Defaults()
st.Line.Defaults()
st.Point.Defaults()
st.Text.Defaults()
st.Width.Defaults()
}
// WidthStyle contains various plot width properties relevant across
// different plot types.
type WidthStyle struct { //types:add -setters
// Cap is the width of the caps drawn at the top of error bars.
// The default is 10dp
Cap units.Value
// Offset for Bar plot is the offset added to each X axis value
// relative to the Stride computed value (X = offset + index * Stride)
// Defaults to 0.
Offset float64
// Stride for Bar plot is distance between bars. Defaults to 1.
Stride float64
// Width for Bar plot is the width of the bars, as a fraction of the Stride,
// to prevent bar overlap. Defaults to .8.
Width float64 `min:"0.01" max:"1" default:"0.8"`
// Pad for Bar plot is additional space at start / end of data range,
// to keep bars from overflowing ends. This amount is subtracted from Offset
// and added to (len(Values)-1)*Stride -- no other accommodation for bar
// width is provided, so that should be built into this value as well.
// Defaults to 1.
Pad float64
}
func (ws *WidthStyle) Defaults() {
ws.Cap.Dp(10)
ws.Offset = 1
ws.Stride = 1
ws.Width = .8
ws.Pad = 1
}
// Stylers is a list of styling functions that set Style properties.
// These are called in the order added.
type Stylers []func(s *Style)
// Add Adds a styling function to the list.
func (st *Stylers) Add(f func(s *Style)) {
if st == nil {
*st = append(Stylers{}, f)
} else {
*st = append(*st, f)
}
}
// Run runs the list of styling functions on given [Style] object.
func (st *Stylers) Run(s *Style) {
if st == nil {
return
}
for _, f := range *st {
f(s)
}
}
// NewStyle returns a new Style object with styling functions applied
// on top of Style defaults.
func (st *Stylers) NewStyle(ps *PlotStyle) *Style {
s := NewStyle()
ps.SetElementStyle(s)
st.Run(s)
return s
}
// SetStyler sets the [Stylers] function(s) into given object's [metadata].
// This overwrites any existing styler functions. The [plotcore.Editor]
// depends on adding a styler function on top of any existing ones,
// so it is better to use [SetFirstStyle] if that is being used.
func SetStyler(obj any, st ...func(s *Style)) {
metadata.Set(obj, "PlotStylers", Stylers(st))
}
// GetStylers returns [Stylers] functions from given object's [metadata].
// Returns nil if none or no metadata.
func GetStylers(obj any) Stylers {
st, _ := metadata.Get[Stylers](obj, "PlotStylers")
return st
}
// SetFirstStyler sets the [Styler] function into given object's [metadata],
// only if there are no other stylers present. This is important for cases
// where code may be run multiple times on the same object, and you don't want
// to add multiple redundant style functions (and [plotcore.Editor] is being used).
func SetFirstStyler(obj any, f func(s *Style)) {
st := GetStylers(obj)
if len(st) > 0 {
return
}
metadata.Set(obj, "PlotStylers", Stylers{f})
}
// Styler adds the given [Styler] function into given object's [metadata].
func Styler(obj any, f func(s *Style)) {
st := GetStylers(obj)
st.Add(f)
SetStyler(obj, st...)
}
// GetStylersFromData returns [Stylers] from given role
// in given [Data]. nil if not present. Mostly used internally
// for Plotters implementations.
func GetStylersFromData(data Data, role Roles) Stylers {
vr, ok := data[role]
if !ok {
return nil
}
return GetStylers(vr)
}
// BasicStylers returns a basic set of [Stylers] that can be used with
// functions like [Editor.SetSlice]. They make the first column the x-axis,
// and turn on the second column.
func BasicStylers() Stylers {
return Stylers{
func(s *Style) {
s.Role = X
},
func(s *Style) {
s.Role = Y
s.On = true
},
}
}
// SetBasicStylers applies [BasicStylers] to the first two columns of the given
// table.
func SetBasicStylers(dt *table.Table) {
bs := BasicStylers()
Styler(dt.Columns.Values[0], bs[0])
Styler(dt.Columns.Values[1], bs[1])
}
////////
// DefaultOffOn specifies whether to use the default value for a bool option,
// or to override the default and set Off or On.
type DefaultOffOn int32 //enums:enum
const (
// Default means use the default value.
Default DefaultOffOn = iota
// Off means to override the default and turn Off.
Off
// On means to override the default and turn On.
On
)
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
import (
"fmt"
"image"
"reflect"
"slices"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/base/reflectx"
"cogentcore.org/core/colors"
"cogentcore.org/lab/stats/stats"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorfs"
"golang.org/x/exp/maps"
)
// NewTablePlot returns a new Plot with all configuration based on given
// [table.Table] set of columns and associated metadata, which must have
// [Stylers] functions set (e.g., [SetStylersTo]) that at least set basic
// table parameters, including:
// - On: Set the main (typically Role = Y) column On to include in plot.
// - Role: Set the appropriate [Roles] role for this column (Y, X, etc).
// - Group: Multiple columns used for a given Plotter type must be grouped
// together with a common name (typically the name of the main Y axis),
// e.g., for Low, High error bars, Size, Color, etc. If only one On column,
// then Group can be empty and all other such columns will be grouped.
// - Plotter: Determines the type of Plotter element to use, which in turn
// determines the additional Roles that can be used within a Group.
//
// Returns nil if no valid plot elements were present.
func NewTablePlot(dt *table.Table) (*Plot, error) {
nc := len(dt.Columns.Values)
if nc == 0 {
return nil, errors.New("plot.NewTablePlot: no columns in data table")
}
csty := make([]*Style, nc)
gps := make(map[string][]int, nc)
xi := -1 // get the _last_ role = X column -- most specific counter
var errs []error
var pstySt Style // overall PlotStyle accumulator
pstySt.Defaults()
for ci, cl := range dt.Columns.Values {
st := &Style{}
st.Defaults()
stl := GetStylers(cl)
if stl != nil {
stl.Run(st)
}
csty[ci] = st
stl.Run(&pstySt)
gps[st.Group] = append(gps[st.Group], ci)
if st.Role == X {
xi = ci
}
}
psty := pstySt.Plot
globalX := false
xidxs := map[int]bool{} // map of all the _unique_ x indexes used
if psty.XAxis.Column != "" {
xc := dt.Columns.IndexByKey(psty.XAxis.Column)
if xc >= 0 {
xi = xc
globalX = true
xidxs[xi] = true
} else {
errs = append(errs, errors.New("XAxis.Column name not found: "+psty.XAxis.Column))
}
}
type pitem struct {
ptyp string
pt *PlotterType
data Data
lbl string
ci int
clr image.Image // if set, set styler
}
var ptrs []*pitem // accumulate in case of grouping
doneGps := map[string]bool{}
var split *tensor.Rows
nLegends := 0
for ci := range dt.Columns.Values {
cl := dt.ColumnByIndex(ci)
cnm := dt.Columns.Keys[ci]
st := csty[ci]
if !st.On || st.Role == X {
continue
}
if st.Role == Split {
if split != nil {
errs = append(errs, errors.New("NewTablePlot: Only 1 Split role can be defined, using the first one"))
}
split = cl
continue
}
lbl := cnm
if st.Label != "" {
lbl = st.Label
}
gp := st.Group
if doneGps[gp] {
continue
}
if gp != "" {
doneGps[gp] = true
}
ptyp := "XY"
if st.Plotter != "" {
ptyp = string(st.Plotter)
}
pt, err := PlotterByType(ptyp)
if err != nil {
errs = append(errs, err)
continue
}
data := Data{st.Role: cl}
gcols := gps[gp]
gotReq := true
gotX := -1
if globalX {
data[X] = dt.ColumnByIndex(xi)
gotX = xi
}
for _, rl := range pt.Required {
if rl == st.Role || (rl == X && globalX) {
continue
}
got := false
for _, gi := range gcols {
gst := csty[gi]
if gst.Role == rl {
if rl == Y {
if !gst.On {
continue
}
}
data[rl] = dt.ColumnByIndex(gi)
got = true
if rl == X {
gotX = gi // fallthrough so we get the last X
} else {
break
}
}
}
if !got {
if rl == X && xi >= 0 {
gotX = xi
data[rl] = dt.ColumnByIndex(xi)
} else {
err = fmt.Errorf("plot.NewTablePlot: Required Role %q not found in Group %q, Plotter %q not added for Column: %q", rl.String(), gp, ptyp, cnm)
errs = append(errs, err)
gotReq = false
}
}
}
if !gotReq {
continue
}
for _, rl := range pt.Optional {
if rl == st.Role || (rl == X && globalX) {
continue
}
got := false
for _, gi := range gcols {
gst := csty[gi]
if gst.Role == rl {
data[rl] = dt.ColumnByIndex(gi)
got = true
if rl == X {
gotX = gi // fallthrough so we get the last X
} else {
break
}
}
}
if !got && rl == X && xi >= 0 {
gotX = xi
data[rl] = dt.ColumnByIndex(xi)
}
}
if gotX >= 0 {
xidxs[gotX] = true
}
ptrs = append(ptrs, &pitem{ptyp: ptyp, pt: pt, data: data, lbl: lbl, ci: ci})
if !st.NoLegend {
nLegends++
}
}
if len(ptrs) == 0 {
return nil, errors.Join(errs...)
}
plt := New()
// do splits here, make a new list of ptrs
if split != nil {
spnm := metadata.Name(split)
if spnm == "" {
spnm = "0"
}
dir := errors.Log1(tensorfs.NewDir("TablePlot"))
err := stats.Groups(dir, split)
if err != nil {
errs = append(errs, err) // todo maybe bail here
}
sdir := dir.Dir("Groups").Dir(spnm)
gps := errors.Log1(sdir.Values())
// generate tensor.Rows indexed views of the original data
// for each unique element in pt.data.* -- the x axis is shared
// so we need a map to just do this once.
// [gp][pt.data.*]sliced
subd := make(map[tensor.Tensor]map[*tensor.Rows]*tensor.Rows)
for _, gp := range gps {
sv := make(map[*tensor.Rows]*tensor.Rows)
idxs := slices.Clone(gp.(*tensor.Int).Values)
for _, pt := range ptrs {
for _, dd := range pt.data {
dv := dd.(*tensor.Rows)
rv, ok := sv[dv]
if !ok {
rv = tensor.NewRows(dv.Tensor, idxs...)
}
sv[dv] = rv
}
}
subd[gp] = sv
}
// now go in plotter item order, then groups within, and make the new
// plot items
nptrs := make([]*pitem, 0, len(gps)*len(ptrs))
nLegends = len(gps) * nLegends
idx := 0
for _, pt := range ptrs {
for _, gp := range gps {
nd := Data{}
for rl, dd := range pt.data {
dv := dd.(*tensor.Rows)
rv := subd[gp][dv]
nd[rl] = rv
}
npt := *pt
npt.clr = colors.Uniform(colors.Spaced(idx))
npt.data = nd
npt.lbl = metadata.Name(gp) + " " + pt.lbl
nptrs = append(nptrs, &npt)
idx++
}
}
ptrs = nptrs
}
var barCols []int // column indexes of bar plots
var barPlots []int // plotter indexes of bar plots
for _, pt := range ptrs {
pl := pt.pt.New(plt, pt.data)
if reflectx.IsNil(reflect.ValueOf(pl)) {
err := fmt.Errorf("plot.NewTablePlot: error in creating plotter type: %q", pt.ptyp)
errs = append(errs, err)
continue
}
if pt.clr != nil {
pl.Stylers().Add(func(s *Style) {
s.Line.Color = pt.clr
if pt.ptyp == "Bar" {
s.Line.Fill = pt.clr
}
s.Point.Color = pt.clr
s.Point.Fill = pt.clr
})
}
plt.Add(pl)
st := csty[pt.ci]
if !st.NoLegend && nLegends > 1 {
if tn, ok := pl.(Thumbnailer); ok {
plt.Legend.Add(pt.lbl, tn)
}
}
if pt.ptyp == "Bar" {
barCols = append(barCols, pt.ci)
barPlots = append(barPlots, len(plt.Plotters)-1)
}
}
// Get XAxis label from actual x axis.
// todo: probably range from here too.
if psty.XAxis.Label == "" && len(xidxs) == 1 {
xi := maps.Keys(xidxs)[0]
lbl := dt.Columns.Keys[xi]
if csty[xi].Label != "" {
lbl = csty[xi].Label
}
if len(plt.Plotters) > 0 {
pl0 := plt.Plotters[0]
if pl0 != nil {
pl0.Stylers().Add(func(s *Style) {
s.Plot.XAxis.Label = lbl
})
}
}
}
// Set bar spacing based on total number of bars present.
nbar := len(barCols)
if nbar > 1 {
sz := 1.0 / (float64(nbar) + 0.5)
for bi, bp := range barPlots {
pl := plt.Plotters[bp]
pl.Stylers().Add(func(s *Style) {
s.Width.Stride = 1
s.Width.Offset = float64(bi) * sz
s.Width.Width = psty.BarWidth * sz
})
}
}
return plt, errors.Join(errs...)
}
// todo: bar chart rows, if needed
//
// netn := pl.table.NumRows() * stride
// xc := pl.table.ColumnByIndex(xi)
// vals := make([]string, netn)
// for i, dx := range pl.table.Indexes {
// pi := mid + i*stride
// if pi < netn && dx < xc.Len() {
// vals[pi] = xc.String1D(dx)
// }
// }
// plt.NominalX(vals...)
// todo:
// Use string labels for X axis if X is a string
// xc := pl.table.ColumnByIndex(xi)
// if xc.Tensor.IsString() {
// xcs := xc.Tensor.(*tensor.String)
// vals := make([]string, pl.table.NumRows())
// for i, dx := range pl.table.Indexes {
// vals[i] = xcs.Values[dx]
// }
// plt.NominalX(vals...)
// }
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
import (
"image"
"cogentcore.org/core/colors"
"cogentcore.org/core/math32"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/units"
"cogentcore.org/core/text/htmltext"
"cogentcore.org/core/text/rich"
"cogentcore.org/core/text/shaped"
"cogentcore.org/core/text/text"
)
// DefaultFontFamily specifies a default font family for plotting.
// if not set, the standard Cogent Core default font is used.
var DefaultFontFamily = rich.SansSerif
// TextStyle specifies styling parameters for Text elements.
type TextStyle struct { //types:add -setters
// Size of font to render. Default is 16dp
Size units.Value
// Family name for font (inherited): ordered list of comma-separated names
// from more general to more specific to use. Use split on, to parse.
Family rich.Family
// Color of text.
Color image.Image
// Align specifies how to align text along the relevant
// dimension for the text element.
Align styles.Aligns
// Padding is used in a case-dependent manner to add
// space around text elements.
Padding units.Value
// Rotation of the text, in degrees.
Rotation float32
// Offset is added directly to the final label location.
Offset units.XY
}
func (ts *TextStyle) Defaults() {
ts.Size.Dp(16)
ts.Color = colors.Scheme.OnSurface
ts.Align = styles.Center
ts.Family = DefaultFontFamily
}
// Text specifies a single text element in a plot
type Text struct {
// text string, which can use HTML formatting
Text string
// styling for this text element
Style TextStyle
// font has the font rendering styles.
font rich.Style
// textStyle has the text rendering styles.
textStyle text.Style
// PaintText is the [shaped.Lines] for painting the text.
PaintText *shaped.Lines
}
func (tx *Text) Defaults() {
tx.Style.Defaults()
}
// config is called during the layout of the plot, prior to drawing
func (tx *Text) Config(pt *Plot) {
uc := pt.UnitContext()
ts := &tx.textStyle
fs := &tx.font
fs.Defaults()
ts.Defaults()
ts.FontSize = tx.Style.Size
ts.WhiteSpace = text.WrapNever
fs.Family = tx.Style.Family
if tx.Style.Color != colors.Scheme.OnSurface {
fs.SetFillColor(colors.ToUniform(tx.Style.Color))
}
if math32.Abs(tx.Style.Rotation) > 10 {
tx.Style.Align = styles.End
}
ts.ToDots(uc)
// fmt.Printf("%p\n", pt.Painter)
// fmt.Println("tdots:", ts.FontSize.Dots)
tx.Style.Padding.ToDots(uc)
txln := float32(len(tx.Text))
fht := tx.textStyle.FontHeight(fs)
hsz := float32(12) * txln
// txs := &pt.StandardTextStyle
rt, _ := htmltext.HTMLToRich([]byte(tx.Text), fs, nil)
tx.PaintText = pt.TextShaper.WrapLines(rt, fs, ts, math32.Vec2(hsz, fht))
}
func (tx *Text) ToDots(uc *units.Context) {
tx.textStyle.ToDots(uc)
tx.Style.Padding.ToDots(uc)
}
// Size returns the actual render size of the text.
func (tx *Text) Size() math32.Vector2 {
if tx.PaintText == nil {
return math32.Vector2{}
}
bb := tx.PaintText.Bounds
if tx.Style.Rotation != 0 {
bb = bb.MulMatrix2(math32.Rotate2D(math32.DegToRad(tx.Style.Rotation)))
}
return bb.Size().Ceil()
}
// PosX returns the starting position for a horizontally-aligned text element,
// based on given width. Text must have been config'd already.
func (tx *Text) PosX(width float32) math32.Vector2 {
rsz := tx.Size()
pos := math32.Vector2{}
pos.X = styles.AlignFactor(tx.Style.Align) * width
switch tx.Style.Align {
case styles.Center:
pos.X -= 0.5 * rsz.X
case styles.End:
pos.X -= rsz.X
}
if math32.Abs(tx.Style.Rotation) > 10 {
pos.Y += 0.5 * rsz.Y
}
return pos
}
// PosY returns the starting position for a vertically-rotated text element,
// based on given height. Text must have been config'd already.
func (tx *Text) PosY(height float32) math32.Vector2 {
rsz := tx.Size() // rotated size
pos := math32.Vector2{}
pos.Y = styles.AlignFactor(tx.Style.Align) * height
switch tx.Style.Align {
case styles.Center:
pos.Y -= 0.5 * rsz.Y
case styles.End:
pos.Y -= rsz.Y
}
return pos
}
// Draw renders the text at given upper left position
func (tx *Text) Draw(pt *Plot, pos math32.Vector2) {
if tx.Style.Rotation == 0 {
pt.Painter.DrawText(tx.PaintText, pos)
return
}
m := pt.Painter.Paint.Transform
rotx := math32.Rotate2DAround(math32.DegToRad(tx.Style.Rotation), pos)
pt.Painter.Paint.Transform = m.Mul(rotx)
pt.Painter.DrawText(tx.PaintText, pos)
pt.Painter.Paint.Transform = m
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plot
import (
"math"
"strconv"
"time"
)
// A Tick is a single tick mark on an axis.
type Tick struct {
// Value is the data value marked by this Tick.
Value float64
// Label is the text to display at the tick mark.
// If Label is an empty string then this is a minor tick mark.
Label string
}
// IsMinor returns true if this is a minor tick mark.
func (tk *Tick) IsMinor() bool {
return tk.Label == ""
}
// Ticker creates Ticks in a specified range
type Ticker interface {
// Ticks returns Ticks in a specified range, with desired number of ticks,
// which can be ignored depending on the ticker type.
Ticks(mn, mx float64, nticks int) []Tick
}
// DefaultTicks is suitable for the Ticker field of an Axis,
// it returns a reasonable default set of tick marks.
type DefaultTicks struct{}
var _ Ticker = DefaultTicks{}
// Ticks returns Ticks in the specified range.
func (DefaultTicks) Ticks(mn, mx float64, nticks int) []Tick {
if mx <= mn {
panic("illegal range")
}
if nticks < 2 {
return nil
}
labels, step, q, mag := talbotLinHanrahan(mn, mx, nticks, withinData, nil, nil, nil)
majorDelta := step * math.Pow10(mag)
if q == 0 {
// Simple fall back was chosen, so
// majorDelta is the label distance.
majorDelta = labels[1] - labels[0]
}
// Choose a reasonable, but ad
// hoc formatting for labels.
fc := byte('f')
var off int
if mag < -1 || 6 < mag {
off = 1
fc = 'g'
}
mag10 := math.Pow10(mag)
if math.Trunc(q*mag10) != q*mag10 {
off += 2
}
prec := min(6, max(off, -mag))
ticks := make([]Tick, len(labels))
for i, v := range labels {
ticks[i] = Tick{Value: v, Label: strconv.FormatFloat(float64(v), fc, prec, 64)}
}
var minorDelta float64
// See talbotLinHanrahan for the values used here.
switch step {
case 1, 2.5:
minorDelta = majorDelta / 5
case 2, 3, 4, 5:
minorDelta = majorDelta / step
default:
if majorDelta/2 < dlamchP {
return ticks
}
minorDelta = majorDelta / 2
}
// Find the first minor tick not greater
// than the lowest data value.
var i float64
for labels[0]+(i-1)*minorDelta > mn {
i--
}
// Add ticks at minorDelta intervals when
// they are not within minorDelta/2 of a
// labelled tick.
for {
val := labels[0] + i*minorDelta
if val > mx {
break
}
found := false
for _, t := range ticks {
if math.Abs(t.Value-val) < minorDelta/2 {
found = true
}
}
if !found {
ticks = append(ticks, Tick{Value: val})
}
i++
}
return ticks
}
// LogTicks is suitable for the Ticker field of an Axis,
// it returns tick marks suitable for a log-scale axis.
type LogTicks struct {
// Prec specifies the precision of tick rendering
// according to the documentation for strconv.FormatFloat.
Prec int
}
var _ Ticker = LogTicks{}
// Ticks returns Ticks in a specified range
func (t LogTicks) Ticks(mn, mx float64, nticks int) []Tick {
if mn <= 0 || mx <= 0 {
panic("Values must be greater than 0 for a log scale.")
}
if nticks < 2 {
return nil
}
val := math.Pow10(int(math.Log10(mn)))
mx = math.Pow10(int(math.Ceil(math.Log10(mx))))
var ticks []Tick
for val < mx {
for i := 1; i < 10; i++ {
if i == 1 {
ticks = append(ticks, Tick{Value: val, Label: formatFloatTick(val, t.Prec)})
}
ticks = append(ticks, Tick{Value: val * float64(i)})
}
val *= 10
}
ticks = append(ticks, Tick{Value: val, Label: formatFloatTick(val, t.Prec)})
return ticks
}
// ConstantTicks is suitable for the Ticker field of an Axis.
// This function returns the given set of ticks.
type ConstantTicks []Tick
var _ Ticker = ConstantTicks{}
// Ticks returns Ticks in a specified range
func (ts ConstantTicks) Ticks(float64, float64, int) []Tick {
return ts
}
// UnixTimeIn returns a time conversion function for the given location.
func UnixTimeIn(loc *time.Location) func(t float64) time.Time {
return func(t float64) time.Time {
return time.Unix(int64(t), 0).In(loc)
}
}
// UTCUnixTime is the default time conversion for TimeTicks.
var UTCUnixTime = UnixTimeIn(time.UTC)
// TimeTicks is suitable for axes representing time values.
type TimeTicks struct {
// Ticker is used to generate a set of ticks.
// If nil, DefaultTicks will be used.
Ticker Ticker
// Format is the textual representation of the time value.
// If empty, time.RFC3339 will be used
Format string
// Time takes a float32 value and converts it into a time.Time.
// If nil, UTCUnixTime is used.
Time func(t float64) time.Time
}
var _ Ticker = TimeTicks{}
// Ticks implements plot.Ticker.
func (t TimeTicks) Ticks(mn, mx float64, nticks int) []Tick {
if t.Ticker == nil {
t.Ticker = DefaultTicks{}
}
if t.Format == "" {
t.Format = time.RFC3339
}
if t.Time == nil {
t.Time = UTCUnixTime
}
if nticks < 2 {
return nil
}
ticks := t.Ticker.Ticks(mn, mx, nticks)
for i := range ticks {
tick := &ticks[i]
if tick.Label == "" {
continue
}
tick.Label = t.Time(tick.Value).Format(t.Format)
}
return ticks
}
/*
// lengthOffset returns an offset that should be added to the
// tick mark's line to accout for its length. I.e., the start of
// the line for a minor tick mark must be shifted by half of
// the length.
func (t Tick) lengthOffset(len vg.Length) vg.Length {
if t.IsMinor() {
return len / 2
}
return 0
}
// tickLabelHeight returns height of the tick mark labels.
func tickLabelHeight(sty text.Style, ticks []Tick) vg.Length {
maxHeight := vg.Length(0)
for _, t := range ticks {
if t.IsMinor() {
continue
}
r := sty.Rectangle(t.Label)
h := r.Max.Y - r.Min.Y
if h > maxHeight {
maxHeight = h
}
}
return maxHeight
}
// tickLabelWidth returns the width of the widest tick mark label.
func tickLabelWidth(sty text.Style, ticks []Tick) vg.Length {
maxWidth := vg.Length(0)
for _, t := range ticks {
if t.IsMinor() {
continue
}
r := sty.Rectangle(t.Label)
w := r.Max.X - r.Min.X
if w > maxWidth {
maxWidth = w
}
}
return maxWidth
}
*/
// formatFloatTick returns a g-formated string representation of v
// to the specified precision.
func formatFloatTick(v float64, prec int) string {
return strconv.FormatFloat(float64(v), 'g', prec, 64)
}
// // TickerFunc is suitable for the Ticker field of an Axis.
// // It is an adapter which allows to quickly setup a Ticker using a function with an appropriate signature.
// type TickerFunc func(min, max float64) []Tick
//
// var _ Ticker = TickerFunc(nil)
//
// // Ticks implements plot.Ticker.
// func (f TickerFunc) Ticks(min, max float64) []Tick {
// return f(min, max)
// }
// Code generated by "core generate -add-types"; DO NOT EDIT.
package plot
import (
"image"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/units"
"cogentcore.org/core/text/rich"
"cogentcore.org/core/types"
)
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.AxisScales", IDName: "axis-scales", Doc: "AxisScales are the scaling options for how values are distributed\nalong an axis: Linear, Log, etc."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.AxisStyle", IDName: "axis-style", Doc: "AxisStyle has style properties for the axis.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "On", Doc: "On determines whether the axis is rendered."}, {Name: "Text", Doc: "Text has the text style parameters for the text label."}, {Name: "Line", Doc: "Line has styling properties for the axis line."}, {Name: "Padding", Doc: "Padding between the axis line and the data. Having\nnon-zero padding ensures that the data is never drawn\non the axis, thus making it easier to see."}, {Name: "NTicks", Doc: "NTicks is the desired number of ticks (actual likely\nwill be different). If < 2 then the axis will not be drawn."}, {Name: "Scale", Doc: "Scale specifies how values are scaled along the axis:\nLinear, Log, Inverted"}, {Name: "TickText", Doc: "TickText has the text style for rendering tick labels,\nand is shared for actual rendering."}, {Name: "TickLine", Doc: "TickLine has line style for drawing tick lines."}, {Name: "TickLength", Doc: "TickLength is the length of tick lines."}}})
// SetOn sets the [AxisStyle.On]:
// On determines whether the axis is rendered.
func (t *AxisStyle) SetOn(v bool) *AxisStyle { t.On = v; return t }
// SetText sets the [AxisStyle.Text]:
// Text has the text style parameters for the text label.
func (t *AxisStyle) SetText(v TextStyle) *AxisStyle { t.Text = v; return t }
// SetLine sets the [AxisStyle.Line]:
// Line has styling properties for the axis line.
func (t *AxisStyle) SetLine(v LineStyle) *AxisStyle { t.Line = v; return t }
// SetPadding sets the [AxisStyle.Padding]:
// Padding between the axis line and the data. Having
// non-zero padding ensures that the data is never drawn
// on the axis, thus making it easier to see.
func (t *AxisStyle) SetPadding(v units.Value) *AxisStyle { t.Padding = v; return t }
// SetNTicks sets the [AxisStyle.NTicks]:
// NTicks is the desired number of ticks (actual likely
// will be different). If < 2 then the axis will not be drawn.
func (t *AxisStyle) SetNTicks(v int) *AxisStyle { t.NTicks = v; return t }
// SetScale sets the [AxisStyle.Scale]:
// Scale specifies how values are scaled along the axis:
// Linear, Log, Inverted
func (t *AxisStyle) SetScale(v AxisScales) *AxisStyle { t.Scale = v; return t }
// SetTickText sets the [AxisStyle.TickText]:
// TickText has the text style for rendering tick labels,
// and is shared for actual rendering.
func (t *AxisStyle) SetTickText(v TextStyle) *AxisStyle { t.TickText = v; return t }
// SetTickLine sets the [AxisStyle.TickLine]:
// TickLine has line style for drawing tick lines.
func (t *AxisStyle) SetTickLine(v LineStyle) *AxisStyle { t.TickLine = v; return t }
// SetTickLength sets the [AxisStyle.TickLength]:
// TickLength is the length of tick lines.
func (t *AxisStyle) SetTickLength(v units.Value) *AxisStyle { t.TickLength = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Axis", IDName: "axis", Doc: "Axis represents either a horizontal or vertical axis of a plot.\nThis is the \"internal\" data structure and should not be used for styling.", Fields: []types.Field{{Name: "Range", Doc: "Range has the Min, Max range of values for the axis (in raw data units.)"}, {Name: "Axis", Doc: "specifies which axis this is: X, Y or Z."}, {Name: "RightY", Doc: "For a Y axis, this puts the axis on the right (i.e., the second Y axis)."}, {Name: "Label", Doc: "Label for the axis."}, {Name: "Style", Doc: "Style has the style parameters for the Axis,\ncopied from [PlotStyle] source."}, {Name: "TickText", Doc: "TickText is used for rendering the tick text labels."}, {Name: "Ticker", Doc: "Ticker generates the tick marks. Any tick marks\nreturned by the Marker function that are not in\nrange of the axis are not drawn."}, {Name: "Scale", Doc: "Scale transforms a value given in the data coordinate system\nto the normalized coordinate system of the axis—its distance\nalong the axis as a fraction of the axis range."}, {Name: "AutoRescale", Doc: "AutoRescale enables an axis to automatically adapt its minimum\nand maximum boundaries, according to its underlying Ticker."}, {Name: "ticks", Doc: "cached list of ticks, set in size"}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Normalizer", IDName: "normalizer", Doc: "Normalizer rescales values from the data coordinate system to the\nnormalized coordinate system.", Methods: []types.Method{{Name: "Normalize", Doc: "Normalize transforms a value x in the data coordinate system to\nthe normalized coordinate system.", Args: []string{"min", "max", "x"}, Returns: []string{"float64"}}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.LinearScale", IDName: "linear-scale", Doc: "LinearScale an be used as the value of an Axis.Scale function to\nset the axis to a standard linear scale."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.LogScale", IDName: "log-scale", Doc: "LogScale can be used as the value of an Axis.Scale function to\nset the axis to a log scale."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.InvertedScale", IDName: "inverted-scale", Doc: "InvertedScale can be used as the value of an Axis.Scale function to\ninvert the axis using any Normalizer.", Embeds: []types.Field{{Name: "Normalizer"}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.VirtualAxisStyle", IDName: "virtual-axis-style", Doc: "VirtualAxisStyle has style properties for a virtual (non-plotted) axis.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "Scale", Doc: "Scale specifies how values are scaled along the axis:\nLinear, Log, Inverted"}}})
// SetScale sets the [VirtualAxisStyle.Scale]:
// Scale specifies how values are scaled along the axis:
// Linear, Log, Inverted
func (t *VirtualAxisStyle) SetScale(v AxisScales) *VirtualAxisStyle { t.Scale = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.VirtualAxis", IDName: "virtual-axis", Doc: "VirtualAxis represents a data role that is not plotted as a visible axis,\nsuch as the Size role controlling size of points.\nThis is the \"internal\" data structure and should not be used for styling.", Fields: []types.Field{{Name: "Range", Doc: "Range has the Min, Max range of values for the axis (in raw data units.)"}, {Name: "Style", Doc: "Style has the style parameters for the Axis,\ncopied from [PlotStyle] source."}, {Name: "Scale", Doc: "Scale transforms a value given in the data coordinate system\nto the normalized coordinate system of the axis—its distance\nalong the axis as a fraction of the axis range."}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Data", IDName: "data", Doc: "Data is a map of Roles and Data for that Role, providing the\nprimary way of passing data to a Plotter"})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Valuer", IDName: "valuer", Doc: "Valuer is the data interface for plotting, supporting either\nfloat64 or string representations. It is satisfied by the tensor.Tensor\ninterface, so a tensor can be used directly for plot Data.", Methods: []types.Method{{Name: "Len", Doc: "Len returns the number of values.", Returns: []string{"int"}}, {Name: "Float1D", Doc: "Float1D(i int) returns float64 value at given index.", Args: []string{"i"}, Returns: []string{"float64"}}, {Name: "String1D", Doc: "String1D(i int) returns string value at given index.", Args: []string{"i"}, Returns: []string{"string"}}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Roles", IDName: "roles", Doc: "Roles are the roles that a given set of data values can play,\ndesigned to be sufficiently generalizable across all different\ntypes of plots, even if sometimes it is a bit of a stretch."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Values", IDName: "values", Doc: "Values provides a minimal implementation of the Data interface\nusing a slice of float64."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Labels", IDName: "labels", Doc: "Labels provides a minimal implementation of the Data interface\nusing a slice of string. It always returns 0 for Float1D."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.selection", IDName: "selection", Fields: []types.Field{{Name: "n", Doc: "n is the number of labels selected."}, {Name: "lMin", Doc: "lMin and lMax are the selected min\nand max label values. lq is the q\nchosen."}, {Name: "lMax", Doc: "lMin and lMax are the selected min\nand max label values. lq is the q\nchosen."}, {Name: "lStep", Doc: "lMin and lMax are the selected min\nand max label values. lq is the q\nchosen."}, {Name: "lq", Doc: "lMin and lMax are the selected min\nand max label values. lq is the q\nchosen."}, {Name: "score", Doc: "score is the score for the selection."}, {Name: "magnitude", Doc: "magnitude is the magnitude of the\nlabel step distance."}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.weights", IDName: "weights", Doc: "weights is a helper type to calcuate the labelling scheme's total score.", Fields: []types.Field{{Name: "simplicity"}, {Name: "coverage"}, {Name: "density"}, {Name: "legibility"}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.LegendStyle", IDName: "legend-style", Doc: "LegendStyle has the styling properties for the Legend.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "Column", Doc: "Column is for table-based plotting, specifying the column with legend values."}, {Name: "Text", Doc: "Text is the style given to the legend entry texts."}, {Name: "Position", Doc: "position of the legend"}, {Name: "ThumbnailWidth", Doc: "ThumbnailWidth is the width of legend thumbnails."}, {Name: "Fill", Doc: "Fill specifies the background fill color for the legend box,\nif non-nil."}}})
// SetColumn sets the [LegendStyle.Column]:
// Column is for table-based plotting, specifying the column with legend values.
func (t *LegendStyle) SetColumn(v string) *LegendStyle { t.Column = v; return t }
// SetText sets the [LegendStyle.Text]:
// Text is the style given to the legend entry texts.
func (t *LegendStyle) SetText(v TextStyle) *LegendStyle { t.Text = v; return t }
// SetPosition sets the [LegendStyle.Position]:
// position of the legend
func (t *LegendStyle) SetPosition(v LegendPosition) *LegendStyle { t.Position = v; return t }
// SetThumbnailWidth sets the [LegendStyle.ThumbnailWidth]:
// ThumbnailWidth is the width of legend thumbnails.
func (t *LegendStyle) SetThumbnailWidth(v units.Value) *LegendStyle { t.ThumbnailWidth = v; return t }
// SetFill sets the [LegendStyle.Fill]:
// Fill specifies the background fill color for the legend box,
// if non-nil.
func (t *LegendStyle) SetFill(v image.Image) *LegendStyle { t.Fill = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.LegendPosition", IDName: "legend-position", Doc: "LegendPosition specifies where to put the legend", Fields: []types.Field{{Name: "Top", Doc: "Top and Left specify the location of the legend."}, {Name: "Left", Doc: "Top and Left specify the location of the legend."}, {Name: "XOffs", Doc: "XOffs and YOffs are added to the legend's final position,\nrelative to the relevant anchor position"}, {Name: "YOffs", Doc: "XOffs and YOffs are added to the legend's final position,\nrelative to the relevant anchor position"}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Legend", IDName: "legend", Doc: "A Legend gives a description of the meaning of different\ndata elements of the plot. Each legend entry has a name\nand a thumbnail, where the thumbnail shows a small\nsample of the display style of the corresponding data.", Fields: []types.Field{{Name: "Style", Doc: "Style has the legend styling parameters."}, {Name: "Entries", Doc: "Entries are all of the LegendEntries described by this legend."}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Thumbnailer", IDName: "thumbnailer", Doc: "Thumbnailer wraps the Thumbnail method, which draws the small\nimage in a legend representing the style of data.", Methods: []types.Method{{Name: "Thumbnail", Doc: "Thumbnail draws an thumbnail representing a legend entry.\nThe thumbnail will usually show a smaller representation\nof the style used to plot the corresponding data.", Args: []string{"pt"}}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.LegendEntry", IDName: "legend-entry", Doc: "A LegendEntry represents a single line of a legend, it\nhas a name and an icon.", Fields: []types.Field{{Name: "Text", Doc: "text is the text associated with this entry."}, {Name: "Thumbs", Doc: "thumbs is a slice of all of the thumbnails styles"}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.LineStyle", IDName: "line-style", Doc: "LineStyle has style properties for drawing lines.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "On", Doc: "On indicates whether to plot lines."}, {Name: "Color", Doc: "Color is the stroke color image specification.\nSetting to nil turns line off."}, {Name: "Width", Doc: "Width is the line width, with a default of 1 Pt (point).\nSetting to 0 turns line off."}, {Name: "Dashes", Doc: "Dashes are the dashes of the stroke. Each pair of values specifies\nthe amount to paint and then the amount to skip."}, {Name: "Fill", Doc: "Fill is the color to fill solid regions, in a plot-specific\nway (e.g., the area below a Line plot, the bar color).\nUse nil to disable filling."}, {Name: "NegativeX", Doc: "NegativeX specifies whether to draw lines that connect points with a negative\nX-axis direction; otherwise there is a break in the line.\ndefault is false, so that repeated series of data across the X axis\nare plotted separately."}, {Name: "Step", Doc: "Step specifies how to step the line between points."}}})
// SetOn sets the [LineStyle.On]:
// On indicates whether to plot lines.
func (t *LineStyle) SetOn(v DefaultOffOn) *LineStyle { t.On = v; return t }
// SetColor sets the [LineStyle.Color]:
// Color is the stroke color image specification.
// Setting to nil turns line off.
func (t *LineStyle) SetColor(v image.Image) *LineStyle { t.Color = v; return t }
// SetWidth sets the [LineStyle.Width]:
// Width is the line width, with a default of 1 Pt (point).
// Setting to 0 turns line off.
func (t *LineStyle) SetWidth(v units.Value) *LineStyle { t.Width = v; return t }
// SetDashes sets the [LineStyle.Dashes]:
// Dashes are the dashes of the stroke. Each pair of values specifies
// the amount to paint and then the amount to skip.
func (t *LineStyle) SetDashes(v ...float32) *LineStyle { t.Dashes = v; return t }
// SetFill sets the [LineStyle.Fill]:
// Fill is the color to fill solid regions, in a plot-specific
// way (e.g., the area below a Line plot, the bar color).
// Use nil to disable filling.
func (t *LineStyle) SetFill(v image.Image) *LineStyle { t.Fill = v; return t }
// SetNegativeX sets the [LineStyle.NegativeX]:
// NegativeX specifies whether to draw lines that connect points with a negative
// X-axis direction; otherwise there is a break in the line.
// default is false, so that repeated series of data across the X axis
// are plotted separately.
func (t *LineStyle) SetNegativeX(v bool) *LineStyle { t.NegativeX = v; return t }
// SetStep sets the [LineStyle.Step]:
// Step specifies how to step the line between points.
func (t *LineStyle) SetStep(v StepKind) *LineStyle { t.Step = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.StepKind", IDName: "step-kind", Doc: "StepKind specifies a form of a connection of two consecutive points."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.XAxisStyle", IDName: "x-axis-style", Doc: "XAxisStyle has overall plot level styling properties for the XAxis.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "Column", Doc: "Column specifies the column to use for the common X axis,\nfor [plot.NewTablePlot] [table.Table] driven plots.\nIf empty, standard Group-based role binding is used: the last column\nwithin the same group with Role=X is used."}, {Name: "Rotation", Doc: "Rotation is the rotation of the X Axis labels, in degrees."}, {Name: "Label", Doc: "Label is the optional label to use for the XAxis instead of the default."}, {Name: "Range", Doc: "Range is the effective range of XAxis data to plot, where either end can be fixed."}, {Name: "Scale", Doc: "Scale specifies how values are scaled along the X axis:\nLinear, Log, Inverted"}}})
// SetColumn sets the [XAxisStyle.Column]:
// Column specifies the column to use for the common X axis,
// for [plot.NewTablePlot] [table.Table] driven plots.
// If empty, standard Group-based role binding is used: the last column
// within the same group with Role=X is used.
func (t *XAxisStyle) SetColumn(v string) *XAxisStyle { t.Column = v; return t }
// SetRotation sets the [XAxisStyle.Rotation]:
// Rotation is the rotation of the X Axis labels, in degrees.
func (t *XAxisStyle) SetRotation(v float32) *XAxisStyle { t.Rotation = v; return t }
// SetLabel sets the [XAxisStyle.Label]:
// Label is the optional label to use for the XAxis instead of the default.
func (t *XAxisStyle) SetLabel(v string) *XAxisStyle { t.Label = v; return t }
// SetRange sets the [XAxisStyle.Range]:
// Range is the effective range of XAxis data to plot, where either end can be fixed.
func (t *XAxisStyle) SetRange(v minmax.Range64) *XAxisStyle { t.Range = v; return t }
// SetScale sets the [XAxisStyle.Scale]:
// Scale specifies how values are scaled along the X axis:
// Linear, Log, Inverted
func (t *XAxisStyle) SetScale(v AxisScales) *XAxisStyle { t.Scale = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.PlotStyle", IDName: "plot-style", Doc: "PlotStyle has overall plot level styling properties.\nSome properties provide defaults for individual elements, which can\nthen be overwritten by element-level properties.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "Title", Doc: "Title is the overall title of the plot."}, {Name: "TitleStyle", Doc: "TitleStyle is the text styling parameters for the title."}, {Name: "Background", Doc: "Background is the background of the plot.\nThe default is [colors.Scheme.Surface]."}, {Name: "Scale", Doc: "Scale multiplies the plot DPI value, to change the overall scale\nof the rendered plot. Larger numbers produce larger scaling.\nTypically use larger numbers when generating plots for inclusion in\ndocuments or other cases where the overall plot size will be small."}, {Name: "Legend", Doc: "Legend has the styling properties for the Legend."}, {Name: "Axis", Doc: "Axis has the styling properties for the Axis associated with this Data."}, {Name: "XAxis", Doc: "XAxis has plot-level properties specific to the XAxis."}, {Name: "YAxisLabel", Doc: "YAxisLabel is the optional label to use for the YAxis instead of the default."}, {Name: "SizeAxis", Doc: "SizeAxis has plot-level properties specific to the Size virtual axis."}, {Name: "LinesOn", Doc: "LinesOn determines whether lines are plotted by default at the overall,\nPlot level, for elements that plot lines (e.g., plots.XY)."}, {Name: "LineWidth", Doc: "LineWidth sets the default line width for data plotting lines at the\noverall Plot level."}, {Name: "PointsOn", Doc: "PointsOn determines whether points are plotted by default at the\noverall Plot level, for elements that plot points (e.g., plots.XY)."}, {Name: "PointSize", Doc: "PointSize sets the default point size at the overall Plot level."}, {Name: "LabelSize", Doc: "LabelSize sets the default label text size at the overall Plot level."}, {Name: "BarWidth", Doc: "BarWidth for Bar plot sets the default width of the bars,\nwhich should be less than the Stride (1 typically) to prevent\nbar overlap. Defaults to .8."}, {Name: "ShowErrors", Doc: "ShowErrors can be set to have Plot configuration errors reported.\nThis is particularly important for table-driven plots (e.g., [plotcore.Editor]),\nbut it is not on by default because often there are transitional states\nwith known errors that can lead to false alarms."}}})
// SetTitle sets the [PlotStyle.Title]:
// Title is the overall title of the plot.
func (t *PlotStyle) SetTitle(v string) *PlotStyle { t.Title = v; return t }
// SetTitleStyle sets the [PlotStyle.TitleStyle]:
// TitleStyle is the text styling parameters for the title.
func (t *PlotStyle) SetTitleStyle(v TextStyle) *PlotStyle { t.TitleStyle = v; return t }
// SetBackground sets the [PlotStyle.Background]:
// Background is the background of the plot.
// The default is [colors.Scheme.Surface].
func (t *PlotStyle) SetBackground(v image.Image) *PlotStyle { t.Background = v; return t }
// SetScale sets the [PlotStyle.Scale]:
// Scale multiplies the plot DPI value, to change the overall scale
// of the rendered plot. Larger numbers produce larger scaling.
// Typically use larger numbers when generating plots for inclusion in
// documents or other cases where the overall plot size will be small.
func (t *PlotStyle) SetScale(v float32) *PlotStyle { t.Scale = v; return t }
// SetLegend sets the [PlotStyle.Legend]:
// Legend has the styling properties for the Legend.
func (t *PlotStyle) SetLegend(v LegendStyle) *PlotStyle { t.Legend = v; return t }
// SetAxis sets the [PlotStyle.Axis]:
// Axis has the styling properties for the Axis associated with this Data.
func (t *PlotStyle) SetAxis(v AxisStyle) *PlotStyle { t.Axis = v; return t }
// SetXAxis sets the [PlotStyle.XAxis]:
// XAxis has plot-level properties specific to the XAxis.
func (t *PlotStyle) SetXAxis(v XAxisStyle) *PlotStyle { t.XAxis = v; return t }
// SetYAxisLabel sets the [PlotStyle.YAxisLabel]:
// YAxisLabel is the optional label to use for the YAxis instead of the default.
func (t *PlotStyle) SetYAxisLabel(v string) *PlotStyle { t.YAxisLabel = v; return t }
// SetSizeAxis sets the [PlotStyle.SizeAxis]:
// SizeAxis has plot-level properties specific to the Size virtual axis.
func (t *PlotStyle) SetSizeAxis(v VirtualAxisStyle) *PlotStyle { t.SizeAxis = v; return t }
// SetLinesOn sets the [PlotStyle.LinesOn]:
// LinesOn determines whether lines are plotted by default at the overall,
// Plot level, for elements that plot lines (e.g., plots.XY).
func (t *PlotStyle) SetLinesOn(v DefaultOffOn) *PlotStyle { t.LinesOn = v; return t }
// SetLineWidth sets the [PlotStyle.LineWidth]:
// LineWidth sets the default line width for data plotting lines at the
// overall Plot level.
func (t *PlotStyle) SetLineWidth(v units.Value) *PlotStyle { t.LineWidth = v; return t }
// SetPointsOn sets the [PlotStyle.PointsOn]:
// PointsOn determines whether points are plotted by default at the
// overall Plot level, for elements that plot points (e.g., plots.XY).
func (t *PlotStyle) SetPointsOn(v DefaultOffOn) *PlotStyle { t.PointsOn = v; return t }
// SetPointSize sets the [PlotStyle.PointSize]:
// PointSize sets the default point size at the overall Plot level.
func (t *PlotStyle) SetPointSize(v units.Value) *PlotStyle { t.PointSize = v; return t }
// SetLabelSize sets the [PlotStyle.LabelSize]:
// LabelSize sets the default label text size at the overall Plot level.
func (t *PlotStyle) SetLabelSize(v units.Value) *PlotStyle { t.LabelSize = v; return t }
// SetBarWidth sets the [PlotStyle.BarWidth]:
// BarWidth for Bar plot sets the default width of the bars,
// which should be less than the Stride (1 typically) to prevent
// bar overlap. Defaults to .8.
func (t *PlotStyle) SetBarWidth(v float64) *PlotStyle { t.BarWidth = v; return t }
// SetShowErrors sets the [PlotStyle.ShowErrors]:
// ShowErrors can be set to have Plot configuration errors reported.
// This is particularly important for table-driven plots (e.g., [plotcore.Editor]),
// but it is not on by default because often there are transitional states
// with known errors that can lead to false alarms.
func (t *PlotStyle) SetShowErrors(v bool) *PlotStyle { t.ShowErrors = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.PanZoom", IDName: "pan-zoom", Doc: "PanZoom provides post-styling pan and zoom range manipulation.", Fields: []types.Field{{Name: "XOffset", Doc: "XOffset adds offset to X range (pan)."}, {Name: "XScale", Doc: "XScale multiplies X range (zoom)."}, {Name: "YOffset", Doc: "YOffset adds offset to Y range (pan)."}, {Name: "YScale", Doc: "YScale multiplies Y range (zoom)."}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Plot", IDName: "plot", Doc: "Plot is the basic type representing a plot.\nIt renders into its own image.RGBA Pixels image,\nand can also save a corresponding SVG version.", Fields: []types.Field{{Name: "Title", Doc: "Title of the plot"}, {Name: "Style", Doc: "Style has the styling properties for the plot.\nAll end-user configuration should be put in here,\nrather than modifying other fields directly on the plot."}, {Name: "StandardTextStyle", Doc: "standard text style with default options"}, {Name: "X", Doc: "X, Y, YR, and Z are the horizontal, vertical, right vertical, and depth axes\nof the plot respectively. These are the actual compiled\nstate data and should not be used for styling: use Style."}, {Name: "Y", Doc: "X, Y, YR, and Z are the horizontal, vertical, right vertical, and depth axes\nof the plot respectively. These are the actual compiled\nstate data and should not be used for styling: use Style."}, {Name: "YR", Doc: "X, Y, YR, and Z are the horizontal, vertical, right vertical, and depth axes\nof the plot respectively. These are the actual compiled\nstate data and should not be used for styling: use Style."}, {Name: "Z", Doc: "X, Y, YR, and Z are the horizontal, vertical, right vertical, and depth axes\nof the plot respectively. These are the actual compiled\nstate data and should not be used for styling: use Style."}, {Name: "SizeAxis", Doc: "SizeAxis is a virtual axis for the Size data role."}, {Name: "Legend", Doc: "Legend is the plot's legend."}, {Name: "Plotters", Doc: "Plotters are drawn by calling their Plot method after the axes are drawn."}, {Name: "PanZoom", Doc: "PanZoom provides post-styling pan and zoom range factors."}, {Name: "HighlightPlotter", Doc: "HighlightPlotter is the Plotter to highlight. Used for mouse hovering for example.\nIt is the responsibility of the Plotter Plot function to implement highlighting."}, {Name: "HighlightIndex", Doc: "HighlightIndex is the index of the data point to highlight, for HighlightPlotter."}, {Name: "TextShaper", Doc: "TextShaper for shaping text. Can set to a shared external one,\nor else the shared plotShaper is used under a mutex lock during Render."}, {Name: "PaintBox", Doc: "PaintBox is the bounding box for the plot within the Paint.\nFor standalone, it is the size of the image."}, {Name: "PlotBox", Doc: "Current local plot bounding box in image coordinates, for computing\nplotting coordinates."}, {Name: "Painter", Doc: "Painter is the current painter being used,\nwhich is only valid during rendering, and is set by Draw function.\nIt needs to be exported for different plot types in other packages."}, {Name: "unitContext", Doc: "unitContext is current unit context, only valid during rendering."}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Plotter", IDName: "plotter", Doc: "Plotter is an interface that wraps the Plot method.\nStandard implementations of Plotter are in the [plots] package.", Methods: []types.Method{{Name: "Plot", Doc: "Plot draws the data to the Plot Paint.", Args: []string{"pt"}}, {Name: "UpdateRange", Doc: "UpdateRange updates the given ranges.", Args: []string{"plt", "x", "y", "yr", "z", "size"}}, {Name: "Data", Doc: "Data returns the data by roles for this plot, for both the original\ndata and the pixel-transformed X,Y coordinates for that data.\nThis allows a GUI interface to inspect data etc.", Returns: []string{"data", "pixX", "pixY"}}, {Name: "Stylers", Doc: "Stylers returns the styler functions for this element.", Returns: []string{"Stylers"}}, {Name: "ApplyStyle", Doc: "ApplyStyle applies any stylers to this element,\nfirst initializing from the given global plot style, which has\nalready been styled with defaults and all the plot element stylers.", Args: []string{"plotStyle", "idx"}}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.PlotterType", IDName: "plotter-type", Doc: "PlotterType registers a Plotter so that it can be created with appropriate data.", Fields: []types.Field{{Name: "Name", Doc: "Name of the plot type."}, {Name: "Doc", Doc: "Doc is the documentation for this Plotter."}, {Name: "Required", Doc: "Required Data roles for this plot. Data for these Roles must be provided."}, {Name: "Optional", Doc: "Optional Data roles for this plot."}, {Name: "New", Doc: "New returns a new plotter of this type with given data in given roles."}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.PlotterName", IDName: "plotter-name", Doc: "PlotterName is the name of a specific plotter type."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.PointStyle", IDName: "point-style", Doc: "PointStyle has style properties for drawing points as different shapes.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "On", Doc: "On indicates whether to plot points."}, {Name: "Shape", Doc: "Shape to draw."}, {Name: "Color", Doc: "Color is the stroke color image specification.\nSetting to nil turns stroke off. See also [PointStyle.ColorFunc]."}, {Name: "Fill", Doc: "Fill is the color to fill points.\nUse nil to disable filling. See also [PointStyle.FillFunc]."}, {Name: "ColorFunc", Doc: "ColorFunc, if non-nil, is used instead of [PointStyle.Color].\nThe function returns the stroke color to use for a given point index."}, {Name: "FillFunc", Doc: "FillFunc, if non-nil, is used instead of [PointStyle.Fill].\nThe function returns the fill color to use for a given point index."}, {Name: "Width", Doc: "Width is the line width for point glyphs, with a default of 1 Pt (point).\nSetting to 0 turns line off."}, {Name: "Size", Doc: "Size of shape to draw for each point.\nDefaults to 3 Pt (point)."}}})
// SetOn sets the [PointStyle.On]:
// On indicates whether to plot points.
func (t *PointStyle) SetOn(v DefaultOffOn) *PointStyle { t.On = v; return t }
// SetShape sets the [PointStyle.Shape]:
// Shape to draw.
func (t *PointStyle) SetShape(v Shapes) *PointStyle { t.Shape = v; return t }
// SetColor sets the [PointStyle.Color]:
// Color is the stroke color image specification.
// Setting to nil turns stroke off. See also [PointStyle.ColorFunc].
func (t *PointStyle) SetColor(v image.Image) *PointStyle { t.Color = v; return t }
// SetFill sets the [PointStyle.Fill]:
// Fill is the color to fill points.
// Use nil to disable filling. See also [PointStyle.FillFunc].
func (t *PointStyle) SetFill(v image.Image) *PointStyle { t.Fill = v; return t }
// SetColorFunc sets the [PointStyle.ColorFunc]:
// ColorFunc, if non-nil, is used instead of [PointStyle.Color].
// The function returns the stroke color to use for a given point index.
func (t *PointStyle) SetColorFunc(v func(i int) image.Image) *PointStyle { t.ColorFunc = v; return t }
// SetFillFunc sets the [PointStyle.FillFunc]:
// FillFunc, if non-nil, is used instead of [PointStyle.Fill].
// The function returns the fill color to use for a given point index.
func (t *PointStyle) SetFillFunc(v func(i int) image.Image) *PointStyle { t.FillFunc = v; return t }
// SetWidth sets the [PointStyle.Width]:
// Width is the line width for point glyphs, with a default of 1 Pt (point).
// Setting to 0 turns line off.
func (t *PointStyle) SetWidth(v units.Value) *PointStyle { t.Width = v; return t }
// SetSize sets the [PointStyle.Size]:
// Size of shape to draw for each point.
// Defaults to 3 Pt (point).
func (t *PointStyle) SetSize(v units.Value) *PointStyle { t.Size = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Shapes", IDName: "shapes", Doc: "Shapes has the options for how to draw points in the plot."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Style", IDName: "style", Doc: "Style contains the plot styling properties relevant across\nmost plot types. These properties apply to individual plot elements\nwhile the Plot properties applies to the overall plot itself.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "Plot", Doc: "Plot has overall plot-level properties, which can be set by any\nplot element, and are updated first, before applying element-wise styles."}, {Name: "On", Doc: "On specifies whether to plot this item, for table-based plots."}, {Name: "Plotter", Doc: "Plotter is the type of plotter to use in plotting this data,\nfor [plot.NewTablePlot] [table.Table] driven plots.\nBlank means use default ([plots.XY] is overall default)."}, {Name: "Role", Doc: "Role specifies how a particular column of data should be used,\nfor [plot.NewTablePlot] [table.Table] driven plots."}, {Name: "Group", Doc: "Group specifies a group of related data items,\nfor [plot.NewTablePlot] [table.Table] driven plots,\nwhere different columns of data within the same Group play different Roles."}, {Name: "Range", Doc: "Range is the effective range of data to plot, where either end can be fixed."}, {Name: "Label", Doc: "Label provides an alternative label to use for axis, if set."}, {Name: "NoLegend", Doc: "NoLegend excludes this item from the legend when it otherwise would be included,\nfor [plot.NewTablePlot] [table.Table] driven plots.\nRole = Y values are included in the Legend by default."}, {Name: "RightY", Doc: "RightY specifies that this should use the right-side alternate Y axis."}, {Name: "NTicks", Doc: "NTicks sets the desired number of ticks for the axis, if > 0."}, {Name: "LabelSkip", Doc: "LabelSkip is the number of data points to skip between Labels.\n0 means plot the Label at every point."}, {Name: "Line", Doc: "Line has style properties for drawing lines."}, {Name: "Point", Doc: "Point has style properties for drawing points."}, {Name: "Text", Doc: "Text has style properties for rendering text."}, {Name: "Width", Doc: "Width has various plot width properties."}}})
// SetPlot sets the [Style.Plot]:
// Plot has overall plot-level properties, which can be set by any
// plot element, and are updated first, before applying element-wise styles.
func (t *Style) SetPlot(v PlotStyle) *Style { t.Plot = v; return t }
// SetOn sets the [Style.On]:
// On specifies whether to plot this item, for table-based plots.
func (t *Style) SetOn(v bool) *Style { t.On = v; return t }
// SetPlotter sets the [Style.Plotter]:
// Plotter is the type of plotter to use in plotting this data,
// for [plot.NewTablePlot] [table.Table] driven plots.
// Blank means use default ([plots.XY] is overall default).
func (t *Style) SetPlotter(v PlotterName) *Style { t.Plotter = v; return t }
// SetRole sets the [Style.Role]:
// Role specifies how a particular column of data should be used,
// for [plot.NewTablePlot] [table.Table] driven plots.
func (t *Style) SetRole(v Roles) *Style { t.Role = v; return t }
// SetGroup sets the [Style.Group]:
// Group specifies a group of related data items,
// for [plot.NewTablePlot] [table.Table] driven plots,
// where different columns of data within the same Group play different Roles.
func (t *Style) SetGroup(v string) *Style { t.Group = v; return t }
// SetRange sets the [Style.Range]:
// Range is the effective range of data to plot, where either end can be fixed.
func (t *Style) SetRange(v minmax.Range64) *Style { t.Range = v; return t }
// SetLabel sets the [Style.Label]:
// Label provides an alternative label to use for axis, if set.
func (t *Style) SetLabel(v string) *Style { t.Label = v; return t }
// SetNoLegend sets the [Style.NoLegend]:
// NoLegend excludes this item from the legend when it otherwise would be included,
// for [plot.NewTablePlot] [table.Table] driven plots.
// Role = Y values are included in the Legend by default.
func (t *Style) SetNoLegend(v bool) *Style { t.NoLegend = v; return t }
// SetRightY sets the [Style.RightY]:
// RightY specifies that this should use the right-side alternate Y axis.
func (t *Style) SetRightY(v bool) *Style { t.RightY = v; return t }
// SetNTicks sets the [Style.NTicks]:
// NTicks sets the desired number of ticks for the axis, if > 0.
func (t *Style) SetNTicks(v int) *Style { t.NTicks = v; return t }
// SetLabelSkip sets the [Style.LabelSkip]:
// LabelSkip is the number of data points to skip between Labels.
// 0 means plot the Label at every point.
func (t *Style) SetLabelSkip(v int) *Style { t.LabelSkip = v; return t }
// SetLine sets the [Style.Line]:
// Line has style properties for drawing lines.
func (t *Style) SetLine(v LineStyle) *Style { t.Line = v; return t }
// SetPoint sets the [Style.Point]:
// Point has style properties for drawing points.
func (t *Style) SetPoint(v PointStyle) *Style { t.Point = v; return t }
// SetText sets the [Style.Text]:
// Text has style properties for rendering text.
func (t *Style) SetText(v TextStyle) *Style { t.Text = v; return t }
// SetWidth sets the [Style.Width]:
// Width has various plot width properties.
func (t *Style) SetWidth(v WidthStyle) *Style { t.Width = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.WidthStyle", IDName: "width-style", Doc: "WidthStyle contains various plot width properties relevant across\ndifferent plot types.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "Cap", Doc: "Cap is the width of the caps drawn at the top of error bars.\nThe default is 10dp"}, {Name: "Offset", Doc: "Offset for Bar plot is the offset added to each X axis value\nrelative to the Stride computed value (X = offset + index * Stride)\nDefaults to 0."}, {Name: "Stride", Doc: "Stride for Bar plot is distance between bars. Defaults to 1."}, {Name: "Width", Doc: "Width for Bar plot is the width of the bars, as a fraction of the Stride,\nto prevent bar overlap. Defaults to .8."}, {Name: "Pad", Doc: "Pad for Bar plot is additional space at start / end of data range,\nto keep bars from overflowing ends. This amount is subtracted from Offset\nand added to (len(Values)-1)*Stride -- no other accommodation for bar\nwidth is provided, so that should be built into this value as well.\nDefaults to 1."}}})
// SetCap sets the [WidthStyle.Cap]:
// Cap is the width of the caps drawn at the top of error bars.
// The default is 10dp
func (t *WidthStyle) SetCap(v units.Value) *WidthStyle { t.Cap = v; return t }
// SetOffset sets the [WidthStyle.Offset]:
// Offset for Bar plot is the offset added to each X axis value
// relative to the Stride computed value (X = offset + index * Stride)
// Defaults to 0.
func (t *WidthStyle) SetOffset(v float64) *WidthStyle { t.Offset = v; return t }
// SetStride sets the [WidthStyle.Stride]:
// Stride for Bar plot is distance between bars. Defaults to 1.
func (t *WidthStyle) SetStride(v float64) *WidthStyle { t.Stride = v; return t }
// SetWidth sets the [WidthStyle.Width]:
// Width for Bar plot is the width of the bars, as a fraction of the Stride,
// to prevent bar overlap. Defaults to .8.
func (t *WidthStyle) SetWidth(v float64) *WidthStyle { t.Width = v; return t }
// SetPad sets the [WidthStyle.Pad]:
// Pad for Bar plot is additional space at start / end of data range,
// to keep bars from overflowing ends. This amount is subtracted from Offset
// and added to (len(Values)-1)*Stride -- no other accommodation for bar
// width is provided, so that should be built into this value as well.
// Defaults to 1.
func (t *WidthStyle) SetPad(v float64) *WidthStyle { t.Pad = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Stylers", IDName: "stylers", Doc: "Stylers is a list of styling functions that set Style properties.\nThese are called in the order added."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.DefaultOffOn", IDName: "default-off-on", Doc: "DefaultOffOn specifies whether to use the default value for a bool option,\nor to override the default and set Off or On."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.pitem", IDName: "pitem", Fields: []types.Field{{Name: "ptyp"}, {Name: "pt"}, {Name: "data"}, {Name: "lbl"}, {Name: "ci"}, {Name: "clr"}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.TextStyle", IDName: "text-style", Doc: "TextStyle specifies styling parameters for Text elements.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"-setters"}}}, Fields: []types.Field{{Name: "Size", Doc: "Size of font to render. Default is 16dp"}, {Name: "Family", Doc: "Family name for font (inherited): ordered list of comma-separated names\nfrom more general to more specific to use. Use split on, to parse."}, {Name: "Color", Doc: "Color of text."}, {Name: "Align", Doc: "Align specifies how to align text along the relevant\ndimension for the text element."}, {Name: "Padding", Doc: "Padding is used in a case-dependent manner to add\nspace around text elements."}, {Name: "Rotation", Doc: "Rotation of the text, in degrees."}, {Name: "Offset", Doc: "Offset is added directly to the final label location."}}})
// SetSize sets the [TextStyle.Size]:
// Size of font to render. Default is 16dp
func (t *TextStyle) SetSize(v units.Value) *TextStyle { t.Size = v; return t }
// SetFamily sets the [TextStyle.Family]:
// Family name for font (inherited): ordered list of comma-separated names
// from more general to more specific to use. Use split on, to parse.
func (t *TextStyle) SetFamily(v rich.Family) *TextStyle { t.Family = v; return t }
// SetColor sets the [TextStyle.Color]:
// Color of text.
func (t *TextStyle) SetColor(v image.Image) *TextStyle { t.Color = v; return t }
// SetAlign sets the [TextStyle.Align]:
// Align specifies how to align text along the relevant
// dimension for the text element.
func (t *TextStyle) SetAlign(v styles.Aligns) *TextStyle { t.Align = v; return t }
// SetPadding sets the [TextStyle.Padding]:
// Padding is used in a case-dependent manner to add
// space around text elements.
func (t *TextStyle) SetPadding(v units.Value) *TextStyle { t.Padding = v; return t }
// SetRotation sets the [TextStyle.Rotation]:
// Rotation of the text, in degrees.
func (t *TextStyle) SetRotation(v float32) *TextStyle { t.Rotation = v; return t }
// SetOffset sets the [TextStyle.Offset]:
// Offset is added directly to the final label location.
func (t *TextStyle) SetOffset(v units.XY) *TextStyle { t.Offset = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Text", IDName: "text", Doc: "Text specifies a single text element in a plot", Fields: []types.Field{{Name: "Text", Doc: "text string, which can use HTML formatting"}, {Name: "Style", Doc: "styling for this text element"}, {Name: "font", Doc: "font has the font rendering styles."}, {Name: "textStyle", Doc: "textStyle has the text rendering styles."}, {Name: "PaintText", Doc: "PaintText is the [shaped.Lines] for painting the text."}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Tick", IDName: "tick", Doc: "A Tick is a single tick mark on an axis.", Fields: []types.Field{{Name: "Value", Doc: "Value is the data value marked by this Tick."}, {Name: "Label", Doc: "Label is the text to display at the tick mark.\nIf Label is an empty string then this is a minor tick mark."}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.Ticker", IDName: "ticker", Doc: "Ticker creates Ticks in a specified range", Methods: []types.Method{{Name: "Ticks", Doc: "Ticks returns Ticks in a specified range, with desired number of ticks,\nwhich can be ignored depending on the ticker type.", Args: []string{"mn", "mx", "nticks"}, Returns: []string{"Tick"}}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.DefaultTicks", IDName: "default-ticks", Doc: "DefaultTicks is suitable for the Ticker field of an Axis,\nit returns a reasonable default set of tick marks."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.LogTicks", IDName: "log-ticks", Doc: "LogTicks is suitable for the Ticker field of an Axis,\nit returns tick marks suitable for a log-scale axis.", Fields: []types.Field{{Name: "Prec", Doc: "Prec specifies the precision of tick rendering\naccording to the documentation for strconv.FormatFloat."}}})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.ConstantTicks", IDName: "constant-ticks", Doc: "ConstantTicks is suitable for the Ticker field of an Axis.\nThis function returns the given set of ticks."})
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plot.TimeTicks", IDName: "time-ticks", Doc: "TimeTicks is suitable for axes representing time values.", Fields: []types.Field{{Name: "Ticker", Doc: "Ticker is used to generate a set of ticks.\nIf nil, DefaultTicks will be used."}, {Name: "Format", Doc: "Format is the textual representation of the time value.\nIf empty, time.RFC3339 will be used"}, {Name: "Time", Doc: "Time takes a float32 value and converts it into a time.Time.\nIf nil, UTCUnixTime is used."}}})
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package plotcore provides Cogent Core widgets for viewing and editing plots.
package plotcore
//go:generate core generate
import (
"fmt"
"image"
"io/fs"
"log/slog"
"os"
"path/filepath"
"slices"
"strings"
"time"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/base/reflectx"
"cogentcore.org/core/colors"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/icons"
"cogentcore.org/core/math32"
"cogentcore.org/core/paint"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/states"
"cogentcore.org/core/system"
"cogentcore.org/core/tree"
"cogentcore.org/lab/plot"
"cogentcore.org/lab/plot/plots"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorcore"
"golang.org/x/exp/maps"
)
// Editor is a widget that provides an interactive 2D plot
// of selected columns of tabular data, represented by a [table.Table] into
// a [table.Table]. Other types of tabular data can be converted into this format.
// The user can change various options for the plot and also modify the underlying data.
type Editor struct { //types:add
core.Frame
// table is the table of data being plotted.
table *table.Table
// PlotStyle has the overall plot style parameters.
PlotStyle plot.PlotStyle
// plot is the plot object.
plot *plot.Plot
// current svg file
svgFile core.Filename
// current csv data file
dataFile core.Filename
// currently doing a plot
inPlot bool
columnsFrame *core.Frame
plotWidget *Plot
plotStyleModified map[string]bool
}
func (pl *Editor) CopyFieldsFrom(frm tree.Node) {
fr := frm.(*Editor)
pl.Frame.CopyFieldsFrom(&fr.Frame)
pl.PlotStyle = fr.PlotStyle
pl.setTable(fr.table)
}
// NewSubPlot returns a [Editor] with its own separate [core.Toolbar],
// suitable for a tab or other element that is not the main plot.
func NewSubPlot(parent ...tree.Node) *Editor {
fr := core.NewFrame(parent...)
tb := core.NewToolbar(fr)
pl := NewEditor(fr)
fr.Styler(func(s *styles.Style) {
s.Direction = styles.Column
s.Grow.Set(1, 1)
})
tb.Maker(pl.MakeToolbar)
return pl
}
func (pl *Editor) Init() {
pl.Frame.Init()
pl.PlotStyle.Defaults()
pl.Styler(func(s *styles.Style) {
s.Grow.Set(1, 1)
if pl.SizeClass() == core.SizeCompact {
s.Direction = styles.Column
}
})
pl.OnShow(func(e events.Event) {
pl.UpdatePlot()
})
pl.Updater(func() {
if pl.table != nil {
pl.plotStyleFromTable(pl.table)
}
})
tree.AddChildAt(pl, "columns", func(w *core.Frame) {
pl.columnsFrame = w
w.Styler(func(s *styles.Style) {
s.Direction = styles.Column
s.Background = colors.Scheme.SurfaceContainerLow
if w.SizeClass() == core.SizeCompact {
s.Grow.Set(1, 0)
} else {
s.Grow.Set(0, 1)
s.Overflow.Y = styles.OverflowAuto
}
})
w.Maker(pl.makeColumns)
})
tree.AddChildAt(pl, "plot", func(w *Plot) {
pl.plotWidget = w
w.Plot = pl.plot
w.Styler(func(s *styles.Style) {
s.Grow.Set(1, 1)
})
})
}
// setTable sets the table to view and does UpdatePlot.
func (pl *Editor) setTable(tab *table.Table) *Editor {
pl.table = tab
pl.UpdatePlot()
return pl
}
// SetTable sets the table to a new view of given table,
// and does UpdatePlot.
func (pl *Editor) SetTable(tab *table.Table) *Editor {
pl.table = nil
pl.Update() // reset
pl.table = table.NewView(tab)
pl.UpdatePlot()
pl.Update() // update to new table
return pl
}
// SetSlice sets the table to a [table.NewSliceTable] from the given slice.
// Optional styler functions are used for each struct field in sequence,
// and any can contain global plot style. See [BasicStylers] for example.
func (pl *Editor) SetSlice(sl any, stylers ...func(s *plot.Style)) *Editor {
dt, err := table.NewSliceTable(sl)
errors.Log(err)
if dt == nil {
return nil
}
mx := min(dt.NumColumns(), len(stylers))
for i := range mx {
plot.SetStyler(dt.Columns.Values[i], stylers[i])
}
return pl.SetTable(dt)
}
// SaveSVG saves the plot to an SVG file.
func (pl *Editor) SaveSVG(fname core.Filename) { //types:add
plt := pl.plotWidget.Plot
mp := plt.PaintBox.Min
plt.PaintBox = plt.PaintBox.Sub(mp)
ptr := paint.NewPainter(math32.FromPoint(plt.PaintBox.Size()))
ptr.Paint.UnitContext = pl.Styles.UnitContext // preserve DPI from current
sv := paint.RenderToSVG(plt.Draw(ptr))
err := os.WriteFile(string(fname), sv, 0666)
plt.PaintBox = plt.PaintBox.Add(mp)
if err != nil {
core.ErrorSnackbar(pl, err)
}
pl.svgFile = fname
}
// SavePDF saves the plot to a PDF file.
func (pl *Editor) SavePDF(fname core.Filename) { //types:add
plt := pl.plotWidget.Plot
mp := plt.PaintBox.Min
plt.PaintBox = plt.PaintBox.Sub(mp)
ptr := paint.NewPainter(math32.FromPoint(plt.PaintBox.Size()))
ptr.Paint.UnitContext = pl.Styles.UnitContext // preserve DPI from current
pd := paint.RenderToPDF(plt.Draw(ptr))
err := os.WriteFile(string(fname), pd, 0666)
plt.PaintBox = plt.PaintBox.Add(mp)
if err != nil {
core.ErrorSnackbar(pl, err)
}
}
// SaveImage saves the current plot as an image (e.g., png).
func (pl *Editor) SaveImage(fname core.Filename) { //types:add
plt := pl.plotWidget.Plot
mp := plt.PaintBox.Min
plt.PaintBox = plt.PaintBox.Sub(mp)
err := pl.plotWidget.Plot.SaveImage(string(fname))
plt.PaintBox = plt.PaintBox.Add(mp)
if err != nil {
core.ErrorSnackbar(pl, err)
}
}
// SaveCSV saves the Table data to a csv (comma-separated values) file with headers (any delim)
func (pl *Editor) SaveCSV(fname core.Filename, delim tensor.Delims) { //types:add
pl.table.SaveCSV(fsx.Filename(fname), delim, table.Headers)
pl.dataFile = fname
}
// SaveAll saves the current plot to a png, svg, and the data to a tsv -- full save
// Any extension is removed and appropriate extensions are added
func (pl *Editor) SaveAll(fname core.Filename) { //types:add
fn := string(fname)
fn = strings.TrimSuffix(fn, filepath.Ext(fn))
pl.SaveCSV(core.Filename(fn+".tsv"), tensor.Tab)
pl.SaveImage(core.Filename(fn + ".png"))
pl.SaveSVG(core.Filename(fn + ".svg"))
pl.SavePDF(core.Filename(fn + ".pdf"))
}
// OpenCSV opens the Table data from a csv (comma-separated values) file (or any delim)
func (pl *Editor) OpenCSV(filename core.Filename, delim tensor.Delims) { //types:add
dt := table.New()
dt.OpenCSV(fsx.Filename(filename), delim)
pl.dataFile = filename
pl.SetTable(dt)
}
// OpenFS opens the Table data from a csv (comma-separated values) file (or any delim)
// from the given filesystem.
func (pl *Editor) OpenFS(fsys fs.FS, filename core.Filename, delim tensor.Delims) {
dt := table.New()
dt.OpenFS(fsys, string(filename), delim)
pl.SetTable(dt)
}
// GoUpdatePlot updates the display based on current Indexed view into table.
// This version can be called from goroutines. It does Sequential() on
// the [table.Table], under the assumption that it is used for tracking a
// the latest updates of a running process.
func (pl *Editor) GoUpdatePlot() {
if pl == nil || pl.This == nil {
return
}
if core.TheApp.Platform() == system.Web {
time.Sleep(time.Millisecond) // critical to prevent hanging!
}
if !pl.IsVisible() || pl.table == nil || pl.inPlot {
return
}
pl.Scene.AsyncLock()
pl.table.Sequential()
pl.genPlot()
pl.NeedsRender()
pl.Scene.AsyncUnlock()
}
// UpdatePlot updates the display based on current Indexed view into table.
// It does not automatically update the [table.Table] unless it is
// nil or out date.
func (pl *Editor) UpdatePlot() {
if pl == nil || pl.This == nil {
return
}
if pl.table == nil || pl.inPlot {
return
}
if len(pl.Children) != 2 { // || len(pl.Columns) != pl.table.NumColumns() { // todo:
pl.Update()
}
if pl.table.NumRows() == 0 {
pl.table.Sequential()
}
pl.genPlot()
}
// genPlot generates a new plot from the current table.
// It surrounds operation with InPlot true / false to prevent multiple updates
func (pl *Editor) genPlot() {
if pl.inPlot {
slog.Error("plot: in plot already") // note: this never seems to happen -- could probably nuke
return
}
pl.inPlot = true
if pl.table == nil {
pl.inPlot = false
return
}
if len(pl.table.Indexes) == 0 {
pl.table.Sequential()
} else {
lsti := pl.table.Indexes[pl.table.NumRows()-1]
if lsti >= pl.table.Columns.Rows { // out of date
pl.table.Sequential()
}
}
var err error
pl.plot, err = plot.NewTablePlot(pl.table)
if pl.plot != nil && pl.plot.Style.ShowErrors && err != nil {
core.ErrorSnackbar(pl, fmt.Errorf("%s: %w", pl.PlotStyle.Title, err))
}
if pl.plot != nil {
pl.plotWidget.SetPlot(pl.plot)
// } else {
// errors.Log(fmt.Errorf("%s: nil plot: %w", pl.PlotStyle.Title, err))
}
// pl.plotWidget.updatePlot()
pl.plotWidget.NeedsRender()
pl.inPlot = false
}
const plotColumnsHeaderN = 3
// allColumnsOff turns all columns off.
func (pl *Editor) allColumnsOff() {
fr := pl.columnsFrame
for i, cli := range fr.Children {
if i < plotColumnsHeaderN {
continue
}
cl := cli.(*core.Frame)
sw := cl.Child(0).(*core.Switch)
sw.SetChecked(false)
sw.SendChange()
}
pl.Update()
}
// setColumnsByName turns columns on or off if their name contains
// the given string.
func (pl *Editor) setColumnsByName(nameContains string, on bool) { //types:add
fr := pl.columnsFrame
for i, cli := range fr.Children {
if i < plotColumnsHeaderN {
continue
}
cl := cli.(*core.Frame)
if !strings.Contains(cl.Name, nameContains) {
continue
}
sw := cl.Child(0).(*core.Switch)
sw.SetChecked(on)
sw.SendChange()
}
pl.Update()
}
// makeColumns makes the Plans for columns
func (pl *Editor) makeColumns(p *tree.Plan) {
tree.Add(p, func(w *core.Frame) {
tree.AddChild(w, func(w *core.Button) {
w.SetText("Clear").SetIcon(icons.ClearAll).SetType(core.ButtonAction)
w.SetTooltip("Turn all columns off")
w.OnClick(func(e events.Event) {
pl.allColumnsOff()
})
})
tree.AddChild(w, func(w *core.Button) {
w.SetText("Search").SetIcon(icons.Search).SetType(core.ButtonAction)
w.SetTooltip("Select columns by column name")
w.OnClick(func(e events.Event) {
core.CallFunc(pl, pl.setColumnsByName)
})
})
})
hasSplit := false // split uses different color styling
colorIdx := 0 // index for color sequence -- skips various types
tree.Add(p, func(w *core.Separator) {})
if pl.table == nil {
return
}
for ci, cl := range pl.table.Columns.Values {
cnm := pl.table.Columns.Keys[ci]
tree.AddAt(p, cnm, func(w *core.Frame) {
psty := plot.GetStylers(cl)
cst, mods, clr := pl.defaultColumnStyle(cl, ci, &colorIdx, &hasSplit, psty)
isSplit := cst.Role == plot.Split
stys := psty
stys.Add(func(s *plot.Style) {
mf := modFields(mods)
errors.Log(reflectx.CopyFields(s, cst, mf...))
errors.Log(reflectx.CopyFields(&s.Plot, &pl.PlotStyle, modFields(pl.plotStyleModified)...))
})
plot.SetStyler(cl, stys...)
w.Styler(func(s *styles.Style) {
s.CenterAll()
})
tree.AddChild(w, func(w *core.Switch) {
w.SetType(core.SwitchCheckbox).SetTooltip("Turn this column on or off")
w.Styler(func(s *styles.Style) {
s.Color = clr
})
tree.AddChildInit(w, "stack", func(w *core.Frame) {
f := func(name string) {
tree.AddChildInit(w, name, func(w *core.Icon) {
w.Styler(func(s *styles.Style) {
s.Color = clr
})
})
}
f("icon-on")
f("icon-off")
f("icon-indeterminate")
})
w.OnChange(func(e events.Event) {
mods["On"] = true
cst.On = w.IsChecked()
pl.UpdatePlot()
})
w.Updater(func() {
xaxis := cst.Role == plot.X // || cp.Column == pl.Options.Legend
w.SetState(xaxis, states.Disabled, states.Indeterminate)
if xaxis {
cst.On = false
} else {
w.SetChecked(cst.On)
}
if cst.Role == plot.Split {
isSplit = true
hasSplit = true // update global flag
} else {
if isSplit && cst.Role != plot.Split {
isSplit = false
hasSplit = false
}
}
})
})
tree.AddChild(w, func(w *core.Button) {
tt := "[Edit all styling options for this column] " + metadata.Doc(cl)
w.SetText(cnm).SetType(core.ButtonAction).SetTooltip(tt)
w.OnClick(func(e events.Event) {
update := func() {
if core.TheApp.Platform().IsMobile() {
pl.Update()
return
}
// we must be async on multi-window platforms since
// it is coming from a separate window
pl.AsyncLock()
pl.Update()
pl.AsyncUnlock()
}
d := core.NewBody(cnm + " style properties")
fm := core.NewForm(d).SetStruct(cst)
fm.Modified = mods
fm.OnChange(func(e events.Event) {
update()
})
// d.AddTopBar(func(bar *core.Frame) {
// core.NewToolbar(bar).Maker(func(p *tree.Plan) {
// tree.Add(p, func(w *core.Button) {
// w.SetText("Set x-axis").OnClick(func(e events.Event) {
// pl.Options.XAxis = cp.Column
// update()
// })
// })
// tree.Add(p, func(w *core.Button) {
// w.SetText("Set legend").OnClick(func(e events.Event) {
// pl.Options.Legend = cp.Column
// update()
// })
// })
// })
// })
d.RunWindowDialog(pl)
})
})
})
}
}
// defaultColumnStyle initializes the column style with any existing stylers
// plus additional general defaults, returning the initially modified field names.
func (pl *Editor) defaultColumnStyle(cl tensor.Values, ci int, colorIdx *int, hasSplit *bool, psty plot.Stylers) (*plot.Style, map[string]bool, image.Image) {
cst := &plot.Style{}
cst.Defaults()
if psty != nil {
psty.Run(cst)
}
if cst.On && cst.Role == plot.Split {
*hasSplit = true
}
mods := map[string]bool{}
isfloat := reflectx.KindIsFloat(cl.DataType())
if cst.Plotter == "" {
if isfloat {
cst.Plotter = plot.PlotterName(plots.XYType)
mods["Plotter"] = true
} else if cl.IsString() {
cst.Plotter = plot.PlotterName(plots.LabelsType)
mods["Plotter"] = true
}
}
if cst.Role == plot.NoRole {
mods["Role"] = true
if isfloat {
cst.Role = plot.Y
} else if cl.IsString() {
cst.Role = plot.Label
} else {
cst.Role = plot.X
}
}
clr := cst.Line.Color
if clr == colors.Scheme.OnSurface {
if cst.Role == plot.Y && isfloat {
clr = colors.Uniform(colors.Spaced(*colorIdx))
(*colorIdx)++
if !*hasSplit {
cst.Line.Color = clr
mods["Line.Color"] = true
cst.Point.Color = clr
mods["Point.Color"] = true
cst.Point.Fill = clr
mods["Point.Fill"] = true
if cst.Plotter == plots.BarType {
cst.Line.Fill = clr
mods["Line.Fill"] = true
}
}
}
}
return cst, mods, clr
}
func (pl *Editor) plotStyleFromTable(dt *table.Table) {
if pl.plotStyleModified != nil { // already set
return
}
pst := &pl.PlotStyle
mods := map[string]bool{}
pl.plotStyleModified = mods
tst := &plot.Style{}
tst.Defaults()
tst.Plot.Defaults()
for _, cl := range pl.table.Columns.Values {
stl := plot.GetStylers(cl)
if stl == nil {
continue
}
stl.Run(tst)
}
*pst = tst.Plot
if pst.PointsOn == plot.Default {
pst.PointsOn = plot.Off
mods["PointsOn"] = true
}
if pst.Title == "" {
pst.Title = metadata.Name(pl.table)
if pst.Title != "" {
mods["Title"] = true
}
}
}
// modFields returns the modified fields as field paths using . separators
func modFields(mods map[string]bool) []string {
fns := maps.Keys(mods)
rf := make([]string, 0, len(fns))
for _, f := range fns {
if mods[f] == false {
continue
}
fc := strings.ReplaceAll(f, " • ", ".")
rf = append(rf, fc)
}
slices.Sort(rf)
return rf
}
func (pl *Editor) MakeToolbar(p *tree.Plan) {
if pl.table == nil {
return
}
tree.Add(p, func(w *core.Button) {
w.SetIcon(icons.PanTool).
SetTooltip("toggle the ability to zoom and pan the view").OnClick(func(e events.Event) {
pw := pl.plotWidget
pw.SetReadOnly(!pw.IsReadOnly())
pw.Restyle()
})
})
// tree.Add(p, func(w *core.Button) {
// w.SetIcon(icons.ArrowForward).
// SetTooltip("turn on select mode for selecting Plot elements").
// OnClick(func(e events.Event) {
// fmt.Println("this will select select mode")
// })
// })
tree.Add(p, func(w *core.Separator) {})
tree.Add(p, func(w *core.Button) {
w.SetText("Update").SetIcon(icons.Update).
SetTooltip("update fully redraws display, reflecting any new settings etc").
OnClick(func(e events.Event) {
pl.UpdatePlot()
pl.Update()
})
})
tree.Add(p, func(w *core.Button) {
w.SetText("Style").SetIcon(icons.Settings).
SetTooltip("Style for how the plot is rendered").
OnClick(func(e events.Event) {
d := core.NewBody("Plot style")
fm := core.NewForm(d).SetStruct(&pl.PlotStyle)
fm.Modified = pl.plotStyleModified
fm.OnChange(func(e events.Event) {
pl.GoUpdatePlot()
})
d.RunWindowDialog(pl)
})
})
tree.Add(p, func(w *core.Button) {
w.SetText("Table").SetIcon(icons.Edit).
SetTooltip("open a Table window of the data").
OnClick(func(e events.Event) {
d := core.NewBody(pl.Name + " Data")
tv := tensorcore.NewTable(d).SetTable(pl.table)
d.AddTopBar(func(bar *core.Frame) {
core.NewToolbar(bar).Maker(tv.MakeToolbar)
})
d.RunWindowDialog(pl)
})
})
tree.Add(p, func(w *core.Separator) {})
tree.Add(p, func(w *core.Button) {
w.SetText("Save").SetIcon(icons.Save).SetMenu(func(m *core.Scene) {
core.NewFuncButton(m).SetFunc(pl.SaveSVG).SetIcon(icons.Save)
core.NewFuncButton(m).SetFunc(pl.SavePDF).SetIcon(icons.Save)
core.NewFuncButton(m).SetFunc(pl.SaveImage).SetIcon(icons.Save)
core.NewFuncButton(m).SetFunc(pl.SaveCSV).SetIcon(icons.Save)
core.NewSeparator(m)
core.NewFuncButton(m).SetFunc(pl.SaveAll).SetIcon(icons.Save)
})
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(pl.OpenCSV).SetIcon(icons.Open)
})
tree.Add(p, func(w *core.Separator) {})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(pl.table.FilterString).SetText("Filter").SetIcon(icons.FilterAlt)
w.SetAfterFunc(pl.UpdatePlot)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(pl.table.Sequential).SetText("Unfilter").SetIcon(icons.FilterAltOff)
w.SetAfterFunc(pl.UpdatePlot)
})
}
func (pt *Editor) SizeFinal() {
pt.Frame.SizeFinal()
pt.UpdatePlot()
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plotcore
import (
"fmt"
"image"
"cogentcore.org/core/core"
"cogentcore.org/core/cursors"
"cogentcore.org/core/events"
"cogentcore.org/core/events/key"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/abilities"
"cogentcore.org/core/styles/states"
"cogentcore.org/core/styles/units"
"cogentcore.org/lab/plot"
)
// Plot is a widget that renders a [plot.Plot] object to the [core.Scene]
// of this widget. If it is not [states.ReadOnly], the user can pan and zoom
// the graph. See [Editor] for an interactive interface for selecting columns to view.
type Plot struct {
core.WidgetBase
// Plot is the Plot to display in this widget.
Plot *plot.Plot `set:"-"`
// SetRangesFunc, if set, is called to adjust the data ranges
// after the point when these ranges are updated based on the plot data.
SetRangesFunc func()
}
// SetPlot sets the plot to the given [plot.Plot]. You must still call [core.WidgetBase.Update]
// to trigger a redrawing of the plot.
func (pt *Plot) SetPlot(pl *plot.Plot) *Plot {
pt.Plot = pl
pt.Plot.SetSize(pt.Geom.ContentBBox.Size())
pt.Plot.TextShaper = pt.Scene.TextShaper()
return pt
}
func (pt *Plot) Init() {
pt.WidgetBase.Init()
pt.Styler(func(s *styles.Style) {
s.Min.Set(units.Dp(512), units.Dp(384))
ro := pt.IsReadOnly()
s.SetAbilities(!ro, abilities.Slideable, abilities.Activatable, abilities.Scrollable)
if !ro {
if s.Is(states.Active) {
s.Cursor = cursors.Grabbing
s.StateLayer = 0
} else {
s.Cursor = cursors.Grab
}
}
})
pt.On(events.SlideMove, func(e events.Event) {
e.SetHandled()
if pt.Plot == nil {
return
}
xf, yf := 1.0, 1.0
if e.HasAnyModifier(key.Shift) {
yf = 0
} else if e.HasAnyModifier(key.Alt) {
xf = 0
}
del := e.PrevDelta()
dx := -float64(del.X) * (pt.Plot.X.Range.Range()) * 0.0008 * xf
dy := float64(del.Y) * (pt.Plot.Y.Range.Range()) * 0.0008 * yf
pt.Plot.PanZoom.XOffset += dx
pt.Plot.PanZoom.YOffset += dy
pt.NeedsRender()
})
pt.On(events.Scroll, func(e events.Event) {
e.SetHandled()
if pt.Plot == nil {
return
}
se := e.(*events.MouseScroll)
sc := 1 + (float64(se.Delta.Y) * 0.002)
xsc, ysc := sc, sc
if e.HasAnyModifier(key.Shift) {
ysc = 1
} else if e.HasAnyModifier(key.Alt) {
xsc = 1
}
pt.Plot.PanZoom.XScale *= xsc
pt.Plot.PanZoom.YScale *= ysc
pt.NeedsRender()
})
}
func (pt *Plot) WidgetTooltip(pos image.Point) (string, image.Point) {
if pos == image.Pt(-1, -1) {
return "_", image.Point{}
}
if pt.Plot == nil {
return pt.Tooltip, pt.DefaultTooltipPos()
}
// note: plot pixel coords are in scene coordinates, so we use pos directly.
plt, _, idx, dist, _, data, legend := pt.Plot.ClosestDataToPixel(pos.X, pos.Y)
if dist <= 10 {
pt.Plot.HighlightPlotter = plt
pt.Plot.HighlightIndex = idx
pt.NeedsRender()
dx := 0.0
if data[plot.X] != nil {
dx = data[plot.X].Float1D(idx)
}
dy := 0.0
if data[plot.Y] != nil {
dy = data[plot.Y].Float1D(idx)
}
return fmt.Sprintf("%s[%d]: (%g, %g)", legend, idx, dx, dy), pos
} else {
if pt.Plot.HighlightPlotter != nil {
pt.Plot.HighlightPlotter = nil
pt.NeedsRender()
}
}
return pt.Tooltip, pt.DefaultTooltipPos()
}
// renderPlot draws the current plot into the scene.
func (pt *Plot) renderPlot() {
if pt.Plot == nil {
return
}
if pt.SetRangesFunc != nil {
pt.SetRangesFunc()
}
pt.Plot.TextShaper = pt.Scene.TextShaper()
pt.Plot.PaintBox = pt.Geom.ContentBBox
pt.Plot.Draw(&pt.Scene.Painter)
}
func (pt *Plot) Render() {
pt.WidgetBase.Render()
pt.renderPlot()
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package plotcore
import (
"slices"
"cogentcore.org/core/core"
"cogentcore.org/lab/plot"
_ "cogentcore.org/lab/plot/plots"
"golang.org/x/exp/maps"
)
func init() {
core.AddValueType[plot.PlotterName, PlotterChooser]()
}
// PlotterChooser represents a [Plottername] value with a [core.Chooser]
// for selecting a plotter.
type PlotterChooser struct {
core.Chooser
}
func (fc *PlotterChooser) Init() {
fc.Chooser.Init()
pnms := maps.Keys(plot.Plotters)
slices.Sort(pnms)
fc.SetStrings(pnms...)
}
// Code generated by "core generate"; DO NOT EDIT.
package plotcore
import (
"cogentcore.org/core/tree"
"cogentcore.org/core/types"
"cogentcore.org/lab/plot"
)
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plotcore.Editor", IDName: "editor", Doc: "Editor is a widget that provides an interactive 2D plot\nof selected columns of tabular data, represented by a [table.Table] into\na [table.Table]. Other types of tabular data can be converted into this format.\nThe user can change various options for the plot and also modify the underlying data.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Methods: []types.Method{{Name: "SaveSVG", Doc: "SaveSVG saves the plot to an SVG file.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"fname"}}, {Name: "SavePDF", Doc: "SavePDF saves the plot to a PDF file.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"fname"}}, {Name: "SaveImage", Doc: "SaveImage saves the current plot as an image (e.g., png).", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"fname"}}, {Name: "SaveCSV", Doc: "SaveCSV saves the Table data to a csv (comma-separated values) file with headers (any delim)", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"fname", "delim"}}, {Name: "SaveAll", Doc: "SaveAll saves the current plot to a png, svg, and the data to a tsv -- full save\nAny extension is removed and appropriate extensions are added", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"fname"}}, {Name: "OpenCSV", Doc: "OpenCSV opens the Table data from a csv (comma-separated values) file (or any delim)", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename", "delim"}}, {Name: "setColumnsByName", Doc: "setColumnsByName turns columns on or off if their name contains\nthe given string.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"nameContains", "on"}}}, Embeds: []types.Field{{Name: "Frame"}}, Fields: []types.Field{{Name: "table", Doc: "table is the table of data being plotted."}, {Name: "PlotStyle", Doc: "PlotStyle has the overall plot style parameters."}, {Name: "plot", Doc: "plot is the plot object."}, {Name: "svgFile", Doc: "current svg file"}, {Name: "dataFile", Doc: "current csv data file"}, {Name: "inPlot", Doc: "currently doing a plot"}, {Name: "columnsFrame"}, {Name: "plotWidget"}, {Name: "plotStyleModified"}}})
// NewEditor returns a new [Editor] with the given optional parent:
// Editor is a widget that provides an interactive 2D plot
// of selected columns of tabular data, represented by a [table.Table] into
// a [table.Table]. Other types of tabular data can be converted into this format.
// The user can change various options for the plot and also modify the underlying data.
func NewEditor(parent ...tree.Node) *Editor { return tree.New[Editor](parent...) }
// SetPlotStyle sets the [Editor.PlotStyle]:
// PlotStyle has the overall plot style parameters.
func (t *Editor) SetPlotStyle(v plot.PlotStyle) *Editor { t.PlotStyle = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plotcore.Plot", IDName: "plot", Doc: "Plot is a widget that renders a [plot.Plot] object to the [core.Scene]\nof this widget. If it is not [states.ReadOnly], the user can pan and zoom\nthe graph. See [Editor] for an interactive interface for selecting columns to view.", Embeds: []types.Field{{Name: "WidgetBase"}}, Fields: []types.Field{{Name: "Plot", Doc: "Plot is the Plot to display in this widget."}, {Name: "SetRangesFunc", Doc: "SetRangesFunc, if set, is called to adjust the data ranges\nafter the point when these ranges are updated based on the plot data."}}})
// NewPlot returns a new [Plot] with the given optional parent:
// Plot is a widget that renders a [plot.Plot] object to the [core.Scene]
// of this widget. If it is not [states.ReadOnly], the user can pan and zoom
// the graph. See [Editor] for an interactive interface for selecting columns to view.
func NewPlot(parent ...tree.Node) *Plot { return tree.New[Plot](parent...) }
// SetSetRangesFunc sets the [Plot.SetRangesFunc]:
// SetRangesFunc, if set, is called to adjust the data ranges
// after the point when these ranges are updated based on the plot data.
func (t *Plot) SetSetRangesFunc(v func()) *Plot { t.SetRangesFunc = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/plotcore.PlotterChooser", IDName: "plotter-chooser", Doc: "PlotterChooser represents a [Plottername] value with a [core.Chooser]\nfor selecting a plotter.", Embeds: []types.Field{{Name: "Chooser"}}})
// NewPlotterChooser returns a new [PlotterChooser] with the given optional parent:
// PlotterChooser represents a [Plottername] value with a [core.Chooser]
// for selecting a plotter.
func NewPlotterChooser(parent ...tree.Node) *PlotterChooser {
return tree.New[PlotterChooser](parent...)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cluster
//go:generate core generate
import (
"fmt"
"math"
"math/rand"
"cogentcore.org/core/base/indent"
"cogentcore.org/lab/stats/stats"
"cogentcore.org/lab/tensor"
)
// todo: all of this data goes into the tensorfs
// Cluster makes a new dir, stuffs results in there!
// need a global "cwd" that it uses, so basically you cd
// to a dir, then call it.
// Node is one node in the cluster
type Node struct {
// index into original distance matrix; only valid for for terminal leaves.
Index int
// Distance value for this node, i.e., how far apart were all the kids from
// each other when this node was created. is 0 for leaf nodes
Dist float64
// ParDist is total aggregate distance from parents; The X axis offset at which our cluster starts.
ParDist float64
// Y is y-axis value for this node; if a parent, it is the average of its kids Y's,
// otherwise it counts down.
Y float64
// Kids are child nodes under this one.
Kids []*Node
}
// IsLeaf returns true if node is a leaf of the tree with no kids
func (nn *Node) IsLeaf() bool {
return len(nn.Kids) == 0
}
// Sprint prints to string
func (nn *Node) Sprint(labels tensor.Tensor, depth int) string {
if nn.IsLeaf() && labels != nil {
return labels.String1D(nn.Index) + " "
}
sv := fmt.Sprintf("\n%v%v: ", indent.Tabs(depth), nn.Dist)
for _, kn := range nn.Kids {
sv += kn.Sprint(labels, depth+1)
}
return sv
}
// Indexes collects all the indexes in this node
func (nn *Node) Indexes(ix []int, ctr *int) {
if nn.IsLeaf() {
ix[*ctr] = nn.Index
(*ctr)++
} else {
for _, kn := range nn.Kids {
kn.Indexes(ix, ctr)
}
}
}
// NewNode merges two nodes into a new node
func NewNode(na, nb *Node, dst float64) *Node {
nn := &Node{Dist: dst}
nn.Kids = []*Node{na, nb}
return nn
}
// TODO: this call signature does not fit with standard
// not sure how one might pack Node into a tensor
// Cluster implements agglomerative clustering, based on a
// distance matrix dmat, e.g., as computed by [metric.Matrix] method,
// using a metric that increases in value with greater dissimilarity.
// labels provides an optional String tensor list of labels for the elements
// of the distance matrix.
// This calls InitAllLeaves to initialize the root node with all of the leaves,
// and then Glom to do the iterative agglomerative clustering process.
// If you want to start with pre-defined initial clusters,
// then call Glom with a root node so-initialized.
func Cluster(metric Metrics, dmat, labels tensor.Tensor) *Node {
ntot := dmat.DimSize(0) // number of leaves
root := InitAllLeaves(ntot)
return Glom(root, metric, dmat)
}
// InitAllLeaves returns a standard root node initialized with all of the leaves.
func InitAllLeaves(ntot int) *Node {
root := &Node{}
root.Kids = make([]*Node, ntot)
for i := 0; i < ntot; i++ {
root.Kids[i] = &Node{Index: i}
}
return root
}
// Glom does the iterative agglomerative clustering,
// based on a raw similarity matrix as given,
// using a root node that has already been initialized
// with the starting clusters, which is all of the
// leaves by default, but could be anything if you want
// to start with predefined clusters.
func Glom(root *Node, metric Metrics, dmat tensor.Tensor) *Node {
ntot := dmat.DimSize(0) // number of leaves
mout := tensor.NewFloat64Scalar(0)
stats.MaxOut(tensor.As1D(dmat), mout)
maxd := mout.Float1D(0)
// indexes in each group
aidx := make([]int, ntot)
bidx := make([]int, ntot)
for {
var ma, mb []int
mval := math.MaxFloat64
for ai, ka := range root.Kids {
actr := 0
ka.Indexes(aidx, &actr)
aix := aidx[0:actr]
for bi := 0; bi < ai; bi++ {
kb := root.Kids[bi]
bctr := 0
kb.Indexes(bidx, &bctr)
bix := bidx[0:bctr]
dv := metric.Call(aix, bix, ntot, maxd, dmat)
if dv < mval {
mval = dv
ma = []int{ai}
mb = []int{bi}
} else if dv == mval { // do all ties at same time
ma = append(ma, ai)
mb = append(mb, bi)
}
}
}
ni := 0
if len(ma) > 1 {
ni = rand.Intn(len(ma))
}
na := ma[ni]
nb := mb[ni]
nn := NewNode(root.Kids[na], root.Kids[nb], mval)
for i := len(root.Kids) - 1; i >= 0; i-- {
if i == na || i == nb {
root.Kids = append(root.Kids[:i], root.Kids[i+1:]...)
}
}
root.Kids = append(root.Kids, nn)
if len(root.Kids) == 1 {
break
}
}
return root
}
// Code generated by "core generate"; DO NOT EDIT.
package cluster
import (
"cogentcore.org/core/enums"
)
var _MetricsValues = []Metrics{0, 1, 2, 3}
// MetricsN is the highest valid value for type Metrics, plus one.
const MetricsN Metrics = 4
var _MetricsValueMap = map[string]Metrics{`Min`: 0, `Max`: 1, `Avg`: 2, `Contrast`: 3}
var _MetricsDescMap = map[Metrics]string{0: `Min is the minimum-distance or single-linkage weighting function.`, 1: `Max is the maximum-distance or complete-linkage weighting function.`, 2: `Avg is the average-distance or average-linkage weighting function.`, 3: `Contrast computes maxd + (average within distance - average between distance).`}
var _MetricsMap = map[Metrics]string{0: `Min`, 1: `Max`, 2: `Avg`, 3: `Contrast`}
// String returns the string representation of this Metrics value.
func (i Metrics) String() string { return enums.String(i, _MetricsMap) }
// SetString sets the Metrics value from its string representation,
// and returns an error if the string is invalid.
func (i *Metrics) SetString(s string) error {
return enums.SetString(i, s, _MetricsValueMap, "Metrics")
}
// Int64 returns the Metrics value as an int64.
func (i Metrics) Int64() int64 { return int64(i) }
// SetInt64 sets the Metrics value from an int64.
func (i *Metrics) SetInt64(in int64) { *i = Metrics(in) }
// Desc returns the description of the Metrics value.
func (i Metrics) Desc() string { return enums.Desc(i, _MetricsDescMap) }
// MetricsValues returns all possible values for the type Metrics.
func MetricsValues() []Metrics { return _MetricsValues }
// Values returns all possible values for the type Metrics.
func (i Metrics) Values() []enums.Enum { return enums.Values(_MetricsValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i Metrics) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *Metrics) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Metrics") }
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cluster
import (
"math"
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/tensor"
)
func init() {
tensor.AddFunc(Min.FuncName(), MinFunc)
tensor.AddFunc(Max.FuncName(), MaxFunc)
tensor.AddFunc(Avg.FuncName(), AvgFunc)
tensor.AddFunc(Contrast.FuncName(), ContrastFunc)
}
// Metrics are standard clustering distance metric functions,
// specifying how a node computes its distance based on its leaves.
type Metrics int32 //enums:enum
const (
// Min is the minimum-distance or single-linkage weighting function.
Min Metrics = iota
// Max is the maximum-distance or complete-linkage weighting function.
Max
// Avg is the average-distance or average-linkage weighting function.
Avg
// Contrast computes maxd + (average within distance - average between distance).
Contrast
)
// MetricFunc is a clustering distance metric function that evaluates aggregate distance
// between nodes, given the indexes of leaves in a and b clusters
// which are indexs into an ntot x ntot distance matrix dmat.
// maxd is the maximum distance value in the dmat, which is needed by the
// ContrastDist function and perhaps others.
type MetricFunc func(aix, bix []int, ntot int, maxd float64, dmat tensor.Tensor) float64
// FuncName returns the package-qualified function name to use
// in tensor.Call to call this function.
func (m Metrics) FuncName() string {
return "cluster." + m.String()
}
// Func returns function for given metric.
func (m Metrics) Func() MetricFunc {
fn := errors.Log1(tensor.FuncByName(m.FuncName()))
return fn.Fun.(func(aix, bix []int, ntot int, maxd float64, dmat tensor.Tensor) float64)
}
// Call calls a standard Metrics enum function on given tensors.
// Output results are in the out tensor.
func (m Metrics) Call(aix, bix []int, ntot int, maxd float64, dmat tensor.Tensor) float64 {
return m.Func()(aix, bix, ntot, maxd, dmat)
}
// MinFunc is the minimum-distance or single-linkage weighting function for comparing
// two clusters a and b, given by their list of indexes.
// ntot is total number of nodes, and dmat is the square similarity matrix [ntot x ntot].
func MinFunc(aix, bix []int, ntot int, maxd float64, dmat tensor.Tensor) float64 {
md := math.MaxFloat64
for _, ai := range aix {
for _, bi := range bix {
d := dmat.Float(ai, bi)
if d < md {
md = d
}
}
}
return md
}
// MaxFunc is the maximum-distance or complete-linkage weighting function for comparing
// two clusters a and b, given by their list of indexes.
// ntot is total number of nodes, and dmat is the square similarity matrix [ntot x ntot].
func MaxFunc(aix, bix []int, ntot int, maxd float64, dmat tensor.Tensor) float64 {
md := -math.MaxFloat64
for _, ai := range aix {
for _, bi := range bix {
d := dmat.Float(ai, bi)
if d > md {
md = d
}
}
}
return md
}
// AvgFunc is the average-distance or average-linkage weighting function for comparing
// two clusters a and b, given by their list of indexes.
// ntot is total number of nodes, and dmat is the square similarity matrix [ntot x ntot].
func AvgFunc(aix, bix []int, ntot int, maxd float64, dmat tensor.Tensor) float64 {
md := 0.0
n := 0
for _, ai := range aix {
for _, bi := range bix {
d := dmat.Float(ai, bi)
md += d
n++
}
}
if n > 0 {
md /= float64(n)
}
return md
}
// ContrastFunc computes maxd + (average within distance - average between distance)
// for two clusters a and b, given by their list of indexes.
// avg between is average distance between all items in a & b versus all outside that.
// ntot is total number of nodes, and dmat is the square similarity matrix [ntot x ntot].
// maxd is the maximum distance and is needed to ensure distances are positive.
func ContrastFunc(aix, bix []int, ntot int, maxd float64, dmat tensor.Tensor) float64 {
wd := AvgFunc(aix, bix, ntot, maxd, dmat)
nab := len(aix) + len(bix)
abix := append(aix, bix...)
abmap := make(map[int]struct{}, ntot-nab)
for _, ix := range abix {
abmap[ix] = struct{}{}
}
oix := make([]int, ntot-nab)
octr := 0
for ix := 0; ix < ntot; ix++ {
if _, has := abmap[ix]; !has {
oix[octr] = ix
octr++
}
}
bd := AvgFunc(abix, oix, ntot, maxd, dmat)
return maxd + (wd - bd)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cluster
import (
"cogentcore.org/lab/plot"
"cogentcore.org/lab/plotcore"
"cogentcore.org/lab/stats/metric"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
)
// PlotFromTable creates a cluster plot in given [plotcore.Editor] plot
// using data from given data table, in column dataColumn,
// and labels from labelColumn, with given distance metric
// and cluster metric functions.
func PlotFromTable(plt *plotcore.Editor, dt *table.Table, distMetric metric.Metrics, clustMetric Metrics, dataColumn, labelColumn string) {
pt := table.New()
PlotFromTableToTable(pt, dt, distMetric, clustMetric, dataColumn, labelColumn)
plt.SetTable(pt)
}
// PlotFromTableToTable creates a cluster plot data to output data table (pt)
// using data from given data table (dt), in column dataColumn,
// and labels from labelColumn, with given distance metric
// and cluster metric functions.
func PlotFromTableToTable(pt *table.Table, dt *table.Table, distMetric metric.Metrics, clustMetric Metrics, dataColumn, labelColumn string) {
dm := metric.Matrix(distMetric.Func(), dt.Column(dataColumn))
labels := dt.Column(labelColumn)
cnd := Cluster(clustMetric, dm, labels)
Plot(pt, cnd, dm, labels)
}
// Plot sets the rows of given data table to trace out lines with labels that
// will render cluster plot starting at root node when plotted with a standard plotting package.
// The lines double-back on themselves to form a continuous line to be plotted.
func Plot(pt *table.Table, root *Node, dmat, labels tensor.Tensor) {
pt.DeleteAll()
pt.AddFloat64Column("X")
pt.AddFloat64Column("Y")
pt.AddStringColumn("Label")
nextY := 0.5
root.SetYs(&nextY)
root.SetParDist(0.0)
root.Plot(pt, dmat, labels)
plot.SetFirstStyler(pt.Columns.Values[0], func(s *plot.Style) {
s.Role = plot.X
s.Plot.PointsOn = plot.Off
})
plot.SetFirstStyler(pt.Columns.Values[1], func(s *plot.Style) {
s.On = true
s.Role = plot.Y
s.Plot.PointsOn = plot.Off
s.Range.FixMin = true
s.NoLegend = true
})
plot.SetFirstStyler(pt.Columns.At("Label"), func(s *plot.Style) {
s.On = true
s.Role = plot.Label
s.Plotter = "Labels"
s.Plot.PointsOn = plot.Off
s.Text.Offset.Y.Dp(8)
s.Text.Offset.X.Dp(2)
})
}
// Plot sets the rows of given data table to trace out lines with labels that
// will render this node in a cluster plot when plotted with a standard plotting package.
// The lines double-back on themselves to form a continuous line to be plotted.
func (nn *Node) Plot(pt *table.Table, dmat, labels tensor.Tensor) {
row := pt.NumRows()
xc := pt.ColumnByIndex(0)
yc := pt.ColumnByIndex(1)
lbl := pt.ColumnByIndex(2)
if nn.IsLeaf() {
pt.SetNumRows(row + 1)
xc.SetFloatRow(nn.ParDist, row, 0)
yc.SetFloatRow(nn.Y, row, 0)
if labels.Len() > nn.Index {
lbl.SetStringRow(labels.StringValue(nn.Index), row, 0)
}
} else {
for _, kn := range nn.Kids {
pt.SetNumRows(row + 2)
xc.SetFloatRow(nn.ParDist, row, 0)
yc.SetFloatRow(kn.Y, row, 0)
row++
xc.SetFloatRow(nn.ParDist+nn.Dist, row, 0)
yc.SetFloatRow(kn.Y, row, 0)
kn.Plot(pt, dmat, labels)
row = pt.NumRows()
pt.SetNumRows(row + 1)
xc.SetFloatRow(nn.ParDist, row, 0)
yc.SetFloatRow(kn.Y, row, 0)
row++
}
pt.SetNumRows(row + 1)
xc.SetFloatRow(nn.ParDist, row, 0)
yc.SetFloatRow(nn.Y, row, 0)
}
}
// SetYs sets the Y-axis values for the nodes in preparation for plotting.
func (nn *Node) SetYs(nextY *float64) {
if nn.IsLeaf() {
nn.Y = *nextY
(*nextY) += 1.0
} else {
avgy := 0.0
for _, kn := range nn.Kids {
kn.SetYs(nextY)
avgy += kn.Y
}
avgy /= float64(len(nn.Kids))
nn.Y = avgy
}
}
// SetParDist sets the parent distance for the nodes in preparation for plotting.
func (nn *Node) SetParDist(pard float64) {
nn.ParDist = pard
if !nn.IsLeaf() {
pard += nn.Dist
for _, kn := range nn.Kids {
kn.SetParDist(pard)
}
}
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package convolve
//go:generate core generate
import (
"errors"
"cogentcore.org/core/base/slicesx"
)
// Slice32 convolves given kernel with given source slice, putting results in
// destination, which is ensured to be the same size as the source slice,
// using existing capacity if available, and otherwise making a new slice.
// The kernel should be normalized, and odd-sized do it is symmetric about 0.
// Returns an error if sizes are not valid.
// No parallelization is used -- see Slice32Parallel for very large slices.
// Edges are handled separately with renormalized kernels -- they can be
// clipped from dest by excluding the kernel half-width from each end.
func Slice32(dest *[]float32, src []float32, kern []float32) error {
sz := len(src)
ksz := len(kern)
if ksz == 0 || sz == 0 {
return errors.New("convolve.Slice32: kernel or source are empty")
}
if ksz%2 == 0 {
return errors.New("convolve.Slice32: kernel is not odd sized")
}
if sz < ksz {
return errors.New("convolve.Slice32: source must be > kernel in size")
}
khalf := (ksz - 1) / 2
*dest = slicesx.SetLength(*dest, sz)
for i := khalf; i < sz-khalf; i++ {
var sum float32
for j := 0; j < ksz; j++ {
sum += src[(i-khalf)+j] * kern[j]
}
(*dest)[i] = sum
}
for i := 0; i < khalf; i++ {
var sum, ksum float32
for j := 0; j <= khalf+i; j++ {
ki := (j + khalf) - i // 0: 1+kh, 1: etc
si := i + (ki - khalf)
// fmt.Printf("i: %d j: %d ki: %d si: %d\n", i, j, ki, si)
sum += src[si] * kern[ki]
ksum += kern[ki]
}
(*dest)[i] = sum / ksum
}
for i := sz - khalf; i < sz; i++ {
var sum, ksum float32
ei := sz - i - 1
for j := 0; j <= khalf+ei; j++ {
ki := ((ksz - 1) - (j + khalf)) + ei
si := i + (ki - khalf)
// fmt.Printf("i: %d j: %d ki: %d si: %d ei: %d\n", i, j, ki, si, ei)
sum += src[si] * kern[ki]
ksum += kern[ki]
}
(*dest)[i] = sum / ksum
}
return nil
}
// Slice64 convolves given kernel with given source slice, putting results in
// destination, which is ensured to be the same size as the source slice,
// using existing capacity if available, and otherwise making a new slice.
// The kernel should be normalized, and odd-sized do it is symmetric about 0.
// Returns an error if sizes are not valid.
// No parallelization is used -- see Slice64Parallel for very large slices.
// Edges are handled separately with renormalized kernels -- they can be
// clipped from dest by excluding the kernel half-width from each end.
func Slice64(dest *[]float64, src []float64, kern []float64) error {
sz := len(src)
ksz := len(kern)
if ksz == 0 || sz == 0 {
return errors.New("convolve.Slice64: kernel or source are empty")
}
if ksz%2 == 0 {
return errors.New("convolve.Slice64: kernel is not odd sized")
}
if sz < ksz {
return errors.New("convolve.Slice64: source must be > kernel in size")
}
khalf := (ksz - 1) / 2
*dest = slicesx.SetLength(*dest, sz)
for i := khalf; i < sz-khalf; i++ {
var sum float64
for j := 0; j < ksz; j++ {
sum += src[(i-khalf)+j] * kern[j]
}
(*dest)[i] = sum
}
for i := 0; i < khalf; i++ {
var sum, ksum float64
for j := 0; j <= khalf+i; j++ {
ki := (j + khalf) - i // 0: 1+kh, 1: etc
si := i + (ki - khalf)
// fmt.Printf("i: %d j: %d ki: %d si: %d\n", i, j, ki, si)
sum += src[si] * kern[ki]
ksum += kern[ki]
}
(*dest)[i] = sum / ksum
}
for i := sz - khalf; i < sz; i++ {
var sum, ksum float64
ei := sz - i - 1
for j := 0; j <= khalf+ei; j++ {
ki := ((ksz - 1) - (j + khalf)) + ei
si := i + (ki - khalf)
// fmt.Printf("i: %d j: %d ki: %d si: %d ei: %d\n", i, j, ki, si, ei)
sum += src[si] * kern[ki]
ksum += kern[ki]
}
(*dest)[i] = sum / ksum
}
return nil
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package convolve
import (
"math"
"cogentcore.org/core/math32"
)
// GaussianKernel32 returns a normalized gaussian kernel for smoothing
// with given half-width and normalized sigma (actual sigma = khalf * sigma).
// A sigma value of .5 is typical for smaller half-widths for containing
// most of the gaussian efficiently -- anything lower than .33 is inefficient --
// generally just use a lower half-width instead.
func GaussianKernel32(khalf int, sigma float32) []float32 {
ksz := khalf*2 + 1
kern := make([]float32, ksz)
sigdiv := 1 / (sigma * float32(khalf))
var sum float32
for i := 0; i < ksz; i++ {
x := sigdiv * float32(i-khalf)
kv := math32.Exp(-0.5 * x * x)
kern[i] = kv
sum += kv
}
nfac := 1 / sum
for i := 0; i < ksz; i++ {
kern[i] *= nfac
}
return kern
}
// GaussianKernel64 returns a normalized gaussian kernel
// with given half-width and normalized sigma (actual sigma = khalf * sigma)
// A sigma value of .5 is typical for smaller half-widths for containing
// most of the gaussian efficiently -- anything lower than .33 is inefficient --
// generally just use a lower half-width instead.
func GaussianKernel64(khalf int, sigma float64) []float64 {
ksz := khalf*2 + 1
kern := make([]float64, ksz)
sigdiv := 1 / (sigma * float64(khalf))
var sum float64
for i := 0; i < ksz; i++ {
x := sigdiv * float64(i-khalf)
kv := math.Exp(-0.5 * x * x)
kern[i] = kv
sum += kv
}
nfac := 1 / sum
for i := 0; i < ksz; i++ {
kern[i] *= nfac
}
return kern
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package glm
import (
"fmt"
"math"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
)
// todo: add tests
// GLM contains results and parameters for running a general
// linear model, which is a general form of multivariate linear
// regression, supporting multiple independent and dependent
// variables. Make a NewGLM and then do Run() on a tensor
// table.Table with the relevant data in columns of the table.
// Batch-mode gradient descent is used and the relevant parameters
// can be altered from defaults before calling Run as needed.
type GLM struct {
// Coeff are the coefficients to map from input independent variables
// to the dependent variables. The first, outer dimension is number of
// dependent variables, and the second, inner dimension is number of
// independent variables plus one for the offset (b) (last element).
Coeff tensor.Float64
// mean squared error of the fitted values relative to data
MSE float64
// R2 is the r^2 total variance accounted for by the linear model,
// for each dependent variable = 1 - (ErrVariance / ObsVariance)
R2 []float64
// Observed variance of each of the dependent variables to be predicted.
ObsVariance []float64
// Variance of the error residuals per dependent variables
ErrVariance []float64
// optional names of the independent variables, for reporting results
IndepNames []string
// optional names of the dependent variables, for reporting results
DepNames []string
//////// Parameters for the GLM model fitting:
// ZeroOffset restricts the offset of the linear function to 0,
// forcing it to pass through the origin. Otherwise, a constant offset "b"
// is fit during the model fitting process.
ZeroOffset bool
// learning rate parameter, which can be adjusted to reduce iterations based on
// specific properties of the data, but the default is reasonable for most "typical" data.
LRate float64 `default:"0.1"`
// tolerance on difference in mean squared error (MSE) across iterations to stop
// iterating and consider the result to be converged.
StopTolerance float64 `default:"0.0001"`
// Constant cost factor subtracted from weights, for the L1 norm or "Lasso"
// regression. This is good for producing sparse results but can arbitrarily
// select one of multiple correlated independent variables.
L1Cost float64
// Cost factor proportional to the coefficient value, for the L2 norm or "Ridge"
// regression. This is good for generally keeping weights small and equally
// penalizes correlated independent variables.
L2Cost float64
// CostStartIter is the iteration when we start applying the L1, L2 Cost factors.
// It is often a good idea to have a few unconstrained iterations prior to
// applying the cost factors.
CostStartIter int `default:"5"`
// maximum number of iterations to perform
MaxIters int `default:"50"`
//////// Cached values from the table
// Table of data
Table *table.Table
// tensor columns from table with the respective variables
IndepVars, DepVars, PredVars, ErrVars tensor.RowMajor
// Number of independent and dependent variables
NIndepVars, NDepVars int
}
func NewGLM() *GLM {
glm := &GLM{}
glm.Defaults()
return glm
}
func (glm *GLM) Defaults() {
glm.LRate = 0.1
glm.StopTolerance = 0.001
glm.MaxIters = 50
glm.CostStartIter = 5
}
func (glm *GLM) init(nIv, nDv int) {
glm.NIndepVars = nIv
glm.NDepVars = nDv
glm.Coeff.SetShapeSizes(nDv, nIv+1)
// glm.Coeff.SetNames("DepVars", "IndepVars")
glm.R2 = make([]float64, nDv)
glm.ObsVariance = make([]float64, nDv)
glm.ErrVariance = make([]float64, nDv)
glm.IndepNames = make([]string, nIv)
glm.DepNames = make([]string, nDv)
}
// SetTable sets the data to use from given indexview of table, where
// each of the Vars args specifies a column in the table, which can have either a
// single scalar value for each row, or a tensor cell with multiple values.
// predVars and errVars (predicted values and error values) are optional.
func (glm *GLM) SetTable(dt *table.Table, indepVars, depVars, predVars, errVars string) error {
iv := dt.Column(indepVars)
dv := dt.Column(depVars)
var pv, ev *tensor.Rows
if predVars != "" {
pv = dt.Column(predVars)
}
if errVars != "" {
ev = dt.Column(errVars)
}
if pv != nil && !pv.Shape().IsEqual(dv.Shape()) {
return fmt.Errorf("predVars must have same shape as depVars")
}
if ev != nil && !ev.Shape().IsEqual(dv.Shape()) {
return fmt.Errorf("errVars must have same shape as depVars")
}
_, nIv := iv.Shape().RowCellSize()
_, nDv := dv.Shape().RowCellSize()
glm.init(nIv, nDv)
glm.Table = dt
glm.IndepVars = iv
glm.DepVars = dv
glm.PredVars = pv
glm.ErrVars = ev
return nil
}
// Run performs the multi-variate linear regression using data SetTable function,
// learning linear coefficients and an overall static offset that best
// fits the observed dependent variables as a function of the independent variables.
// Initial values of the coefficients, and other parameters for the regression,
// should be set prior to running.
func (glm *GLM) Run() {
dt := glm.Table
iv := glm.IndepVars
dv := glm.DepVars
pv := glm.PredVars
ev := glm.ErrVars
if pv == nil {
pv = tensor.Clone(dv)
}
if ev == nil {
ev = tensor.Clone(dv)
}
nDv := glm.NDepVars
nIv := glm.NIndepVars
nCi := nIv + 1
dc := glm.Coeff.Clone().(*tensor.Float64)
lastItr := false
sse := 0.0
prevmse := 0.0
n := dt.NumRows()
norm := 1.0 / float64(n)
lrate := norm * glm.LRate
for itr := 0; itr < glm.MaxIters; itr++ {
for i := range dc.Values {
dc.Values[i] = 0
}
sse = 0
if (itr+1)%10 == 0 {
lrate *= 0.5
}
for i := 0; i < n; i++ {
row := dt.RowIndex(i)
for di := 0; di < nDv; di++ {
pred := 0.0
for ii := 0; ii < nIv; ii++ {
pred += glm.Coeff.Float(di, ii) * iv.FloatRow(row, ii)
}
if !glm.ZeroOffset {
pred += glm.Coeff.Float(di, nIv)
}
targ := dv.FloatRow(row, di)
err := targ - pred
sse += err * err
for ii := 0; ii < nIv; ii++ {
dc.Values[di*nCi+ii] += err * iv.FloatRow(row, ii)
}
if !glm.ZeroOffset {
dc.Values[di*nCi+nIv] += err
}
if lastItr {
pv.SetFloatRow(pred, row, di)
if ev != nil {
ev.SetFloatRow(err, row, di)
}
}
}
}
for di := 0; di < nDv; di++ {
for ii := 0; ii <= nIv; ii++ {
if glm.ZeroOffset && ii == nIv {
continue
}
idx := di*(nCi+1) + ii
w := glm.Coeff.Values[idx]
d := dc.Values[idx]
sgn := 1.0
if w < 0 {
sgn = -1.0
} else if w == 0 {
sgn = 0
}
glm.Coeff.Values[idx] += lrate * (d - glm.L1Cost*sgn - glm.L2Cost*w)
}
}
glm.MSE = norm * sse
if lastItr {
break
}
if itr > 0 {
dmse := glm.MSE - prevmse
if math.Abs(dmse) < glm.StopTolerance || itr == glm.MaxIters-2 {
lastItr = true
}
}
fmt.Println(itr, glm.MSE)
prevmse = glm.MSE
}
obsMeans := make([]float64, nDv)
errMeans := make([]float64, nDv)
for i := 0; i < n; i++ {
row := i
if dt.Indexes != nil {
row = dt.Indexes[i]
}
for di := 0; di < nDv; di++ {
obsMeans[di] += dv.FloatRow(row, di)
errMeans[di] += ev.FloatRow(row, di)
}
}
for di := 0; di < nDv; di++ {
obsMeans[di] *= norm
errMeans[di] *= norm
glm.ObsVariance[di] = 0
glm.ErrVariance[di] = 0
}
for i := 0; i < n; i++ {
row := i
if dt.Indexes != nil {
row = dt.Indexes[i]
}
for di := 0; di < nDv; di++ {
o := dv.FloatRow(row, di) - obsMeans[di]
glm.ObsVariance[di] += o * o
e := ev.FloatRow(row, di) - errMeans[di]
glm.ErrVariance[di] += e * e
}
}
for di := 0; di < nDv; di++ {
glm.ObsVariance[di] *= norm
glm.ErrVariance[di] *= norm
glm.R2[di] = 1.0 - (glm.ErrVariance[di] / glm.ObsVariance[di])
}
}
// Variance returns a description of the variance accounted for by the regression
// equation, R^2, for each dependent variable, along with the variances of
// observed and errors (residuals), which are used to compute it.
func (glm *GLM) Variance() string {
str := ""
for di := range glm.R2 {
if len(glm.DepNames) > di && glm.DepNames[di] != "" {
str += glm.DepNames[di]
} else {
str += fmt.Sprintf("DV %d", di)
}
str += fmt.Sprintf("\tR^2: %8.6g\tR: %8.6g\tVar Err: %8.4g\t Obs: %8.4g\n", glm.R2[di], math.Sqrt(glm.R2[di]), glm.ErrVariance[di], glm.ObsVariance[di])
}
return str
}
// Coeffs returns a string describing the coefficients
func (glm *GLM) Coeffs() string {
str := ""
for di := range glm.NDepVars {
if len(glm.DepNames) > di && glm.DepNames[di] != "" {
str += glm.DepNames[di]
} else {
str += fmt.Sprintf("DV %d", di)
}
str += " = "
for ii := 0; ii <= glm.NIndepVars; ii++ {
str += fmt.Sprintf("\t%8.6g", glm.Coeff.Float(di, ii))
if ii < glm.NIndepVars {
str += " * "
if len(glm.IndepNames) > ii && glm.IndepNames[di] != "" {
str += glm.IndepNames[di]
} else {
str += fmt.Sprintf("IV_%d", ii)
}
str += " + "
}
}
str += "\n"
}
return str
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package histogram
//go:generate core generate
import (
"cogentcore.org/core/base/slicesx"
"cogentcore.org/core/math32"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
)
// F64 generates a histogram of counts of values within given
// number of bins and min / max range. hist vals is sized to nBins.
// if value is < min or > max it is ignored.
func F64(hist *[]float64, vals []float64, nBins int, min, max float64) {
*hist = slicesx.SetLength(*hist, nBins)
h := *hist
// 0.1.2.3 = 3-0 = 4 bins
inc := (max - min) / float64(nBins)
for i := 0; i < nBins; i++ {
h[i] = 0
}
for _, v := range vals {
if v < min || v > max {
continue
}
bin := int((v - min) / inc)
if bin >= nBins {
bin = nBins - 1
}
h[bin] += 1
}
}
// F64Table generates an table with a histogram of counts of values within given
// number of bins and min / max range. The table has columns: Value, Count
// if value is < min or > max it is ignored.
// The Value column represents the min value for each bin, with the max being
// the value of the next bin, or the max if at the end.
func F64Table(dt *table.Table, vals []float64, nBins int, min, max float64) {
dt.DeleteAll()
dt.AddFloat64Column("Value")
dt.AddFloat64Column("Count")
dt.SetNumRows(nBins)
ct := dt.Columns.Values[1].(*tensor.Float64)
F64(&ct.Values, vals, nBins, min, max)
inc := (max - min) / float64(nBins)
vls := dt.Columns.Values[0].(*tensor.Float64).Values
for i := 0; i < nBins; i++ {
vls[i] = math32.Truncate64(min+float64(i)*inc, 4)
}
}
// Code generated by "core generate"; DO NOT EDIT.
package metric
import (
"cogentcore.org/core/enums"
)
var _MetricsValues = []Metrics{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}
// MetricsN is the highest valid value for type Metrics, plus one.
const MetricsN Metrics = 13
var _MetricsValueMap = map[string]Metrics{`L2Norm`: 0, `SumSquares`: 1, `L1Norm`: 2, `Hamming`: 3, `L2NormBinTol`: 4, `SumSquaresBinTol`: 5, `InvCosine`: 6, `InvCorrelation`: 7, `CrossEntropy`: 8, `DotProduct`: 9, `Covariance`: 10, `Correlation`: 11, `Cosine`: 12}
var _MetricsDescMap = map[Metrics]string{0: `L2Norm is the square root of the sum of squares differences between tensor values, aka the Euclidean distance.`, 1: `SumSquares is the sum of squares differences between tensor values.`, 2: `L1Norm is the sum of the absolute value of differences between tensor values, the L1 Norm.`, 3: `Hamming is the sum of 1s for every element that is different, i.e., "city block" distance.`, 4: `L2NormBinTol is the [L2Norm] square root of the sum of squares differences between tensor values, with binary tolerance: differences < 0.5 are thresholded to 0.`, 5: `SumSquaresBinTol is the [SumSquares] differences between tensor values, with binary tolerance: differences < 0.5 are thresholded to 0.`, 6: `InvCosine is 1-[Cosine], which is useful to convert it to an Increasing metric where more different vectors have larger metric values.`, 7: `InvCorrelation is 1-[Correlation], which is useful to convert it to an Increasing metric where more different vectors have larger metric values.`, 8: `CrossEntropy is a standard measure of the difference between two probabilty distributions, reflecting the additional entropy (uncertainty) associated with measuring probabilities under distribution b when in fact they come from distribution a. It is also the entropy of a plus the divergence between a from b, using Kullback-Leibler (KL) divergence. It is computed as: a * log(a/b) + (1-a) * log(1-a/1-b).`, 9: `DotProduct is the sum of the co-products of the tensor values.`, 10: `Covariance is co-variance between two vectors, i.e., the mean of the co-product of each vector element minus the mean of that vector: cov(A,B) = E[(A - E(A))(B - E(B))].`, 11: `Correlation is the standardized [Covariance] in the range (-1..1), computed as the mean of the co-product of each vector element minus the mean of that vector, normalized by the product of their standard deviations: cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B). Equivalent to the [Cosine] of mean-normalized vectors.`, 12: `Cosine is high-dimensional angle between two vectors, in range (-1..1) as the normalized [DotProduct]: inner product / sqrt(ssA * ssB). See also [Correlation].`}
var _MetricsMap = map[Metrics]string{0: `L2Norm`, 1: `SumSquares`, 2: `L1Norm`, 3: `Hamming`, 4: `L2NormBinTol`, 5: `SumSquaresBinTol`, 6: `InvCosine`, 7: `InvCorrelation`, 8: `CrossEntropy`, 9: `DotProduct`, 10: `Covariance`, 11: `Correlation`, 12: `Cosine`}
// String returns the string representation of this Metrics value.
func (i Metrics) String() string { return enums.String(i, _MetricsMap) }
// SetString sets the Metrics value from its string representation,
// and returns an error if the string is invalid.
func (i *Metrics) SetString(s string) error {
return enums.SetString(i, s, _MetricsValueMap, "Metrics")
}
// Int64 returns the Metrics value as an int64.
func (i Metrics) Int64() int64 { return int64(i) }
// SetInt64 sets the Metrics value from an int64.
func (i *Metrics) SetInt64(in int64) { *i = Metrics(in) }
// Desc returns the description of the Metrics value.
func (i Metrics) Desc() string { return enums.Desc(i, _MetricsDescMap) }
// MetricsValues returns all possible values for the type Metrics.
func MetricsValues() []Metrics { return _MetricsValues }
// Values returns all possible values for the type Metrics.
func (i Metrics) Values() []enums.Enum { return enums.Values(_MetricsValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i Metrics) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *Metrics) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Metrics") }
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package metric
import (
"math"
"cogentcore.org/core/math32"
"cogentcore.org/lab/stats/stats"
"cogentcore.org/lab/tensor"
)
// MetricFunc is the function signature for a metric function,
// which is computed over the outermost row dimension and the
// output is the shape of the remaining inner cells (a scalar for 1D inputs).
// Use [tensor.As1D], [tensor.NewRowCellsView], [tensor.Cells1D] etc
// to reshape and reslice the data as needed.
// All metric functions skip over NaN's, as a missing value,
// and use the min of the length of the two tensors.
// Metric functions cannot be computed in parallel,
// e.g., using VectorizeThreaded or GPU, due to shared writing
// to the same output values. Special implementations are required
// if that is needed.
type MetricFunc = func(a, b tensor.Tensor) tensor.Values
// MetricOutFunc is the function signature for a metric function,
// that takes output values as the final argument. See [MetricFunc].
// This version is for computationally demanding cases and saves
// reallocation of output.
type MetricOutFunc = func(a, b tensor.Tensor, out tensor.Values) error
// SumSquaresScaleOut64 computes the sum of squares differences between tensor values,
// returning scale and ss factors aggregated separately for better numerical stability, per BLAS.
func SumSquaresScaleOut64(a, b tensor.Tensor) (scale64, ss64 *tensor.Float64, err error) {
if err = tensor.MustBeSameShape(a, b); err != nil {
return
}
scale64, ss64 = Vectorize2Out64(a, b, 0, 1, func(a, b, scale, ss float64) (float64, float64) {
if math.IsNaN(a) || math.IsNaN(b) {
return scale, ss
}
d := a - b
if d == 0 {
return scale, ss
}
absxi := math.Abs(d)
if scale < absxi {
ss = 1 + ss*(scale/absxi)*(scale/absxi)
scale = absxi
} else {
ss = ss + (absxi/scale)*(absxi/scale)
}
return scale, ss
})
return
}
// SumSquaresOut64 computes the sum of squares differences between tensor values,
// and returns the Float64 output values for use in subsequent computations.
func SumSquaresOut64(a, b tensor.Tensor, out tensor.Values) (*tensor.Float64, error) {
scale64, ss64, err := SumSquaresScaleOut64(a, b)
if err != nil {
return nil, err
}
osz := tensor.CellsSize(a.ShapeSizes())
out.SetShapeSizes(osz...)
nsub := out.Len()
for i := range nsub {
scale := scale64.Float1D(i)
ss := ss64.Float1D(i)
v := 0.0
if math.IsInf(scale, 1) {
v = math.Inf(1)
} else {
v = scale * scale * ss
}
scale64.SetFloat1D(v, i)
out.SetFloat1D(v, i)
}
return scale64, err
}
// SumSquaresOut computes the sum of squares differences between tensor values,
// See [MetricOutFunc] for general information.
func SumSquaresOut(a, b tensor.Tensor, out tensor.Values) error {
_, err := SumSquaresOut64(a, b, out)
return err
}
// SumSquares computes the sum of squares differences between tensor values,
// See [MetricFunc] for general information.
func SumSquares(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(SumSquaresOut, a, b)
}
// L2NormOut computes the L2 Norm: square root of the sum of squares
// differences between tensor values, aka the Euclidean distance.
func L2NormOut(a, b tensor.Tensor, out tensor.Values) error {
scale64, ss64, err := SumSquaresScaleOut64(a, b)
if err != nil {
return err
}
osz := tensor.CellsSize(a.ShapeSizes())
out.SetShapeSizes(osz...)
nsub := out.Len()
for i := range nsub {
scale := scale64.Float1D(i)
ss := ss64.Float1D(i)
v := 0.0
if math.IsInf(scale, 1) {
v = math.Inf(1)
} else {
v = scale * math.Sqrt(ss)
}
scale64.SetFloat1D(v, i)
out.SetFloat1D(v, i)
}
return nil
}
// L2Norm computes the L2 Norm: square root of the sum of squares
// differences between tensor values, aka the Euclidean distance.
// See [MetricFunc] for general information.
func L2Norm(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(L2NormOut, a, b)
}
// L1NormOut computes the sum of the absolute value of differences between the
// tensor values, the L1 Norm.
// See [MetricOutFunc] for general information.
func L1NormOut(a, b tensor.Tensor, out tensor.Values) error {
if err := tensor.MustBeSameShape(a, b); err != nil {
return err
}
VectorizeOut64(a, b, out, 0, func(a, b, agg float64) float64 {
if math.IsNaN(a) || math.IsNaN(b) {
return agg
}
return agg + math.Abs(a-b)
})
return nil
}
// L1Norm computes the sum of the absolute value of differences between the
// tensor values, the L1 Norm.
// See [MetricFunc] for general information.
func L1Norm(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(L1NormOut, a, b)
}
// HammingOut computes the sum of 1s for every element that is different,
// i.e., "city block" distance.
// See [MetricOutFunc] for general information.
func HammingOut(a, b tensor.Tensor, out tensor.Values) error {
if err := tensor.MustBeSameShape(a, b); err != nil {
return err
}
VectorizeOut64(a, b, out, 0, func(a, b, agg float64) float64 {
if math.IsNaN(a) || math.IsNaN(b) {
return agg
}
if a != b {
agg += 1
}
return agg
})
return nil
}
// Hamming computes the sum of 1s for every element that is different,
// i.e., "city block" distance.
// See [MetricFunc] for general information.
func Hamming(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(HammingOut, a, b)
}
// SumSquaresBinTolScaleOut64 computes the sum of squares differences between tensor values,
// with binary tolerance: differences < 0.5 are thresholded to 0.
// returning scale and ss factors aggregated separately for better numerical stability, per BLAS.
func SumSquaresBinTolScaleOut64(a, b tensor.Tensor) (scale64, ss64 *tensor.Float64, err error) {
if err = tensor.MustBeSameShape(a, b); err != nil {
return
}
scale64, ss64 = Vectorize2Out64(a, b, 0, 1, func(a, b, scale, ss float64) (float64, float64) {
if math.IsNaN(a) || math.IsNaN(b) {
return scale, ss
}
d := a - b
if math.Abs(d) < 0.5 {
return scale, ss
}
absxi := math.Abs(d)
if scale < absxi {
ss = 1 + ss*(scale/absxi)*(scale/absxi)
scale = absxi
} else {
ss = ss + (absxi/scale)*(absxi/scale)
}
return scale, ss
})
return
}
// L2NormBinTolOut computes the L2 Norm square root of the sum of squares
// differences between tensor values (aka Euclidean distance), with binary tolerance:
// differences < 0.5 are thresholded to 0.
func L2NormBinTolOut(a, b tensor.Tensor, out tensor.Values) error {
scale64, ss64, err := SumSquaresBinTolScaleOut64(a, b)
if err != nil {
return err
}
osz := tensor.CellsSize(a.ShapeSizes())
out.SetShapeSizes(osz...)
nsub := out.Len()
for i := range nsub {
scale := scale64.Float1D(i)
ss := ss64.Float1D(i)
v := 0.0
if math.IsInf(scale, 1) {
v = math.Inf(1)
} else {
v = scale * math.Sqrt(ss)
}
scale64.SetFloat1D(v, i)
out.SetFloat1D(v, i)
}
return nil
}
// L2NormBinTol computes the L2 Norm square root of the sum of squares
// differences between tensor values (aka Euclidean distance), with binary tolerance:
// differences < 0.5 are thresholded to 0.
// See [MetricFunc] for general information.
func L2NormBinTol(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(L2NormBinTolOut, a, b)
}
// SumSquaresBinTolOut computes the sum of squares differences between tensor values,
// with binary tolerance: differences < 0.5 are thresholded to 0.
func SumSquaresBinTolOut(a, b tensor.Tensor, out tensor.Values) error {
scale64, ss64, err := SumSquaresBinTolScaleOut64(a, b)
if err != nil {
return err
}
osz := tensor.CellsSize(a.ShapeSizes())
out.SetShapeSizes(osz...)
nsub := out.Len()
for i := range nsub {
scale := scale64.Float1D(i)
ss := ss64.Float1D(i)
v := 0.0
if math.IsInf(scale, 1) {
v = math.Inf(1)
} else {
v = scale * scale * ss
}
scale64.SetFloat1D(v, i)
out.SetFloat1D(v, i)
}
return nil
}
// SumSquaresBinTol computes the sum of squares differences between tensor values,
// with binary tolerance: differences < 0.5 are thresholded to 0.
// See [MetricFunc] for general information.
func SumSquaresBinTol(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(SumSquaresBinTolOut, a, b)
}
// CrossEntropyOut is a standard measure of the difference between two
// probabilty distributions, reflecting the additional entropy (uncertainty) associated
// with measuring probabilities under distribution b when in fact they come from
// distribution a. It is also the entropy of a plus the divergence between a from b,
// using Kullback-Leibler (KL) divergence. It is computed as:
// a * log(a/b) + (1-a) * log(1-a/1-b).
// See [MetricOutFunc] for general information.
func CrossEntropyOut(a, b tensor.Tensor, out tensor.Values) error {
if err := tensor.MustBeSameShape(a, b); err != nil {
return err
}
VectorizeOut64(a, b, out, 0, func(a, b, agg float64) float64 {
if math.IsNaN(a) || math.IsNaN(b) {
return agg
}
b = math32.Clamp(b, 0.000001, 0.999999)
if a >= 1.0 {
agg += -math.Log(b)
} else if a <= 0.0 {
agg += -math.Log(1.0 - b)
} else {
agg += a*math.Log(a/b) + (1-a)*math.Log((1-a)/(1-b))
}
return agg
})
return nil
}
// CrossEntropy is a standard measure of the difference between two
// probabilty distributions, reflecting the additional entropy (uncertainty) associated
// with measuring probabilities under distribution b when in fact they come from
// distribution a. It is also the entropy of a plus the divergence between a from b,
// using Kullback-Leibler (KL) divergence. It is computed as:
// a * log(a/b) + (1-a) * log(1-a/1-b).
// See [MetricFunc] for general information.
func CrossEntropy(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(CrossEntropyOut, a, b)
}
// DotProductOut computes the sum of the element-wise products of the
// two tensors (aka the inner product).
// See [MetricOutFunc] for general information.
func DotProductOut(a, b tensor.Tensor, out tensor.Values) error {
if err := tensor.MustBeSameShape(a, b); err != nil {
return err
}
VectorizeOut64(a, b, out, 0, func(a, b, agg float64) float64 {
if math.IsNaN(a) || math.IsNaN(b) {
return agg
}
return agg + a*b
})
return nil
}
// DotProductOut computes the sum of the element-wise products of the
// two tensors (aka the inner product).
// See [MetricFunc] for general information.
func DotProduct(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(DotProductOut, a, b)
}
// CovarianceOut computes the co-variance between two vectors,
// i.e., the mean of the co-product of each vector element minus
// the mean of that vector: cov(A,B) = E[(A - E(A))(B - E(B))].
func CovarianceOut(a, b tensor.Tensor, out tensor.Values) error {
if err := tensor.MustBeSameShape(a, b); err != nil {
return err
}
amean, acount := stats.MeanOut64(a, out)
bmean, _ := stats.MeanOut64(b, out)
cov64 := VectorizePreOut64(a, b, out, 0, amean, bmean, func(a, b, am, bm, agg float64) float64 {
if math.IsNaN(a) || math.IsNaN(b) {
return agg
}
return agg + (a-am)*(b-bm)
})
osz := tensor.CellsSize(a.ShapeSizes())
out.SetShapeSizes(osz...)
nsub := out.Len()
for i := range nsub {
c := acount.Float1D(i)
if c == 0 {
continue
}
cov := cov64.Float1D(i) / c
out.SetFloat1D(cov, i)
}
return nil
}
// Covariance computes the co-variance between two vectors,
// i.e., the mean of the co-product of each vector element minus
// the mean of that vector: cov(A,B) = E[(A - E(A))(B - E(B))].
// See [MetricFunc] for general information.
func Covariance(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(CovarianceOut, a, b)
}
// CorrelationOut64 computes the correlation between two vectors,
// in range (-1..1) as the mean of the co-product of each vector
// element minus the mean of that vector, normalized by the product of their
// standard deviations: cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B).
// (i.e., the standardized covariance).
// Equivalent to the cosine of mean-normalized vectors.
// Returns the Float64 output values for subsequent use.
func CorrelationOut64(a, b tensor.Tensor, out tensor.Values) (*tensor.Float64, error) {
if err := tensor.MustBeSameShape(a, b); err != nil {
return nil, err
}
amean, _ := stats.MeanOut64(a, out)
bmean, _ := stats.MeanOut64(b, out)
ss64, avar64, bvar64 := VectorizePre3Out64(a, b, 0, 0, 0, amean, bmean, func(a, b, am, bm, ss, avar, bvar float64) (float64, float64, float64) {
if math.IsNaN(a) || math.IsNaN(b) {
return ss, avar, bvar
}
ad := a - am
bd := b - bm
ss += ad * bd // between
avar += ad * ad // within
bvar += bd * bd
return ss, avar, bvar
})
osz := tensor.CellsSize(a.ShapeSizes())
out.SetShapeSizes(osz...)
nsub := out.Len()
for i := range nsub {
ss := ss64.Float1D(i)
vp := math.Sqrt(avar64.Float1D(i) * bvar64.Float1D(i))
if vp > 0 {
ss /= vp
}
ss64.SetFloat1D(ss, i)
out.SetFloat1D(ss, i)
}
return ss64, nil
}
// CorrelationOut computes the correlation between two vectors,
// in range (-1..1) as the mean of the co-product of each vector
// element minus the mean of that vector, normalized by the product of their
// standard deviations: cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B).
// (i.e., the standardized [Covariance]).
// Equivalent to the [Cosine] of mean-normalized vectors.
func CorrelationOut(a, b tensor.Tensor, out tensor.Values) error {
_, err := CorrelationOut64(a, b, out)
return err
}
// Correlation computes the correlation between two vectors,
// in range (-1..1) as the mean of the co-product of each vector
// element minus the mean of that vector, normalized by the product of their
// standard deviations: cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B).
// (i.e., the standardized [Covariance]).
// Equivalent to the [Cosine] of mean-normalized vectors.
// See [MetricFunc] for general information.
func Correlation(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(CorrelationOut, a, b)
}
// InvCorrelationOut computes 1 minus the correlation between two vectors,
// in range (-1..1) as the mean of the co-product of each vector
// element minus the mean of that vector, normalized by the product of their
// standard deviations: cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B).
// (i.e., the standardized covariance).
// Equivalent to the [Cosine] of mean-normalized vectors.
// This is useful for a difference measure instead of similarity,
// where more different vectors have larger metric values.
func InvCorrelationOut(a, b tensor.Tensor, out tensor.Values) error {
cor64, err := CorrelationOut64(a, b, out)
if err != nil {
return err
}
nsub := out.Len()
for i := range nsub {
cor := cor64.Float1D(i)
out.SetFloat1D(1-cor, i)
}
return nil
}
// InvCorrelation computes 1 minus the correlation between two vectors,
// in range (-1..1) as the mean of the co-product of each vector
// element minus the mean of that vector, normalized by the product of their
// standard deviations: cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B).
// (i.e., the standardized covariance).
// Equivalent to the [Cosine] of mean-normalized vectors.
// This is useful for a difference measure instead of similarity,
// where more different vectors have larger metric values.
// See [MetricFunc] for general information.
func InvCorrelation(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(InvCorrelationOut, a, b)
}
// CosineOut64 computes the high-dimensional angle between two vectors,
// in range (-1..1) as the normalized [Dot]:
// dot product / sqrt(ssA * ssB). See also [Correlation].
func CosineOut64(a, b tensor.Tensor, out tensor.Values) (*tensor.Float64, error) {
if err := tensor.MustBeSameShape(a, b); err != nil {
return nil, err
}
ss64, avar64, bvar64 := Vectorize3Out64(a, b, 0, 0, 0, func(a, b, ss, avar, bvar float64) (float64, float64, float64) {
if math.IsNaN(a) || math.IsNaN(b) {
return ss, avar, bvar
}
ss += a * b
avar += a * a
bvar += b * b
return ss, avar, bvar
})
osz := tensor.CellsSize(a.ShapeSizes())
out.SetShapeSizes(osz...)
nsub := out.Len()
for i := range nsub {
ss := ss64.Float1D(i)
vp := math.Sqrt(avar64.Float1D(i) * bvar64.Float1D(i))
if vp > 0 {
ss /= vp
}
ss64.SetFloat1D(ss, i)
out.SetFloat1D(ss, i)
}
return ss64, nil
}
// CosineOut computes the high-dimensional angle between two vectors,
// in range (-1..1) as the normalized dot product:
// dot product / sqrt(ssA * ssB). See also [Correlation]
func CosineOut(a, b tensor.Tensor, out tensor.Values) error {
_, err := CosineOut64(a, b, out)
return err
}
// Cosine computes the high-dimensional angle between two vectors,
// in range (-1..1) as the normalized dot product:
// dot product / sqrt(ssA * ssB). See also [Correlation]
// See [MetricFunc] for general information.
func Cosine(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(CosineOut, a, b)
}
// InvCosineOut computes 1 minus the cosine between two vectors,
// in range (-1..1) as the normalized dot product:
// dot product / sqrt(ssA * ssB).
// This is useful for a difference measure instead of similarity,
// where more different vectors have larger metric values.
func InvCosineOut(a, b tensor.Tensor, out tensor.Values) error {
cos64, err := CosineOut64(a, b, out)
if err != nil {
return err
}
nsub := out.Len()
for i := range nsub {
cos := cos64.Float1D(i)
out.SetFloat1D(1-cos, i)
}
return nil
}
// InvCosine computes 1 minus the cosine between two vectors,
// in range (-1..1) as the normalized dot product:
// dot product / sqrt(ssA * ssB).
// This is useful for a difference measure instead of similarity,
// where more different vectors have larger metric values.
// See [MetricFunc] for general information.
func InvCosine(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(InvCosineOut, a, b)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package metric
import (
"cogentcore.org/lab/matrix"
"cogentcore.org/lab/tensor"
)
// MatrixOut computes the rows x rows square distance / similarity matrix
// between the patterns for each row of the given higher dimensional input tensor,
// which must have at least 2 dimensions: the outermost rows,
// and within that, 1+dimensional patterns (cells). Use [tensor.NewRowCellsView]
// to organize data into the desired split between a 1D outermost Row dimension
// and the remaining Cells dimension.
// The metric function must have the [MetricFunc] signature.
// The results fill in the elements of the output matrix, which is symmetric,
// and only the lower triangular part is computed, with results copied
// to the upper triangular region, for maximum efficiency.
func MatrixOut(fun any, in tensor.Tensor, out tensor.Values) error {
mfun, err := AsMetricFunc(fun)
if err != nil {
return err
}
rows, cells := in.Shape().RowCellSize()
if rows == 0 || cells == 0 {
return nil
}
out.SetShapeSizes(rows, rows)
coords := matrix.TriLIndicies(rows)
nc := coords.DimSize(0)
// note: flops estimating 3 per item on average -- different for different metrics.
tensor.VectorizeThreaded(cells*3, func(tsr ...tensor.Tensor) int { return nc },
func(idx int, tsr ...tensor.Tensor) {
cx := coords.Int(idx, 0)
cy := coords.Int(idx, 1)
sa := tensor.Cells1D(tsr[0], cx)
sb := tensor.Cells1D(tsr[0], cy)
mout := mfun(sa, sb)
tsr[1].SetFloat(mout.Float1D(0), cx, cy)
}, in, out)
for idx := range nc { // copy to upper
cx := coords.Int(idx, 0)
cy := coords.Int(idx, 1)
if cx == cy { // exclude diag
continue
}
out.SetFloat(out.Float(cx, cy), cy, cx)
}
return nil
}
// Matrix computes the rows x rows square distance / similarity matrix
// between the patterns for each row of the given higher dimensional input tensor,
// which must have at least 2 dimensions: the outermost rows,
// and within that, 1+dimensional patterns (cells). Use [tensor.NewRowCellsView]
// to organize data into the desired split between a 1D outermost Row dimension
// and the remaining Cells dimension.
// The metric function must have the [MetricFunc] signature.
// The results fill in the elements of the output matrix, which is symmetric,
// and only the lower triangular part is computed, with results copied
// to the upper triangular region, for maximum efficiency.
func Matrix(fun any, in tensor.Tensor) tensor.Values {
return tensor.CallOut1Gen1(MatrixOut, fun, in)
}
// CrossMatrixOut computes the distance / similarity matrix between
// two different sets of patterns in the two input tensors, where
// the patterns are in the sub-space cells of the tensors,
// which must have at least 2 dimensions: the outermost rows,
// and within that, 1+dimensional patterns that the given distance metric
// function is applied to, with the results filling in the cells of the output matrix.
// The metric function must have the [MetricFunc] signature.
// The rows of the output matrix are the rows of the first input tensor,
// and the columns of the output are the rows of the second input tensor.
func CrossMatrixOut(fun any, a, b tensor.Tensor, out tensor.Values) error {
mfun, err := AsMetricFunc(fun)
if err != nil {
return err
}
arows, acells := a.Shape().RowCellSize()
if arows == 0 || acells == 0 {
return nil
}
brows, bcells := b.Shape().RowCellSize()
if brows == 0 || bcells == 0 {
return nil
}
out.SetShapeSizes(arows, brows)
// note: flops estimating 3 per item on average -- different for different metrics.
flops := min(acells, bcells) * 3
nc := arows * brows
tensor.VectorizeThreaded(flops, func(tsr ...tensor.Tensor) int { return nc },
func(idx int, tsr ...tensor.Tensor) {
ar := idx / brows
br := idx % brows
sa := tensor.Cells1D(tsr[0], ar)
sb := tensor.Cells1D(tsr[1], br)
mout := mfun(sa, sb)
tsr[2].SetFloat(mout.Float1D(0), ar, br)
}, a, b, out)
return nil
}
// CrossMatrix computes the distance / similarity matrix between
// two different sets of patterns in the two input tensors, where
// the patterns are in the sub-space cells of the tensors,
// which must have at least 2 dimensions: the outermost rows,
// and within that, 1+dimensional patterns that the given distance metric
// function is applied to, with the results filling in the cells of the output matrix.
// The metric function must have the [MetricFunc] signature.
// The rows of the output matrix are the rows of the first input tensor,
// and the columns of the output are the rows of the second input tensor.
func CrossMatrix(fun any, a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2Gen1(CrossMatrixOut, fun, a, b)
}
// CovarianceMatrixOut generates the cells x cells square covariance matrix
// for all per-row cells of the given higher dimensional input tensor,
// which must have at least 2 dimensions: the outermost rows,
// and within that, 1+dimensional patterns (cells).
// Each value in the resulting matrix represents the extent to which the
// value of a given cell covaries across the rows of the tensor with the
// value of another cell.
// Uses the given metric function, typically [Covariance] or [Correlation],
// The metric function must have the [MetricFunc] signature.
// Use Covariance if vars have similar overall scaling,
// which is typical in neural network models, and use
// Correlation if they are on very different scales, because it effectively rescales).
// The resulting matrix can be used as the input to PCA or SVD eigenvalue decomposition.
func CovarianceMatrixOut(fun any, in tensor.Tensor, out tensor.Values) error {
mfun, err := AsMetricFunc(fun)
if err != nil {
return err
}
rows, cells := in.Shape().RowCellSize()
if rows == 0 || cells == 0 {
return nil
}
out.SetShapeSizes(cells, cells)
flatvw := tensor.NewReshaped(in, rows, cells)
coords := matrix.TriLIndicies(cells)
nc := coords.DimSize(0)
// note: flops estimating 3 per item on average -- different for different metrics.
tensor.VectorizeThreaded(rows*3, func(tsr ...tensor.Tensor) int { return nc },
func(idx int, tsr ...tensor.Tensor) {
cx := coords.Int(idx, 0)
cy := coords.Int(idx, 1)
av := tensor.Reslice(tsr[0], tensor.FullAxis, cx)
bv := tensor.Reslice(tsr[0], tensor.FullAxis, cy)
mout := mfun(av, bv)
tsr[1].SetFloat(mout.Float1D(0), cx, cy)
}, flatvw, out)
for idx := range nc { // copy to upper
cx := coords.Int(idx, 0)
cy := coords.Int(idx, 1)
if cx == cy { // exclude diag
continue
}
out.SetFloat(out.Float(cx, cy), cy, cx)
}
return nil
}
// CovarianceMatrix generates the cells x cells square covariance matrix
// for all per-row cells of the given higher dimensional input tensor,
// which must have at least 2 dimensions: the outermost rows,
// and within that, 1+dimensional patterns (cells).
// Each value in the resulting matrix represents the extent to which the
// value of a given cell covaries across the rows of the tensor with the
// value of another cell.
// Uses the given metric function, typically [Covariance] or [Correlation],
// The metric function must have the [MetricFunc] signature.
// Use Covariance if vars have similar overall scaling,
// which is typical in neural network models, and use
// Correlation if they are on very different scales, because it effectively rescales).
// The resulting matrix can be used as the input to PCA or SVD eigenvalue decomposition.
func CovarianceMatrix(fun any, in tensor.Tensor) tensor.Values {
return tensor.CallOut1Gen1(CovarianceMatrixOut, fun, in)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:generate core generate
package metric
import (
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/tensor"
)
func init() {
tensor.AddFunc(MetricL2Norm.FuncName(), L2Norm)
tensor.AddFunc(MetricSumSquares.FuncName(), SumSquares)
tensor.AddFunc(MetricL1Norm.FuncName(), L1Norm)
tensor.AddFunc(MetricHamming.FuncName(), Hamming)
tensor.AddFunc(MetricL2NormBinTol.FuncName(), L2NormBinTol)
tensor.AddFunc(MetricSumSquaresBinTol.FuncName(), SumSquaresBinTol)
tensor.AddFunc(MetricInvCosine.FuncName(), InvCosine)
tensor.AddFunc(MetricInvCorrelation.FuncName(), InvCorrelation)
tensor.AddFunc(MetricDotProduct.FuncName(), DotProduct)
tensor.AddFunc(MetricCrossEntropy.FuncName(), CrossEntropy)
tensor.AddFunc(MetricCovariance.FuncName(), Covariance)
tensor.AddFunc(MetricCorrelation.FuncName(), Correlation)
tensor.AddFunc(MetricCosine.FuncName(), Cosine)
}
// Metrics are standard metric functions
type Metrics int32 //enums:enum -trim-prefix Metric
const (
// L2Norm is the square root of the sum of squares differences
// between tensor values, aka the Euclidean distance.
MetricL2Norm Metrics = iota
// SumSquares is the sum of squares differences between tensor values.
MetricSumSquares
// L1Norm is the sum of the absolute value of differences
// between tensor values, the L1 Norm.
MetricL1Norm
// Hamming is the sum of 1s for every element that is different,
// i.e., "city block" distance.
MetricHamming
// L2NormBinTol is the [L2Norm] square root of the sum of squares
// differences between tensor values, with binary tolerance:
// differences < 0.5 are thresholded to 0.
MetricL2NormBinTol
// SumSquaresBinTol is the [SumSquares] differences between tensor values,
// with binary tolerance: differences < 0.5 are thresholded to 0.
MetricSumSquaresBinTol
// InvCosine is 1-[Cosine], which is useful to convert it
// to an Increasing metric where more different vectors have larger metric values.
MetricInvCosine
// InvCorrelation is 1-[Correlation], which is useful to convert it
// to an Increasing metric where more different vectors have larger metric values.
MetricInvCorrelation
// CrossEntropy is a standard measure of the difference between two
// probabilty distributions, reflecting the additional entropy (uncertainty) associated
// with measuring probabilities under distribution b when in fact they come from
// distribution a. It is also the entropy of a plus the divergence between a from b,
// using Kullback-Leibler (KL) divergence. It is computed as:
// a * log(a/b) + (1-a) * log(1-a/1-b).
MetricCrossEntropy
//////// Everything below here is !Increasing -- larger = closer, not farther
// DotProduct is the sum of the co-products of the tensor values.
MetricDotProduct
// Covariance is co-variance between two vectors,
// i.e., the mean of the co-product of each vector element minus
// the mean of that vector: cov(A,B) = E[(A - E(A))(B - E(B))].
MetricCovariance
// Correlation is the standardized [Covariance] in the range (-1..1),
// computed as the mean of the co-product of each vector
// element minus the mean of that vector, normalized by the product of their
// standard deviations: cor(A,B) = E[(A - E(A))(B - E(B))] / sigma(A) sigma(B).
// Equivalent to the [Cosine] of mean-normalized vectors.
MetricCorrelation
// Cosine is high-dimensional angle between two vectors,
// in range (-1..1) as the normalized [DotProduct]:
// inner product / sqrt(ssA * ssB). See also [Correlation].
MetricCosine
)
// FuncName returns the package-qualified function name to use
// in tensor.Call to call this function.
func (m Metrics) FuncName() string {
return "metric." + m.String()
}
// Func returns function for given metric.
func (m Metrics) Func() MetricFunc {
fn := errors.Log1(tensor.FuncByName(m.FuncName()))
return fn.Fun.(MetricFunc)
}
// Call calls a standard Metrics enum function on given tensors.
// Output results are in the out tensor.
func (m Metrics) Call(a, b tensor.Tensor) tensor.Values {
return m.Func()(a, b)
}
// Increasing returns true if the distance metric is such that metric
// values increase as a function of distance (e.g., L2Norm)
// and false if metric values decrease as a function of distance
// (e.g., Cosine, Correlation)
func (m Metrics) Increasing() bool {
if m >= MetricDotProduct {
return false
}
return true
}
// AsMetricFunc returns given function as a [MetricFunc] function,
// or an error if it does not fit that signature.
func AsMetricFunc(fun any) (MetricFunc, error) {
mfun, ok := fun.(MetricFunc)
if !ok {
return nil, errors.New("metric.AsMetricFunc: function does not fit the MetricFunc signature")
}
return mfun, nil
}
// AsMetricOutFunc returns given function as a [MetricFunc] function,
// or an error if it does not fit that signature.
func AsMetricOutFunc(fun any) (MetricOutFunc, error) {
mfun, ok := fun.(MetricOutFunc)
if !ok {
return nil, errors.New("metric.AsMetricOutFunc: function does not fit the MetricOutFunc signature")
}
return mfun, nil
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package metric
import (
"math"
"cogentcore.org/lab/tensor"
)
// ClosestRow returns the closest fit between probe pattern and patterns in
// a "vocabulary" tensor with outermost row dimension, using given metric
// function, which must fit the MetricFunc signature.
// The metric *must have the Increasing property*, i.e., larger = further.
// Output is a 1D tensor with 2 elements: the row index and metric value for that row.
// Note: this does _not_ use any existing Indexes for the probe,
// but does for the vocab, and the returned index is the logical index
// into any existing Indexes.
func ClosestRow(fun any, probe, vocab tensor.Tensor) tensor.Values {
return tensor.CallOut2Gen1(ClosestRowOut, fun, probe, vocab)
}
// ClosestRowOut returns the closest fit between probe pattern and patterns in
// a "vocabulary" tensor with outermost row dimension, using given metric
// function, which must fit the MetricFunc signature.
// The metric *must have the Increasing property*, i.e., larger = further.
// Output is a 1D tensor with 2 elements: the row index and metric value for that row.
// Note: this does _not_ use any existing Indexes for the probe,
// but does for the vocab, and the returned index is the logical index
// into any existing Indexes.
func ClosestRowOut(fun any, probe, vocab tensor.Tensor, out tensor.Values) error {
out.SetShapeSizes(2)
mfun, err := AsMetricFunc(fun)
if err != nil {
return err
}
rows, _ := vocab.Shape().RowCellSize()
mi := -1
mind := math.MaxFloat64
pr1d := tensor.As1D(probe)
for ri := range rows {
sub := tensor.Cells1D(vocab, ri)
mout := mfun(pr1d, sub)
d := mout.Float1D(0)
if d < mind {
mi = ri
mind = d
}
}
out.SetFloat1D(float64(mi), 0)
out.SetFloat1D(mind, 1)
return nil
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package metric
import (
"cogentcore.org/lab/tensor"
)
// VectorizeOut64 is the general compute function for metric.
// This version makes a Float64 output tensor for aggregating
// and computing values, and then copies the results back to the
// original output. This allows metric functions to operate directly
// on integer valued inputs and produce sensible results.
// It returns the Float64 output tensor for further processing as needed.
// a and b are already enforced to be the same shape.
func VectorizeOut64(a, b tensor.Tensor, out tensor.Values, ini float64, fun func(a, b, agg float64) float64) *tensor.Float64 {
rows, cells := a.Shape().RowCellSize()
o64 := tensor.NewFloat64(cells)
if rows <= 0 {
return o64
}
if cells == 1 {
out.SetShapeSizes(1)
agg := ini
switch x := a.(type) {
case *tensor.Float64:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
agg = fun(x.Float1D(i), y.Float1D(i), agg)
}
case *tensor.Float32:
for i := range rows {
agg = fun(x.Float1D(i), y.Float1D(i), agg)
}
default:
for i := range rows {
agg = fun(x.Float1D(i), b.Float1D(i), agg)
}
}
case *tensor.Float32:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
agg = fun(x.Float1D(i), y.Float1D(i), agg)
}
case *tensor.Float32:
for i := range rows {
agg = fun(x.Float1D(i), y.Float1D(i), agg)
}
default:
for i := range rows {
agg = fun(x.Float1D(i), b.Float1D(i), agg)
}
}
default:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
agg = fun(a.Float1D(i), y.Float1D(i), agg)
}
case *tensor.Float32:
for i := range rows {
agg = fun(a.Float1D(i), y.Float1D(i), agg)
}
default:
for i := range rows {
agg = fun(a.Float1D(i), b.Float1D(i), agg)
}
}
}
o64.SetFloat1D(agg, 0)
out.SetFloat1D(agg, 0)
return o64
}
osz := tensor.CellsSize(a.ShapeSizes())
out.SetShapeSizes(osz...)
for i := range cells {
o64.SetFloat1D(ini, i)
}
switch x := a.(type) {
case *tensor.Float64:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), o64.Float1D(j)), j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), o64.Float1D(j)), j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), b.Float1D(si+j), o64.Float1D(j)), j)
}
}
}
case *tensor.Float32:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), o64.Float1D(j)), j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), o64.Float1D(j)), j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), b.Float1D(si+j), o64.Float1D(j)), j)
}
}
}
default:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(a.Float1D(si+j), y.Float1D(si+j), o64.Float1D(j)), j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(a.Float1D(si+j), y.Float1D(si+j), o64.Float1D(j)), j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(a.Float1D(si+j), b.Float1D(si+j), o64.Float1D(j)), j)
}
}
}
}
for j := range cells {
out.SetFloat1D(o64.Float1D(j), j)
}
return o64
}
// VectorizePreOut64 is a version of [VectorizeOut64] that takes additional
// tensor.Float64 inputs of pre-computed values, e.g., the means of each output cell.
func VectorizePreOut64(a, b tensor.Tensor, out tensor.Values, ini float64, preA, preB *tensor.Float64, fun func(a, b, preA, preB, agg float64) float64) *tensor.Float64 {
rows, cells := a.Shape().RowCellSize()
o64 := tensor.NewFloat64(cells)
if rows <= 0 {
return o64
}
if cells == 1 {
out.SetShapeSizes(1)
agg := ini
prevA := preA.Float1D(0)
prevB := preB.Float1D(0)
switch x := a.(type) {
case *tensor.Float64:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
agg = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, agg)
}
case *tensor.Float32:
for i := range rows {
agg = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, agg)
}
default:
for i := range rows {
agg = fun(x.Float1D(i), b.Float1D(i), prevA, prevB, agg)
}
}
case *tensor.Float32:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
agg = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, agg)
}
case *tensor.Float32:
for i := range rows {
agg = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, agg)
}
default:
for i := range rows {
agg = fun(x.Float1D(i), b.Float1D(i), prevA, prevB, agg)
}
}
default:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
agg = fun(a.Float1D(i), y.Float1D(i), prevA, prevB, agg)
}
case *tensor.Float32:
for i := range rows {
agg = fun(a.Float1D(i), y.Float1D(i), prevA, prevB, agg)
}
default:
for i := range rows {
agg = fun(a.Float1D(i), b.Float1D(i), prevA, prevB, agg)
}
}
}
o64.SetFloat1D(agg, 0)
out.SetFloat1D(agg, 0)
return o64
}
osz := tensor.CellsSize(a.ShapeSizes())
out.SetShapeSizes(osz...)
for j := range cells {
o64.SetFloat1D(ini, j)
}
switch x := a.(type) {
case *tensor.Float64:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), b.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j)
}
}
}
case *tensor.Float32:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(si+j), b.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j)
}
}
}
default:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(a.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(a.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
o64.SetFloat1D(fun(a.Float1D(si+j), b.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), o64.Float1D(j)), j)
}
}
}
}
for i := range cells {
out.SetFloat1D(o64.Float1D(i), i)
}
return o64
}
// Vectorize2Out64 is a version of [VectorizeOut64] that separately aggregates
// two output values, x and y as tensor.Float64.
func Vectorize2Out64(a, b tensor.Tensor, iniX, iniY float64, fun func(a, b, ox, oy float64) (float64, float64)) (ox64, oy64 *tensor.Float64) {
rows, cells := a.Shape().RowCellSize()
ox64 = tensor.NewFloat64(cells)
oy64 = tensor.NewFloat64(cells)
if rows <= 0 {
return
}
if cells == 1 {
ox := iniX
oy := iniY
switch x := a.(type) {
case *tensor.Float64:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
ox, oy = fun(x.Float1D(i), y.Float1D(i), ox, oy)
}
case *tensor.Float32:
for i := range rows {
ox, oy = fun(x.Float1D(i), y.Float1D(i), ox, oy)
}
default:
for i := range rows {
ox, oy = fun(x.Float1D(i), b.Float1D(i), ox, oy)
}
}
case *tensor.Float32:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
ox, oy = fun(x.Float1D(i), y.Float1D(i), ox, oy)
}
case *tensor.Float32:
for i := range rows {
ox, oy = fun(x.Float1D(i), y.Float1D(i), ox, oy)
}
default:
for i := range rows {
ox, oy = fun(x.Float1D(i), b.Float1D(i), ox, oy)
}
}
default:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
ox, oy = fun(a.Float1D(i), y.Float1D(i), ox, oy)
}
case *tensor.Float32:
for i := range rows {
ox, oy = fun(a.Float1D(i), y.Float1D(i), ox, oy)
}
default:
for i := range rows {
ox, oy = fun(a.Float1D(i), b.Float1D(i), ox, oy)
}
}
}
ox64.SetFloat1D(ox, 0)
oy64.SetFloat1D(oy, 0)
return
}
for j := range cells {
ox64.SetFloat1D(iniX, j)
oy64.SetFloat1D(iniY, j)
}
switch x := a.(type) {
case *tensor.Float64:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy := fun(x.Float1D(si+j), b.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
}
case *tensor.Float32:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy := fun(x.Float1D(si+j), b.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
}
default:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy := fun(a.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy := fun(a.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy := fun(a.Float1D(si+j), b.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
}
}
return
}
// Vectorize3Out64 is a version of [VectorizeOut64] that has 3 outputs instead of 1.
func Vectorize3Out64(a, b tensor.Tensor, iniX, iniY, iniZ float64, fun func(a, b, ox, oy, oz float64) (float64, float64, float64)) (ox64, oy64, oz64 *tensor.Float64) {
rows, cells := a.Shape().RowCellSize()
ox64 = tensor.NewFloat64(cells)
oy64 = tensor.NewFloat64(cells)
oz64 = tensor.NewFloat64(cells)
if rows <= 0 {
return
}
if cells == 1 {
ox := iniX
oy := iniY
oz := iniZ
switch x := a.(type) {
case *tensor.Float64:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), ox, oy, oz)
}
case *tensor.Float32:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), ox, oy, oz)
}
default:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), b.Float1D(i), ox, oy, oz)
}
}
case *tensor.Float32:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), ox, oy, oz)
}
case *tensor.Float32:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), ox, oy, oz)
}
default:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), b.Float1D(i), ox, oy, oz)
}
}
default:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
ox, oy, oz = fun(a.Float1D(i), y.Float1D(i), ox, oy, oz)
}
case *tensor.Float32:
for i := range rows {
ox, oy, oz = fun(a.Float1D(i), y.Float1D(i), ox, oy, oz)
}
default:
for i := range rows {
ox, oy, oz = fun(a.Float1D(i), b.Float1D(i), ox, oy, oz)
}
}
}
ox64.SetFloat1D(ox, 0)
oy64.SetFloat1D(oy, 0)
oz64.SetFloat1D(oz, 0)
return
}
for j := range cells {
ox64.SetFloat1D(iniX, j)
oy64.SetFloat1D(iniY, j)
oz64.SetFloat1D(iniZ, j)
}
switch x := a.(type) {
case *tensor.Float64:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), b.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
}
case *tensor.Float32:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), b.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
}
default:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(a.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(a.Float1D(si+j), y.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(a.Float1D(si+j), b.Float1D(si+j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
}
}
return
}
// VectorizePre3Out64 is a version of [VectorizePreOut64] that takes additional
// tensor.Float64 inputs of pre-computed values, e.g., the means of each output cell,
// and has 3 outputs instead of 1.
func VectorizePre3Out64(a, b tensor.Tensor, iniX, iniY, iniZ float64, preA, preB *tensor.Float64, fun func(a, b, preA, preB, ox, oy, oz float64) (float64, float64, float64)) (ox64, oy64, oz64 *tensor.Float64) {
rows, cells := a.Shape().RowCellSize()
ox64 = tensor.NewFloat64(cells)
oy64 = tensor.NewFloat64(cells)
oz64 = tensor.NewFloat64(cells)
if rows <= 0 {
return
}
if cells == 1 {
ox := iniX
oy := iniY
oz := iniZ
prevA := preA.Float1D(0)
prevB := preB.Float1D(0)
switch x := a.(type) {
case *tensor.Float64:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, ox, oy, oz)
}
case *tensor.Float32:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, ox, oy, oz)
}
default:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), b.Float1D(i), prevA, prevB, ox, oy, oz)
}
}
case *tensor.Float32:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, ox, oy, oz)
}
case *tensor.Float32:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), y.Float1D(i), prevA, prevB, ox, oy, oz)
}
default:
for i := range rows {
ox, oy, oz = fun(x.Float1D(i), b.Float1D(i), prevA, prevB, ox, oy, oz)
}
}
default:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
ox, oy, oz = fun(a.Float1D(i), y.Float1D(i), prevA, prevB, ox, oy, oz)
}
case *tensor.Float32:
for i := range rows {
ox, oy, oz = fun(a.Float1D(i), y.Float1D(i), prevA, prevB, ox, oy, oz)
}
default:
for i := range rows {
ox, oy, oz = fun(a.Float1D(i), b.Float1D(i), prevA, prevB, ox, oy, oz)
}
}
}
ox64.SetFloat1D(ox, 0)
oy64.SetFloat1D(oy, 0)
oz64.SetFloat1D(oz, 0)
return
}
for j := range cells {
ox64.SetFloat1D(iniX, j)
oy64.SetFloat1D(iniY, j)
oz64.SetFloat1D(iniZ, j)
}
switch x := a.(type) {
case *tensor.Float64:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), b.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
}
case *tensor.Float32:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(x.Float1D(si+j), b.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
}
default:
switch y := b.(type) {
case *tensor.Float64:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(a.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
case *tensor.Float32:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(a.Float1D(si+j), y.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
default:
for i := range rows {
si := i * cells
for j := range cells {
ox, oy, oz := fun(a.Float1D(si+j), b.Float1D(si+j), preA.Float1D(j), preB.Float1D(j), ox64.Float1D(j), oy64.Float1D(j), oz64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
oz64.SetFloat1D(oz, j)
}
}
}
}
return
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stats
import (
"strconv"
"cogentcore.org/core/base/metadata"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorfs"
)
// DescriptiveStats are the standard descriptive stats used in Describe function.
// Cannot apply the final 3 sort-based stats to higher-dimensional data.
var DescriptiveStats = []Stats{StatCount, StatMean, StatStd, StatSem, StatMin, StatQ1, StatMedian, StatQ3, StatMax}
// Describe adds standard descriptive statistics for given tensor
// to the given [tensorfs] directory, adding a directory for each tensor
// and result tensor stats for each result.
// This is an easy way to provide a comprehensive description of data.
// The [DescriptiveStats] list is: [Count], [Mean], [Std], [Sem],
// [Min], [Q1], [Median], [Q3], [Max]
func Describe(dir *tensorfs.Node, tsrs ...tensor.Tensor) {
dd := dir.Dir("Describe")
for i, tsr := range tsrs {
nr := tsr.DimSize(0)
if nr == 0 {
continue
}
nm := metadata.Name(tsr)
if nm == "" {
nm = strconv.Itoa(i)
}
td := dd.Dir(nm)
for _, st := range DescriptiveStats {
stnm := st.String()
sv := tensorfs.Scalar[float64](td, stnm)
stout := st.Call(tsr)
sv.CopyFrom(stout)
}
}
}
// DescribeTable runs [Describe] on given columns in table.
func DescribeTable(dir *tensorfs.Node, dt *table.Table, columns ...string) {
Describe(dir, dt.ColumnList(columns...)...)
}
// DescribeTableAll runs [Describe] on all numeric columns in given table.
func DescribeTableAll(dir *tensorfs.Node, dt *table.Table) {
var cols []string
for i, cl := range dt.Columns.Values {
if !cl.IsString() {
cols = append(cols, dt.ColumnName(i))
}
}
Describe(dir, dt.ColumnList(cols...)...)
}
// Code generated by "core generate"; DO NOT EDIT.
package stats
import (
"cogentcore.org/core/enums"
)
var _StatsValues = []Stats{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}
// StatsN is the highest valid value for type Stats, plus one.
const StatsN Stats = 22
var _StatsValueMap = map[string]Stats{`Count`: 0, `Sum`: 1, `L1Norm`: 2, `Prod`: 3, `Min`: 4, `Max`: 5, `MinAbs`: 6, `MaxAbs`: 7, `Mean`: 8, `Var`: 9, `Std`: 10, `Sem`: 11, `SumSq`: 12, `L2Norm`: 13, `VarPop`: 14, `StdPop`: 15, `SemPop`: 16, `Median`: 17, `Q1`: 18, `Q3`: 19, `First`: 20, `Final`: 21}
var _StatsDescMap = map[Stats]string{0: `count of number of elements.`, 1: `sum of elements.`, 2: `L1 Norm: sum of absolute values of elements.`, 3: `product of elements.`, 4: `minimum value.`, 5: `maximum value.`, 6: `minimum of absolute values.`, 7: `maximum of absolute values.`, 8: `mean value = sum / count.`, 9: `sample variance (squared deviations from mean, divided by n-1).`, 10: `sample standard deviation (sqrt of Var).`, 11: `sample standard error of the mean (Std divided by sqrt(n)).`, 12: `sum of squared values.`, 13: `L2 Norm: square-root of sum-of-squares.`, 14: `population variance (squared diffs from mean, divided by n).`, 15: `population standard deviation (sqrt of VarPop).`, 16: `population standard error of the mean (StdPop divided by sqrt(n)).`, 17: `middle value in sorted ordering.`, 18: `Q1 first quartile = 25%ile value = .25 quantile value.`, 19: `Q3 third quartile = 75%ile value = .75 quantile value.`, 20: `first item in the set of data: for data with a natural ordering.`, 21: `final item in the set of data: for data with a natural ordering.`}
var _StatsMap = map[Stats]string{0: `Count`, 1: `Sum`, 2: `L1Norm`, 3: `Prod`, 4: `Min`, 5: `Max`, 6: `MinAbs`, 7: `MaxAbs`, 8: `Mean`, 9: `Var`, 10: `Std`, 11: `Sem`, 12: `SumSq`, 13: `L2Norm`, 14: `VarPop`, 15: `StdPop`, 16: `SemPop`, 17: `Median`, 18: `Q1`, 19: `Q3`, 20: `First`, 21: `Final`}
// String returns the string representation of this Stats value.
func (i Stats) String() string { return enums.String(i, _StatsMap) }
// SetString sets the Stats value from its string representation,
// and returns an error if the string is invalid.
func (i *Stats) SetString(s string) error { return enums.SetString(i, s, _StatsValueMap, "Stats") }
// Int64 returns the Stats value as an int64.
func (i Stats) Int64() int64 { return int64(i) }
// SetInt64 sets the Stats value from an int64.
func (i *Stats) SetInt64(in int64) { *i = Stats(in) }
// Desc returns the description of the Stats value.
func (i Stats) Desc() string { return enums.Desc(i, _StatsDescMap) }
// StatsValues returns all possible values for the type Stats.
func StatsValues() []Stats { return _StatsValues }
// Values returns all possible values for the type Stats.
func (i Stats) Values() []enums.Enum { return enums.Values(_StatsValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i Stats) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *Stats) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Stats") }
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stats
import (
"math"
"cogentcore.org/lab/tensor"
)
// StatsFunc is the function signature for a stats function that
// returns a new output vector. This can be less efficient for repeated
// computations where the output can be re-used: see [StatsOutFunc].
// But this version can be directly chained with other function calls.
// Function is computed over the outermost row dimension and the
// output is the shape of the remaining inner cells (a scalar for 1D inputs).
// Use [tensor.As1D], [tensor.NewRowCellsView], [tensor.Cells1D] etc
// to reshape and reslice the data as needed.
// All stats functions skip over NaN's, as a missing value.
// Stats functions cannot be computed in parallel,
// e.g., using VectorizeThreaded or GPU, due to shared writing
// to the same output values. Special implementations are required
// if that is needed.
type StatsFunc = func(in tensor.Tensor) tensor.Values
// StatsOutFunc is the function signature for a stats function,
// that takes output values as final argument. See [StatsFunc]
// This version is for computationally demanding cases and saves
// reallocation of output.
type StatsOutFunc = func(in tensor.Tensor, out tensor.Values) error
// CountOut64 computes the count of non-NaN tensor values,
// and returns the Float64 output values for subsequent use.
func CountOut64(in tensor.Tensor, out tensor.Values) *tensor.Float64 {
return VectorizeOut64(in, out, 0, func(val, agg float64) float64 {
if math.IsNaN(val) {
return agg
}
return agg + 1
})
}
// Count computes the count of non-NaN tensor values.
// See [StatsFunc] for general information.
func Count(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(CountOut, in)
}
// CountOut computes the count of non-NaN tensor values.
// See [StatsOutFunc] for general information.
func CountOut(in tensor.Tensor, out tensor.Values) error {
CountOut64(in, out)
return nil
}
// SumOut64 computes the sum of tensor values,
// and returns the Float64 output values for subsequent use.
func SumOut64(in tensor.Tensor, out tensor.Values) *tensor.Float64 {
return VectorizeOut64(in, out, 0, func(val, agg float64) float64 {
if math.IsNaN(val) {
return agg
}
return agg + val
})
}
// SumOut computes the sum of tensor values.
// See [StatsOutFunc] for general information.
func SumOut(in tensor.Tensor, out tensor.Values) error {
SumOut64(in, out)
return nil
}
// Sum computes the sum of tensor values.
// See [StatsFunc] for general information.
func Sum(in tensor.Tensor) tensor.Values {
out := tensor.NewOfType(in.DataType())
SumOut64(in, out)
return out
}
// L1Norm computes the sum of absolute-value-of tensor values.
// See [StatsFunc] for general information.
func L1Norm(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(L1NormOut, in)
}
// L1NormOut computes the sum of absolute-value-of tensor values.
// See [StatsFunc] for general information.
func L1NormOut(in tensor.Tensor, out tensor.Values) error {
VectorizeOut64(in, out, 0, func(val, agg float64) float64 {
if math.IsNaN(val) {
return agg
}
return agg + math.Abs(val)
})
return nil
}
// Prod computes the product of tensor values.
// See [StatsFunc] for general information.
func Prod(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(ProdOut, in)
}
// ProdOut computes the product of tensor values.
// See [StatsOutFunc] for general information.
func ProdOut(in tensor.Tensor, out tensor.Values) error {
VectorizeOut64(in, out, 1, func(val, agg float64) float64 {
if math.IsNaN(val) {
return agg
}
return agg * val
})
return nil
}
// Min computes the min of tensor values.
// See [StatsFunc] for general information.
func Min(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(MinOut, in)
}
// MinOut computes the min of tensor values.
// See [StatsOutFunc] for general information.
func MinOut(in tensor.Tensor, out tensor.Values) error {
VectorizeOut64(in, out, math.MaxFloat64, func(val, agg float64) float64 {
if math.IsNaN(val) {
return agg
}
return math.Min(agg, val)
})
return nil
}
// Max computes the max of tensor values.
// See [StatsFunc] for general information.
func Max(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(MaxOut, in)
}
// MaxOut computes the max of tensor values.
// See [StatsOutFunc] for general information.
func MaxOut(in tensor.Tensor, out tensor.Values) error {
VectorizeOut64(in, out, -math.MaxFloat64, func(val, agg float64) float64 {
if math.IsNaN(val) {
return agg
}
return math.Max(agg, val)
})
return nil
}
// MinAbs computes the min of absolute-value-of tensor values.
// See [StatsFunc] for general information.
func MinAbs(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(MinAbsOut, in)
}
// MinAbsOut computes the min of absolute-value-of tensor values.
// See [StatsOutFunc] for general information.
func MinAbsOut(in tensor.Tensor, out tensor.Values) error {
VectorizeOut64(in, out, math.MaxFloat64, func(val, agg float64) float64 {
if math.IsNaN(val) {
return agg
}
return math.Min(agg, math.Abs(val))
})
return nil
}
// MaxAbs computes the max of absolute-value-of tensor values.
// See [StatsFunc] for general information.
func MaxAbs(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(MaxAbsOut, in)
}
// MaxAbsOut computes the max of absolute-value-of tensor values.
// See [StatsOutFunc] for general information.
func MaxAbsOut(in tensor.Tensor, out tensor.Values) error {
VectorizeOut64(in, out, -math.MaxFloat64, func(val, agg float64) float64 {
if math.IsNaN(val) {
return agg
}
return math.Max(agg, math.Abs(val))
})
return nil
}
// MeanOut64 computes the mean of tensor values,
// and returns the Float64 output values for subsequent use.
func MeanOut64(in tensor.Tensor, out tensor.Values) (mean64, count64 *tensor.Float64) {
var sum64 *tensor.Float64
sum64, count64 = Vectorize2Out64(in, 0, 0, func(val, sum, count float64) (float64, float64) {
if math.IsNaN(val) {
return sum, count
}
count += 1
sum += val
return sum, count
})
osz := tensor.CellsSize(in.ShapeSizes())
out.SetShapeSizes(osz...)
nsub := out.Len()
for i := range nsub {
c := count64.Float1D(i)
if c == 0 {
continue
}
mean := sum64.Float1D(i) / c
sum64.SetFloat1D(mean, i)
out.SetFloat1D(mean, i)
}
return sum64, count64
}
// Mean computes the mean of tensor values.
// See [StatsFunc] for general information.
func Mean(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(MeanOut, in)
}
// MeanOut computes the mean of tensor values.
// See [StatsOutFunc] for general information.
func MeanOut(in tensor.Tensor, out tensor.Values) error {
MeanOut64(in, out)
return nil
}
// SumSqDevOut64 computes the sum of squared mean deviates of tensor values,
// and returns the Float64 output values for subsequent use.
func SumSqDevOut64(in tensor.Tensor, out tensor.Values) (ssd64, mean64, count64 *tensor.Float64) {
mean64, count64 = MeanOut64(in, out)
ssd64 = VectorizePreOut64(in, out, 0, mean64, func(val, mean, agg float64) float64 {
if math.IsNaN(val) {
return agg
}
dv := val - mean
return agg + dv*dv
})
return
}
// VarOut64 computes the sample variance of tensor values,
// and returns the Float64 output values for subsequent use.
func VarOut64(in tensor.Tensor, out tensor.Values) (var64, mean64, count64 *tensor.Float64) {
var64, mean64, count64 = SumSqDevOut64(in, out)
nsub := out.Len()
for i := range nsub {
c := count64.Float1D(i)
if c < 2 {
continue
}
vr := var64.Float1D(i) / (c - 1)
var64.SetFloat1D(vr, i)
out.SetFloat1D(vr, i)
}
return
}
// Var computes the sample variance of tensor values.
// Squared deviations from mean, divided by n-1. See also [VarPopFunc].
// See [StatsFunc] for general information.
func Var(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(VarOut, in)
}
// VarOut computes the sample variance of tensor values.
// Squared deviations from mean, divided by n-1. See also [VarPopFunc].
// See [StatsOutFunc] for general information.
func VarOut(in tensor.Tensor, out tensor.Values) error {
VarOut64(in, out)
return nil
}
// StdOut64 computes the sample standard deviation of tensor values.
// and returns the Float64 output values for subsequent use.
func StdOut64(in tensor.Tensor, out tensor.Values) (std64, mean64, count64 *tensor.Float64) {
std64, mean64, count64 = VarOut64(in, out)
nsub := out.Len()
for i := range nsub {
std := math.Sqrt(std64.Float1D(i))
std64.SetFloat1D(std, i)
out.SetFloat1D(std, i)
}
return
}
// Std computes the sample standard deviation of tensor values.
// Sqrt of variance from [VarFunc]. See also [StdPopFunc].
// See [StatsFunc] for general information.
func Std(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(StdOut, in)
}
// StdOut computes the sample standard deviation of tensor values.
// Sqrt of variance from [VarFunc]. See also [StdPopFunc].
// See [StatsOutFunc] for general information.
func StdOut(in tensor.Tensor, out tensor.Values) error {
StdOut64(in, out)
return nil
}
// Sem computes the sample standard error of the mean of tensor values.
// Standard deviation [StdFunc] / sqrt(n). See also [SemPopFunc].
// See [StatsFunc] for general information.
func Sem(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(SemOut, in)
}
// SemOut computes the sample standard error of the mean of tensor values.
// Standard deviation [StdFunc] / sqrt(n). See also [SemPopFunc].
// See [StatsOutFunc] for general information.
func SemOut(in tensor.Tensor, out tensor.Values) error {
var64, _, count64 := VarOut64(in, out)
nsub := out.Len()
for i := range nsub {
c := count64.Float1D(i)
if c < 2 {
out.SetFloat1D(math.Sqrt(var64.Float1D(i)), i)
} else {
out.SetFloat1D(math.Sqrt(var64.Float1D(i))/math.Sqrt(c), i)
}
}
return nil
}
// VarPopOut64 computes the population variance of tensor values.
// and returns the Float64 output values for subsequent use.
func VarPopOut64(in tensor.Tensor, out tensor.Values) (var64, mean64, count64 *tensor.Float64) {
var64, mean64, count64 = SumSqDevOut64(in, out)
nsub := out.Len()
for i := range nsub {
c := count64.Float1D(i)
if c == 0 {
continue
}
var64.SetFloat1D(var64.Float1D(i)/c, i)
out.SetFloat1D(var64.Float1D(i), i)
}
return
}
// VarPop computes the population variance of tensor values.
// Squared deviations from mean, divided by n. See also [VarFunc].
// See [StatsFunc] for general information.
func VarPop(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(VarPopOut, in)
}
// VarPopOut computes the population variance of tensor values.
// Squared deviations from mean, divided by n. See also [VarFunc].
// See [StatsOutFunc] for general information.
func VarPopOut(in tensor.Tensor, out tensor.Values) error {
VarPopOut64(in, out)
return nil
}
// StdPop computes the population standard deviation of tensor values.
// Sqrt of variance from [VarPopFunc]. See also [StdFunc].
// See [StatsFunc] for general information.
func StdPop(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(StdPopOut, in)
}
// StdPopOut computes the population standard deviation of tensor values.
// Sqrt of variance from [VarPopFunc]. See also [StdFunc].
// See [StatsOutFunc] for general information.
func StdPopOut(in tensor.Tensor, out tensor.Values) error {
var64, _, _ := VarPopOut64(in, out)
nsub := out.Len()
for i := range nsub {
out.SetFloat1D(math.Sqrt(var64.Float1D(i)), i)
}
return nil
}
// SemPop computes the population standard error of the mean of tensor values.
// Standard deviation [StdPopFunc] / sqrt(n). See also [SemFunc].
// See [StatsFunc] for general information.
func SemPop(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(SemPopOut, in)
}
// SemPopOut computes the population standard error of the mean of tensor values.
// Standard deviation [StdPopFunc] / sqrt(n). See also [SemFunc].
// See [StatsOutFunc] for general information.
func SemPopOut(in tensor.Tensor, out tensor.Values) error {
var64, _, count64 := VarPopOut64(in, out)
nsub := out.Len()
for i := range nsub {
c := count64.Float1D(i)
if c < 2 {
out.SetFloat1D(math.Sqrt(var64.Float1D(i)), i)
} else {
out.SetFloat1D(math.Sqrt(var64.Float1D(i))/math.Sqrt(c), i)
}
}
return nil
}
// SumSqScaleOut64 is a helper for sum-of-squares, returning scale and ss
// factors aggregated separately for better numerical stability, per BLAS.
// Returns the Float64 output values for subsequent use.
func SumSqScaleOut64(in tensor.Tensor) (scale64, ss64 *tensor.Float64) {
scale64, ss64 = Vectorize2Out64(in, 0, 1, func(val, scale, ss float64) (float64, float64) {
if math.IsNaN(val) || val == 0 {
return scale, ss
}
absxi := math.Abs(val)
if scale < absxi {
ss = 1 + ss*(scale/absxi)*(scale/absxi)
scale = absxi
} else {
ss = ss + (absxi/scale)*(absxi/scale)
}
return scale, ss
})
return
}
// SumSqOut64 computes the sum of squares of tensor values,
// and returns the Float64 output values for subsequent use.
func SumSqOut64(in tensor.Tensor, out tensor.Values) *tensor.Float64 {
scale64, ss64 := SumSqScaleOut64(in)
osz := tensor.CellsSize(in.ShapeSizes())
out.SetShapeSizes(osz...)
nsub := out.Len()
for i := range nsub {
scale := scale64.Float1D(i)
ss := ss64.Float1D(i)
v := 0.0
if math.IsInf(scale, 1) {
v = math.Inf(1)
} else {
v = scale * scale * ss
}
scale64.SetFloat1D(v, i)
out.SetFloat1D(v, i)
}
return scale64
}
// SumSq computes the sum of squares of tensor values,
// See [StatsFunc] for general information.
func SumSq(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(SumSqOut, in)
}
// SumSqOut computes the sum of squares of tensor values,
// See [StatsOutFunc] for general information.
func SumSqOut(in tensor.Tensor, out tensor.Values) error {
SumSqOut64(in, out)
return nil
}
// L2NormOut64 computes the square root of the sum of squares of tensor values,
// known as the L2 norm, and returns the Float64 output values for
// use in subsequent computations.
func L2NormOut64(in tensor.Tensor, out tensor.Values) *tensor.Float64 {
scale64, ss64 := SumSqScaleOut64(in)
osz := tensor.CellsSize(in.ShapeSizes())
out.SetShapeSizes(osz...)
nsub := out.Len()
for i := range nsub {
scale := scale64.Float1D(i)
ss := ss64.Float1D(i)
v := 0.0
if math.IsInf(scale, 1) {
v = math.Inf(1)
} else {
v = scale * math.Sqrt(ss)
}
scale64.SetFloat1D(v, i)
out.SetFloat1D(v, i)
}
return scale64
}
// L2Norm computes the square root of the sum of squares of tensor values,
// known as the L2 norm.
// See [StatsFunc] for general information.
func L2Norm(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(L2NormOut, in)
}
// L2NormOut computes the square root of the sum of squares of tensor values,
// known as the L2 norm.
// See [StatsOutFunc] for general information.
func L2NormOut(in tensor.Tensor, out tensor.Values) error {
L2NormOut64(in, out)
return nil
}
// First returns the first tensor value(s), as a stats function,
// for the starting point in a naturally-ordered set of data.
// See [StatsFunc] for general information.
func First(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(FirstOut, in)
}
// FirstOut returns the first tensor value(s), as a stats function,
// for the starting point in a naturally-ordered set of data.
// See [StatsOutFunc] for general information.
func FirstOut(in tensor.Tensor, out tensor.Values) error {
rows, cells := in.Shape().RowCellSize()
if cells == 1 {
out.SetShapeSizes(1)
if rows > 0 {
out.SetFloat1D(in.Float1D(0), 0)
}
return nil
}
osz := tensor.CellsSize(in.ShapeSizes())
out.SetShapeSizes(osz...)
if rows == 0 {
return nil
}
for i := range cells {
out.SetFloat1D(in.Float1D(i), i)
}
return nil
}
// Final returns the final tensor value(s), as a stats function,
// for the ending point in a naturally-ordered set of data.
// See [StatsFunc] for general information.
func Final(in tensor.Tensor) tensor.Values {
return tensor.CallOut1(FinalOut, in)
}
// FinalOut returns the first tensor value(s), as a stats function,
// for the ending point in a naturally-ordered set of data.
// See [StatsOutFunc] for general information.
func FinalOut(in tensor.Tensor, out tensor.Values) error {
rows, cells := in.Shape().RowCellSize()
if cells == 1 {
out.SetShapeSizes(1)
if rows > 0 {
out.SetFloat1D(in.Float1D(rows-1), 0)
}
return nil
}
osz := tensor.CellsSize(in.ShapeSizes())
out.SetShapeSizes(osz...)
if rows == 0 {
return nil
}
st := (rows - 1) * cells
for i := range cells {
out.SetFloat1D(in.Float1D(st+i), i)
}
return nil
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stats
import (
"strconv"
"strings"
"cogentcore.org/core/base/metadata"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensorfs"
)
// Groups generates indexes for each unique value in each of the given tensors.
// One can then use the resulting indexes for the [tensor.Rows] indexes to
// perform computations restricted to grouped subsets of data, as in the
// [GroupStats] function. See [GroupCombined] for function that makes a
// "Combined" Group that has a unique group for each _combination_ of
// the separate, independent groups created by this function.
// It creates subdirectories in a "Groups" directory within given [tensorfs],
// for each tensor passed in here, using the metadata Name property for
// names (index if empty).
// Within each subdirectory there are int tensors for each unique 1D
// row-wise value of elements in the input tensor, named as the string
// representation of the value, where the int tensor contains a list of
// row-wise indexes corresponding to the source rows having that value.
// Note that these indexes are directly in terms of the underlying [Tensor] data
// rows, indirected through any existing indexes on the inputs, so that
// the results can be used directly as Indexes into the corresponding tensor data.
// Uses a stable sort on columns, so ordering of other dimensions is preserved.
func Groups(dir *tensorfs.Node, tsrs ...tensor.Tensor) error {
gd := dir.Dir("Groups")
makeIdxs := func(dir *tensorfs.Node, srt *tensor.Rows, val string, start, r int) {
n := r - start
it := tensorfs.Value[int](dir, val, n)
for j := range n {
it.SetIntRow(srt.Indexes[start+j], j, 0) // key to indirect through sort indexes
}
}
for i, tsr := range tsrs {
nr := tsr.DimSize(0)
if nr == 0 {
continue
}
nm := metadata.Name(tsr)
if nm == "" {
nm = strconv.Itoa(i)
}
td := gd.Dir(nm)
srt := tensor.AsRows(tsr).CloneIndexes()
srt.SortStable(tensor.Ascending)
start := 0
if tsr.IsString() {
lastVal := srt.StringRow(0, 0)
for r := range nr {
v := srt.StringRow(r, 0)
if v != lastVal {
makeIdxs(td, srt, lastVal, start, r)
start = r
lastVal = v
}
}
if start != nr-1 {
makeIdxs(td, srt, lastVal, start, nr)
}
} else {
lastVal := srt.FloatRow(0, 0)
for r := range nr {
v := srt.FloatRow(r, 0)
if v != lastVal {
makeIdxs(td, srt, tensor.Float64ToString(lastVal), start, r)
start = r
lastVal = v
}
}
if start != nr-1 {
makeIdxs(td, srt, tensor.Float64ToString(lastVal), start, nr)
}
}
}
return nil
}
// TableGroups runs [Groups] on the given columns from given [table.Table].
func TableGroups(dir *tensorfs.Node, dt *table.Table, columns ...string) error {
dv := table.NewView(dt)
// important for consistency across columns, to do full outer product sort first.
dv.SortColumns(tensor.Ascending, tensor.StableSort, columns...)
return Groups(dir, dv.ColumnList(columns...)...)
}
// GroupAll copies all indexes from the first given tensor,
// into an "All/All" tensor in the given [tensorfs], which can then
// be used with [GroupStats] to generate summary statistics across
// all the data. See [Groups] for more general documentation.
func GroupAll(dir *tensorfs.Node, tsrs ...tensor.Tensor) error {
gd := dir.Dir("Groups")
tsr := tensor.AsRows(tsrs[0])
nr := tsr.NumRows()
if nr == 0 {
return nil
}
td := gd.Dir("All")
it := tensorfs.Value[int](td, "All", nr)
for j := range nr {
it.SetIntRow(tsr.RowIndex(j), j, 0) // key to indirect through any existing indexes
}
return nil
}
// todo: GroupCombined
// GroupStats computes the given stats function on the unique grouped indexes
// produced by the [Groups] function, in the given [tensorfs] directory,
// applied to each of the tensors passed here.
// It creates a "Stats" subdirectory in given directory, with
// subdirectories with the name of each value tensor (if it does not
// yet exist), and then creates a subdirectory within that
// for the statistic name. Within that statistic directory, it creates
// a String tensor with the unique values of each source [Groups] tensor,
// and a aligned Float64 tensor with the statistics results for each such
// unique group value. See the README.md file for a diagram of the results.
func GroupStats(dir *tensorfs.Node, stat Stats, tsrs ...tensor.Tensor) error {
gd := dir.Dir("Groups")
sd := dir.Dir("Stats")
stnm := StripPackage(stat.String())
groups, _ := gd.Nodes()
for _, gp := range groups {
gpnm := gp.Name()
ggd := gd.Dir(gpnm)
vals := ggd.ValuesFunc(nil)
nv := len(vals)
if nv == 0 {
continue
}
sgd := sd.Dir(gpnm)
gv := sgd.Node(gpnm)
if gv == nil {
gtsr := tensorfs.Value[string](sgd, gpnm, nv)
for i, v := range vals {
gtsr.SetStringRow(metadata.Name(v), i, 0)
}
}
for _, tsr := range tsrs {
vd := sgd.Dir(metadata.Name(tsr))
sv := tensorfs.Value[float64](vd, stnm, nv)
for i, v := range vals {
idx := tensor.AsIntSlice(v)
sg := tensor.NewRows(tsr.AsValues(), idx...)
stout := stat.Call(sg)
sv.SetFloatRow(stout.Float1D(0), i, 0)
}
}
}
return nil
}
// TableGroupStats runs [GroupStats] using standard [Stats]
// on the given columns from given [table.Table].
func TableGroupStats(dir *tensorfs.Node, stat Stats, dt *table.Table, columns ...string) error {
return GroupStats(dir, stat, dt.ColumnList(columns...)...)
}
// GroupDescribe runs standard descriptive statistics on given tensor data
// using [GroupStats] function, with [DescriptiveStats] list of stats.
func GroupDescribe(dir *tensorfs.Node, tsrs ...tensor.Tensor) error {
for _, st := range DescriptiveStats {
err := GroupStats(dir, st, tsrs...)
if err != nil {
return err
}
}
return nil
}
// TableGroupDescribe runs [GroupDescribe] on the given columns from given [table.Table].
func TableGroupDescribe(dir *tensorfs.Node, dt *table.Table, columns ...string) error {
return GroupDescribe(dir, dt.ColumnList(columns...)...)
}
// GroupStatsAsTable returns the results from [GroupStats] in given directory
// as a [table.Table], using [tensorfs.DirTable] function.
func GroupStatsAsTable(dir *tensorfs.Node) *table.Table {
return tensorfs.DirTable(dir.Node("Stats"), nil)
}
// GroupStatsAsTableNoStatName returns the results from [GroupStats]
// in given directory as a [table.Table], using [tensorfs.DirTable] function.
// Column names are updated to not include the stat name, if there is only
// one statistic such that the resulting name will still be unique.
// Otherwise, column names are Value/Stat.
func GroupStatsAsTableNoStatName(dir *tensorfs.Node) *table.Table {
dt := tensorfs.DirTable(dir.Node("Stats"), nil)
cols := make(map[string]string)
for _, nm := range dt.Columns.Keys {
vn := nm
si := strings.Index(nm, "/")
if si > 0 {
vn = nm[:si]
}
if _, exists := cols[vn]; exists {
continue
}
cols[vn] = nm
}
for k, v := range cols {
ci := dt.Columns.IndexByKey(v)
dt.Columns.RenameIndex(ci, k)
}
return dt
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stats
import (
"cogentcore.org/core/math32"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensor/tmath"
)
// ZScore computes Z-normalized values into output tensor,
// subtracting the Mean and dividing by the standard deviation.
func ZScore(a tensor.Tensor) tensor.Values {
return tensor.CallOut1(ZScoreOut, a)
}
// ZScore computes Z-normalized values into given output tensor,
// subtracting the Mean and dividing by the standard deviation.
func ZScoreOut(a tensor.Tensor, out tensor.Values) error {
mout := tensor.NewFloat64()
std, mean, _ := StdOut64(a, mout)
tmath.SubOut(a, mean, out)
tmath.DivOut(out, std, out)
return nil
}
// UnitNorm computes unit normalized values into given output tensor,
// subtracting the Min value and dividing by the Max of the remaining numbers.
func UnitNorm(a tensor.Tensor) tensor.Values {
return tensor.CallOut1(UnitNormOut, a)
}
// UnitNormOut computes unit normalized values into given output tensor,
// subtracting the Min value and dividing by the Max of the remaining numbers.
func UnitNormOut(a tensor.Tensor, out tensor.Values) error {
mout := tensor.NewFloat64()
err := MinOut(a, mout)
if err != nil {
return err
}
tmath.SubOut(a, mout, out)
MaxOut(out, mout)
tmath.DivOut(out, mout, out)
return nil
}
// Clamp ensures that all values are within min, max limits, clamping
// values to those bounds if they exceed them. min and max args are
// treated as scalars (first value used).
func Clamp(in, minv, maxv tensor.Tensor) tensor.Values {
return tensor.CallOut3(ClampOut, in, minv, minv)
}
// ClampOut ensures that all values are within min, max limits, clamping
// values to those bounds if they exceed them. min and max args are
// treated as scalars (first value used).
func ClampOut(in, minv, maxv tensor.Tensor, out tensor.Values) error {
tensor.SetShapeFrom(out, in)
mn := minv.Float1D(0)
mx := maxv.Float1D(0)
tensor.VectorizeThreaded(1, tensor.NFirstLen, func(idx int, tsr ...tensor.Tensor) {
tsr[1].SetFloat1D(math32.Clamp(tsr[0].Float1D(idx), mn, mx), idx)
}, in, out)
return nil
}
// Binarize results in a binary-valued output by setting
// values >= the threshold to 1, else 0. threshold is
// treated as a scalar (first value used).
func Binarize(in, threshold tensor.Tensor) tensor.Values {
return tensor.CallOut2(BinarizeOut, in, threshold)
}
// BinarizeOut results in a binary-valued output by setting
// values >= the threshold to 1, else 0. threshold is
// treated as a scalar (first value used).
func BinarizeOut(in, threshold tensor.Tensor, out tensor.Values) error {
tensor.SetShapeFrom(out, in)
thr := threshold.Float1D(0)
tensor.VectorizeThreaded(1, tensor.NFirstLen, func(idx int, tsr ...tensor.Tensor) {
v := tsr[0].Float1D(idx)
if v >= thr {
v = 1
} else {
v = 0
}
tsr[1].SetFloat1D(v, idx)
}, in, out)
return nil
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stats
import (
"math"
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/tensor"
)
// Quantiles returns the given quantile(s) of non-NaN elements in given
// 1D tensor. Because sorting uses indexes, this only works for 1D case.
// If needed for a sub-space of values, that can be extracted through slicing
// and then used. Logs an error if not 1D.
// qs are 0-1 values, 0 = min, 1 = max, .5 = median, etc.
// Uses linear interpolation.
// Because this requires a sort, it is more efficient to get as many quantiles
// as needed in one pass.
func Quantiles(in, qs tensor.Tensor) tensor.Values {
return tensor.CallOut2(QuantilesOut, in, qs)
}
// QuantilesOut returns the given quantile(s) of non-NaN elements in given
// 1D tensor. Because sorting uses indexes, this only works for 1D case.
// If needed for a sub-space of values, that can be extracted through slicing
// and then used. Returns and logs an error if not 1D.
// qs are 0-1 values, 0 = min, 1 = max, .5 = median, etc.
// Uses linear interpolation.
// Because this requires a sort, it is more efficient to get as many quantiles
// as needed in one pass.
func QuantilesOut(in, qs tensor.Tensor, out tensor.Values) error {
if in.NumDims() != 1 {
return errors.Log(errors.New("stats.QuantilesFunc: only 1D input tensors allowed"))
}
if qs.NumDims() != 1 {
return errors.Log(errors.New("stats.QuantilesFunc: only 1D quantile tensors allowed"))
}
tensor.SetShapeFrom(out, in)
sin := tensor.AsRows(in.AsValues())
sin.ExcludeMissing()
sin.Sort(tensor.Ascending)
sz := len(sin.Indexes) - 1 // length of our own index list
if sz <= 0 {
out.(tensor.Values).SetZeros()
return nil
}
fsz := float64(sz)
nq := qs.Len()
for i := range nq {
q := qs.Float1D(i)
val := 0.0
qi := q * fsz
lwi := math.Floor(qi)
lwii := int(lwi)
if lwii >= sz {
val = sin.FloatRow(sz, 0)
} else if lwii < 0 {
val = sin.FloatRow(0, 0)
} else {
phi := qi - lwi
lwv := sin.FloatRow(lwii, 0)
hiv := sin.FloatRow(lwii+1, 0)
val = (1-phi)*lwv + phi*hiv
}
out.SetFloat1D(val, i)
}
return nil
}
// Median computes the median (50% quantile) of tensor values.
// See [StatsFunc] for general information.
func Median(in tensor.Tensor) tensor.Values {
return Quantiles(in, tensor.NewFloat64Scalar(.5))
}
// Q1 computes the first quantile (25%) of tensor values.
// See [StatsFunc] for general information.
func Q1(in tensor.Tensor) tensor.Values {
return Quantiles(in, tensor.NewFloat64Scalar(.25))
}
// Q3 computes the third quantile (75%) of tensor values.
// See [StatsFunc] for general information.
func Q3(in tensor.Tensor) tensor.Values {
return Quantiles(in, tensor.NewFloat64Scalar(.75))
}
// MedianOut computes the median (50% quantile) of tensor values.
// See [StatsFunc] for general information.
func MedianOut(in tensor.Tensor, out tensor.Values) error {
return QuantilesOut(in, tensor.NewFloat64Scalar(.5), out)
}
// Q1Out computes the first quantile (25%) of tensor values.
// See [StatsFunc] for general information.
func Q1Out(in tensor.Tensor, out tensor.Values) error {
return QuantilesOut(in, tensor.NewFloat64Scalar(.25), out)
}
// Q3Out computes the third quantile (75%) of tensor values.
// See [StatsFunc] for general information.
func Q3Out(in tensor.Tensor, out tensor.Values) error {
return QuantilesOut(in, tensor.NewFloat64Scalar(.75), out)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stats
import (
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/tensor"
)
//go:generate core generate
func init() {
tensor.AddFunc(StatCount.FuncName(), Count)
tensor.AddFunc(StatSum.FuncName(), Sum)
tensor.AddFunc(StatL1Norm.FuncName(), L1Norm)
tensor.AddFunc(StatProd.FuncName(), Prod)
tensor.AddFunc(StatMin.FuncName(), Min)
tensor.AddFunc(StatMax.FuncName(), Max)
tensor.AddFunc(StatMinAbs.FuncName(), MinAbs)
tensor.AddFunc(StatMaxAbs.FuncName(), MaxAbs)
tensor.AddFunc(StatMean.FuncName(), Mean)
tensor.AddFunc(StatVar.FuncName(), Var)
tensor.AddFunc(StatStd.FuncName(), Std)
tensor.AddFunc(StatSem.FuncName(), Sem)
tensor.AddFunc(StatSumSq.FuncName(), SumSq)
tensor.AddFunc(StatL2Norm.FuncName(), L2Norm)
tensor.AddFunc(StatVarPop.FuncName(), VarPop)
tensor.AddFunc(StatStdPop.FuncName(), StdPop)
tensor.AddFunc(StatSemPop.FuncName(), SemPop)
tensor.AddFunc(StatMedian.FuncName(), Median)
tensor.AddFunc(StatQ1.FuncName(), Q1)
tensor.AddFunc(StatQ3.FuncName(), Q3)
tensor.AddFunc(StatFirst.FuncName(), First)
tensor.AddFunc(StatFinal.FuncName(), Final)
}
// Stats is a list of different standard aggregation functions, which can be used
// to choose an aggregation function
type Stats int32 //enums:enum -trim-prefix Stat
const (
// count of number of elements.
StatCount Stats = iota
// sum of elements.
StatSum
// L1 Norm: sum of absolute values of elements.
StatL1Norm
// product of elements.
StatProd
// minimum value.
StatMin
// maximum value.
StatMax
// minimum of absolute values.
StatMinAbs
// maximum of absolute values.
StatMaxAbs
// mean value = sum / count.
StatMean
// sample variance (squared deviations from mean, divided by n-1).
StatVar
// sample standard deviation (sqrt of Var).
StatStd
// sample standard error of the mean (Std divided by sqrt(n)).
StatSem
// sum of squared values.
StatSumSq
// L2 Norm: square-root of sum-of-squares.
StatL2Norm
// population variance (squared diffs from mean, divided by n).
StatVarPop
// population standard deviation (sqrt of VarPop).
StatStdPop
// population standard error of the mean (StdPop divided by sqrt(n)).
StatSemPop
// middle value in sorted ordering.
StatMedian
// Q1 first quartile = 25%ile value = .25 quantile value.
StatQ1
// Q3 third quartile = 75%ile value = .75 quantile value.
StatQ3
// first item in the set of data: for data with a natural ordering.
StatFirst
// final item in the set of data: for data with a natural ordering.
StatFinal
)
// FuncName returns the package-qualified function name to use
// in tensor.Call to call this function.
func (s Stats) FuncName() string {
return "stats." + s.String()
}
// Func returns function for given stat.
func (s Stats) Func() StatsFunc {
fn := errors.Log1(tensor.FuncByName(s.FuncName()))
return fn.Fun.(StatsFunc)
}
// Call calls this statistic function on given tensors.
// returning output as a newly created tensor.
func (s Stats) Call(in tensor.Tensor) tensor.Values {
return s.Func()(in)
}
// StripPackage removes any package name from given string,
// used for naming based on FuncName() which could be custom
// or have a package prefix.
func StripPackage(name string) string {
spl := strings.Split(name, ".")
if len(spl) > 1 {
return spl[len(spl)-1]
}
return name
}
// AsStatsFunc returns given function as a [StatsFunc] function,
// or an error if it does not fit that signature.
func AsStatsFunc(fun any) (StatsFunc, error) {
sfun, ok := fun.(StatsFunc)
if !ok {
return nil, errors.New("metric.AsStatsFunc: function does not fit the StatsFunc signature")
}
return sfun, nil
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stats
import (
"reflect"
"cogentcore.org/lab/table"
)
// MeanTables returns a [table.Table] with the mean values across all float
// columns of the input tables, which must have the same columns but not
// necessarily the same number of rows.
func MeanTables(dts []*table.Table) *table.Table {
nt := len(dts)
if nt == 0 {
return nil
}
maxRows := 0
var maxdt *table.Table
for _, dt := range dts {
nr := dt.NumRows()
if nr > maxRows {
maxRows = nr
maxdt = dt
}
}
if maxRows == 0 {
return nil
}
ot := maxdt.Clone()
// N samples per row
rns := make([]int, maxRows)
for _, dt := range dts {
dnr := dt.NumRows()
mx := min(dnr, maxRows)
for ri := 0; ri < mx; ri++ {
rns[ri]++
}
}
for ci := range ot.Columns.Values {
cl := ot.ColumnByIndex(ci)
if cl.DataType() != reflect.Float32 && cl.DataType() != reflect.Float64 {
continue
}
_, cells := cl.RowCellSize()
for di, dt := range dts {
if di == 0 {
continue
}
dc := dt.ColumnByIndex(ci)
dnr := dt.NumRows()
mx := min(dnr, maxRows)
for ri := 0; ri < mx; ri++ {
for j := 0; j < cells; j++ {
cv := cl.FloatRow(ri, j)
cv += dc.FloatRow(ri, j)
cl.SetFloatRow(cv, ri, j)
}
}
}
for ri := 0; ri < maxRows; ri++ {
for j := 0; j < cells; j++ {
cv := cl.FloatRow(ri, j)
if rns[ri] > 0 {
cv /= float64(rns[ri])
cl.SetFloatRow(cv, ri, j)
}
}
}
}
return ot
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stats
import (
"cogentcore.org/lab/tensor"
)
// VectorizeOut64 is the general compute function for stats.
// This version makes a Float64 output tensor for aggregating
// and computing values, and then copies the results back to the
// original output. This allows stats functions to operate directly
// on integer valued inputs and produce sensible results.
// It returns the Float64 output tensor for further processing as needed.
func VectorizeOut64(in tensor.Tensor, out tensor.Values, ini float64, fun func(val, agg float64) float64) *tensor.Float64 {
rows, cells := in.Shape().RowCellSize()
o64 := tensor.NewFloat64(cells)
if rows <= 0 {
return o64
}
if cells == 1 {
out.SetShapeSizes(1)
agg := ini
switch x := in.(type) {
case *tensor.Float64:
for i := range rows {
agg = fun(x.Float1D(i), agg)
}
case *tensor.Float32:
for i := range rows {
agg = fun(x.Float1D(i), agg)
}
default:
for i := range rows {
agg = fun(in.Float1D(i), agg)
}
}
o64.SetFloat1D(agg, 0)
out.SetFloat1D(agg, 0)
return o64
}
osz := tensor.CellsSize(in.ShapeSizes())
out.SetShapeSizes(osz...)
for i := range cells {
o64.SetFloat1D(ini, i)
}
switch x := in.(type) {
case *tensor.Float64:
for i := range rows {
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(i*cells+j), o64.Float1D(j)), j)
}
}
case *tensor.Float32:
for i := range rows {
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(i*cells+j), o64.Float1D(j)), j)
}
}
default:
for i := range rows {
for j := range cells {
o64.SetFloat1D(fun(in.Float1D(i*cells+j), o64.Float1D(j)), j)
}
}
}
for j := range cells {
out.SetFloat1D(o64.Float1D(j), j)
}
return o64
}
// VectorizePreOut64 is a version of [VectorizeOut64] that takes an additional
// tensor.Float64 input of pre-computed values, e.g., the means of each output cell.
func VectorizePreOut64(in tensor.Tensor, out tensor.Values, ini float64, pre *tensor.Float64, fun func(val, pre, agg float64) float64) *tensor.Float64 {
rows, cells := in.Shape().RowCellSize()
o64 := tensor.NewFloat64(cells)
if rows <= 0 {
return o64
}
if cells == 1 {
out.SetShapeSizes(1)
agg := ini
prev := pre.Float1D(0)
switch x := in.(type) {
case *tensor.Float64:
for i := range rows {
agg = fun(x.Float1D(i), prev, agg)
}
case *tensor.Float32:
for i := range rows {
agg = fun(x.Float1D(i), prev, agg)
}
default:
for i := range rows {
agg = fun(in.Float1D(i), prev, agg)
}
}
o64.SetFloat1D(agg, 0)
out.SetFloat1D(agg, 0)
return o64
}
osz := tensor.CellsSize(in.ShapeSizes())
out.SetShapeSizes(osz...)
for j := range cells {
o64.SetFloat1D(ini, j)
}
switch x := in.(type) {
case *tensor.Float64:
for i := range rows {
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(i*cells+j), pre.Float1D(j), o64.Float1D(j)), j)
}
}
case *tensor.Float32:
for i := range rows {
for j := range cells {
o64.SetFloat1D(fun(x.Float1D(i*cells+j), pre.Float1D(j), o64.Float1D(j)), j)
}
}
default:
for i := range rows {
for j := range cells {
o64.SetFloat1D(fun(in.Float1D(i*cells+j), pre.Float1D(j), o64.Float1D(j)), j)
}
}
}
for i := range cells {
out.SetFloat1D(o64.Float1D(i), i)
}
return o64
}
// Vectorize2Out64 is a version of [VectorizeOut64] that separately aggregates
// two output values, x and y as tensor.Float64.
func Vectorize2Out64(in tensor.Tensor, iniX, iniY float64, fun func(val, ox, oy float64) (float64, float64)) (ox64, oy64 *tensor.Float64) {
rows, cells := in.Shape().RowCellSize()
ox64 = tensor.NewFloat64(cells)
oy64 = tensor.NewFloat64(cells)
if rows <= 0 {
return ox64, oy64
}
if cells == 1 {
ox := iniX
oy := iniY
switch x := in.(type) {
case *tensor.Float64:
for i := range rows {
ox, oy = fun(x.Float1D(i), ox, oy)
}
case *tensor.Float32:
for i := range rows {
ox, oy = fun(x.Float1D(i), ox, oy)
}
default:
for i := range rows {
ox, oy = fun(in.Float1D(i), ox, oy)
}
}
ox64.SetFloat1D(ox, 0)
oy64.SetFloat1D(oy, 0)
return
}
for j := range cells {
ox64.SetFloat1D(iniX, j)
oy64.SetFloat1D(iniY, j)
}
switch x := in.(type) {
case *tensor.Float64:
for i := range rows {
for j := range cells {
ox, oy := fun(x.Float1D(i*cells+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
case *tensor.Float32:
for i := range rows {
for j := range cells {
ox, oy := fun(x.Float1D(i*cells+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
default:
for i := range rows {
for j := range cells {
ox, oy := fun(in.Float1D(i*cells+j), ox64.Float1D(j), oy64.Float1D(j))
ox64.SetFloat1D(ox, j)
oy64.SetFloat1D(oy, j)
}
}
}
return
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package table
import (
"cogentcore.org/core/base/keylist"
"cogentcore.org/core/base/metadata"
"cogentcore.org/lab/tensor"
)
// Columns is the underlying column list and number of rows for Table.
// Each column is a raw [tensor.Values] tensor, and [Table]
// provides a [tensor.Rows] indexed view onto the Columns.
type Columns struct {
keylist.List[string, tensor.Values]
// number of rows, which is enforced to be the size of the
// outermost row dimension of the column tensors.
Rows int `edit:"-"`
}
// NewColumns returns a new Columns.
func NewColumns() *Columns {
return &Columns{}
}
// SetNumRows sets the number of rows in the table, across all columns.
// It is safe to set this to 0. For incrementally growing tables (e.g., a log)
// it is best to first set the anticipated full size, which allocates the
// full amount of memory, and then set to 0 and grow incrementally.
func (cl *Columns) SetNumRows(rows int) *Columns { //types:add
cl.Rows = rows // can be 0
for _, tsr := range cl.Values {
tsr.SetNumRows(rows)
}
return cl
}
// AddColumn adds the given tensor (as a [tensor.Values]) as a column,
// returning an error and not adding if the name is not unique.
// Automatically adjusts the shape to fit the current number of rows,
// (setting Rows if this is the first column added)
// and calls the metadata SetName with column name.
func (cl *Columns) AddColumn(name string, tsr tensor.Values) error {
if cl.Len() == 0 {
cl.Rows = tsr.DimSize(0)
}
err := cl.Add(name, tsr)
if err != nil {
return err
}
tsr.SetNumRows(cl.Rows)
metadata.SetName(tsr, name)
return nil
}
// InsertColumn inserts the given tensor as a column at given index,
// returning an error and not adding if the name is not unique.
// Automatically adjusts the shape to fit the current number of rows.
func (cl *Columns) InsertColumn(idx int, name string, tsr tensor.Values) error {
cl.Insert(idx, name, tsr)
tsr.SetNumRows(cl.Rows)
return nil
}
// Clone returns a complete copy of this set of columns.
func (cl *Columns) Clone() *Columns {
cp := NewColumns().SetNumRows(cl.Rows)
for i, nm := range cl.Keys {
tsr := cl.Values[i]
cp.AddColumn(nm, tsr.Clone())
}
return cp
}
// AppendRows appends shared columns in both tables with input table rows.
func (cl *Columns) AppendRows(cl2 *Columns) {
for i, nm := range cl.Keys {
c2 := cl2.At(nm)
if c2 == nil {
continue
}
c1 := cl.Values[i]
c1.AppendFrom(c2)
}
cl.SetNumRows(cl.Rows + cl2.Rows)
}
// UpdateRows updates the current Rows count based on length of the first column.
func (cl *Columns) UpdateRows() {
if cl.Len() == 0 {
return
}
cl.Rows = cl.Values[0].DimSize(0)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package table
import (
"math/rand"
"slices"
"sort"
"cogentcore.org/lab/tensor"
)
// RowIndex returns the actual index into underlying tensor row based on given
// index value. If Indexes == nil, index is passed through.
func (dt *Table) RowIndex(idx int) int {
if dt.Indexes == nil {
return idx
}
return dt.Indexes[idx]
}
// NumRows returns the number of rows, which is the number of Indexes if present,
// else actual number of [Columns.Rows].
func (dt *Table) NumRows() int {
if dt.Indexes == nil {
return dt.Columns.Rows
}
return len(dt.Indexes)
}
// Sequential sets Indexes to nil, resulting in sequential row-wise access into tensor.
func (dt *Table) Sequential() { //types:add
dt.Indexes = nil
dt.Columns.UpdateRows()
}
// IndexesNeeded is called prior to an operation that needs actual indexes,
// e.g., Sort, Filter. If Indexes == nil, they are set to all rows, otherwise
// current indexes are left as is. Use Sequential, then IndexesNeeded to ensure
// all rows are represented.
func (dt *Table) IndexesNeeded() {
if dt.Indexes != nil {
return
}
dt.Columns.UpdateRows()
dt.Indexes = make([]int, dt.Columns.Rows)
for i := range dt.Indexes {
dt.Indexes[i] = i
}
}
// IndexesFromTensor copies Indexes from the given [tensor.Rows] tensor,
// including if they are nil. This allows column-specific Sort, Filter and
// other such methods to be applied to the entire table.
func (dt *Table) IndexesFromTensor(ix *tensor.Rows) {
dt.Indexes = ix.Indexes
}
// ValidIndexes deletes all invalid indexes from the list.
// Call this if rows (could) have been deleted from table.
func (dt *Table) ValidIndexes() {
dt.Columns.UpdateRows()
if dt.Columns.Rows <= 0 || dt.Indexes == nil {
dt.Indexes = nil
return
}
ni := dt.NumRows()
for i := ni - 1; i >= 0; i-- {
if dt.Indexes[i] >= dt.Columns.Rows {
dt.Indexes = append(dt.Indexes[:i], dt.Indexes[i+1:]...)
}
}
}
// Permuted sets indexes to a permuted order -- if indexes already exist
// then existing list of indexes is permuted, otherwise a new set of
// permuted indexes are generated
func (dt *Table) Permuted() {
dt.Columns.UpdateRows()
if dt.Columns.Rows <= 0 {
dt.Indexes = nil
return
}
if dt.Indexes == nil {
dt.Indexes = rand.Perm(dt.Columns.Rows)
} else {
rand.Shuffle(len(dt.Indexes), func(i, j int) {
dt.Indexes[i], dt.Indexes[j] = dt.Indexes[j], dt.Indexes[i]
})
}
}
// SortColumn sorts the indexes into our Table according to values in
// given column, using either ascending or descending order,
// (use [tensor.Ascending] or [tensor.Descending] for self-documentation).
// Uses first cell of higher dimensional data.
// Returns error if column name not found.
func (dt *Table) SortColumn(columnName string, ascending bool) error { //types:add
dt.IndexesNeeded()
cl, err := dt.ColumnTry(columnName)
if err != nil {
return err
}
cl.Sort(ascending)
dt.IndexesFromTensor(cl)
return nil
}
// SortFunc sorts the indexes into our Table using given compare function.
// The compare function operates directly on row numbers into the Table
// as these row numbers have already been projected through the indexes.
// cmp(a, b) should return a negative number when a < b, a positive
// number when a > b and zero when a == b.
func (dt *Table) SortFunc(cmp func(dt *Table, i, j int) int) {
dt.IndexesNeeded()
slices.SortFunc(dt.Indexes, func(a, b int) int {
return cmp(dt, a, b) // key point: these are already indirected through indexes!!
})
}
// SortStableFunc stably sorts the indexes into our Table using given compare function.
// The compare function operates directly on row numbers into the Table
// as these row numbers have already been projected through the indexes.
// cmp(a, b) should return a negative number when a < b, a positive
// number when a > b and zero when a == b.
// It is *essential* that it always returns 0 when the two are equal
// for the stable function to actually work.
func (dt *Table) SortStableFunc(cmp func(dt *Table, i, j int) int) {
dt.IndexesNeeded()
slices.SortStableFunc(dt.Indexes, func(a, b int) int {
return cmp(dt, a, b) // key point: these are already indirected through indexes!!
})
}
// SortColumns sorts the indexes into our Table according to values in
// given column names, using either ascending or descending order,
// (use [tensor.Ascending] or [tensor.Descending] for self-documentation,
// and optionally using a stable sort.
// Uses first cell of higher dimensional data.
func (dt *Table) SortColumns(ascending, stable bool, columns ...string) { //types:add
dt.SortColumnIndexes(ascending, stable, dt.ColumnIndexList(columns...)...)
}
// SortColumnIndexes sorts the indexes into our Table according to values in
// given list of column indexes, using either ascending or descending order for
// all of the columns. Uses first cell of higher dimensional data.
func (dt *Table) SortColumnIndexes(ascending, stable bool, colIndexes ...int) {
dt.IndexesNeeded()
sf := dt.SortFunc
if stable {
sf = dt.SortStableFunc
}
sf(func(dt *Table, i, j int) int {
for _, ci := range colIndexes {
cl := dt.ColumnByIndex(ci).Tensor
if cl.IsString() {
v := tensor.CompareAscending(cl.StringRow(i, 0), cl.StringRow(j, 0), ascending)
if v != 0 {
return v
}
} else {
v := tensor.CompareAscending(cl.FloatRow(i, 0), cl.FloatRow(j, 0), ascending)
if v != 0 {
return v
}
}
}
return 0
})
}
// SortIndexes sorts the indexes into our Table directly in
// numerical order, producing the native ordering, while preserving
// any filtering that might have occurred.
func (dt *Table) SortIndexes() {
if dt.Indexes == nil {
return
}
sort.Ints(dt.Indexes)
}
// FilterFunc is a function used for filtering that returns
// true if Table row should be included in the current filtered
// view of the table, and false if it should be removed.
type FilterFunc func(dt *Table, row int) bool
// Filter filters the indexes into our Table using given Filter function.
// The Filter function operates directly on row numbers into the Table
// as these row numbers have already been projected through the indexes.
func (dt *Table) Filter(filterer func(dt *Table, row int) bool) {
dt.IndexesNeeded()
sz := len(dt.Indexes)
for i := sz - 1; i >= 0; i-- { // always go in reverse for filtering
if !filterer(dt, dt.Indexes[i]) { // delete
dt.Indexes = append(dt.Indexes[:i], dt.Indexes[i+1:]...)
}
}
}
// FilterString filters the indexes using string values in column compared to given
// string. Includes rows with matching values unless the Exclude option is set.
// If Contains option is set, it only checks if row contains string;
// if IgnoreCase, ignores case, otherwise filtering is case sensitive.
// Uses first cell from higher dimensions.
// Returns error if column name not found.
func (dt *Table) FilterString(columnName string, str string, opts tensor.StringMatch) error { //types:add
dt.IndexesNeeded()
cl, err := dt.ColumnTry(columnName)
if err != nil {
return err
}
cl.FilterString(str, opts)
dt.IndexesFromTensor(cl)
return nil
}
// New returns a new table with column data organized according to
// the indexes. If Indexes are nil, a clone of the current tensor is returned
// but this function is only sensible if there is an indexed view in place.
func (dt *Table) New() *Table {
if dt.Indexes == nil {
return dt.Clone()
}
rows := len(dt.Indexes)
nt := dt.Clone()
nt.Indexes = nil
nt.SetNumRows(rows)
if rows == 0 {
return nt
}
for ci, cl := range nt.Columns.Values {
scl := dt.Columns.Values[ci]
_, csz := cl.Shape().RowCellSize()
for i, srw := range dt.Indexes {
cl.CopyCellsFrom(scl, i*csz, srw*csz, csz)
}
}
return nt
}
// DeleteRows deletes n rows of Indexes starting at given index in the list of indexes.
// This does not affect the underlying tensor data; To create an actual in-memory
// ordering with rows deleted, use [Table.New].
func (dt *Table) DeleteRows(at, n int) {
dt.IndexesNeeded()
dt.Indexes = append(dt.Indexes[:at], dt.Indexes[at+n:]...)
}
// Swap switches the indexes for i and j
func (dt *Table) Swap(i, j int) {
dt.Indexes[i], dt.Indexes[j] = dt.Indexes[j], dt.Indexes[i]
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package table
import (
"bufio"
"bytes"
"encoding/csv"
"fmt"
"io"
"io/fs"
"log"
"log/slog"
"math"
"os"
"reflect"
"strconv"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/fsx"
"cogentcore.org/lab/tensor"
)
const (
// Headers is passed to CSV methods for the headers arg, to use headers
// that capture full type and tensor shape information.
Headers = true
// NoHeaders is passed to CSV methods for the headers arg, to not use headers
NoHeaders = false
)
// SaveCSV writes a table to a comma-separated-values (CSV) file
// (where comma = any delimiter, specified in the delim arg).
// If headers = true then generate column headers that capture the type
// and tensor cell geometry of the columns, enabling full reloading
// of exactly the same table format and data (recommended).
// Otherwise, only the data is written.
func (dt *Table) SaveCSV(filename fsx.Filename, delim tensor.Delims, headers bool) error { //types:add
fp, err := os.Create(string(filename))
defer fp.Close()
if err != nil {
log.Println(err)
return err
}
bw := bufio.NewWriter(fp)
err = dt.WriteCSV(bw, delim, headers)
bw.Flush()
return err
}
// String returns a string of the CSV formatted file for the table.
func (dt *Table) String() string {
var b bytes.Buffer
dt.WriteCSV(&b, tensor.Tab, true)
return b.String()
}
// OpenCSV reads a table from a comma-separated-values (CSV) file
// (where comma = any delimiter, specified in the delim arg),
// using the Go standard encoding/csv reader conforming to the official CSV standard.
// If the table does not currently have any columns, the first row of the file
// is assumed to be headers, and columns are constructed therefrom.
// If the file was saved from table with headers, then these have full configuration
// information for tensor type and dimensionality.
// If the table DOES have existing columns, then those are used robustly
// for whatever information fits from each row of the file.
func (dt *Table) OpenCSV(filename fsx.Filename, delim tensor.Delims) error { //types:add
fp, err := os.Open(string(filename))
if err != nil {
return errors.Log(err)
}
defer fp.Close()
return dt.ReadCSV(bufio.NewReader(fp), delim)
}
// OpenFS is the version of [Table.OpenCSV] that uses an [fs.FS] filesystem.
func (dt *Table) OpenFS(fsys fs.FS, filename string, delim tensor.Delims) error {
fp, err := fsys.Open(filename)
if err != nil {
return errors.Log(err)
}
defer fp.Close()
return dt.ReadCSV(bufio.NewReader(fp), delim)
}
// ReadCSV reads a table from a comma-separated-values (CSV) file
// (where comma = any delimiter, specified in the delim arg),
// using the Go standard encoding/csv reader conforming to the official CSV standard.
// If the table does not currently have any columns, the first row of the file
// is assumed to be headers, and columns are constructed therefrom.
// If the file was saved from table with headers, then these have full configuration
// information for tensor type and dimensionality.
// If the table DOES have existing columns, then those are used robustly
// for whatever information fits from each row of the file.
func (dt *Table) ReadCSV(r io.Reader, delim tensor.Delims) error {
dt.Sequential()
cr := csv.NewReader(r)
cr.Comma = delim.Rune()
rec, err := cr.ReadAll() // todo: lazy, avoid resizing
if err != nil || len(rec) == 0 {
return err
}
rows := len(rec)
strow := 0
if dt.NumColumns() == 0 || DetectTableHeaders(rec[0]) {
dt.DeleteAll()
err := ConfigFromHeaders(dt, rec[0], rec)
if err != nil {
log.Println(err.Error())
return err
}
strow++
rows--
}
dt.SetNumRows(rows)
for ri := 0; ri < rows; ri++ {
dt.ReadCSVRow(rec[ri+strow], ri)
}
return nil
}
// ReadCSVRow reads a record of CSV data into given row in table
func (dt *Table) ReadCSVRow(rec []string, row int) {
ci := 0
if rec[0] == "_D:" { // data row
ci++
}
nan := math.NaN()
for _, tsr := range dt.Columns.Values {
_, csz := tsr.Shape().RowCellSize()
stoff := row * csz
for cc := 0; cc < csz; cc++ {
str := rec[ci]
if !tsr.IsString() {
if str == "" || str == "NaN" || str == "-NaN" || str == "Inf" || str == "-Inf" {
tsr.SetFloat1D(nan, stoff+cc)
} else {
tsr.SetString1D(strings.TrimSpace(str), stoff+cc)
}
} else {
tsr.SetString1D(strings.TrimSpace(str), stoff+cc)
}
ci++
if ci >= len(rec) {
return
}
}
}
}
// ConfigFromHeaders attempts to configure Table based on the headers.
// for non-table headers, data is examined to determine types.
func ConfigFromHeaders(dt *Table, hdrs []string, rec [][]string) error {
if DetectTableHeaders(hdrs) {
return ConfigFromTableHeaders(dt, hdrs)
}
return ConfigFromDataValues(dt, hdrs, rec)
}
// DetectTableHeaders looks for special header characters -- returns true if found
func DetectTableHeaders(hdrs []string) bool {
for _, hd := range hdrs {
hd = strings.TrimSpace(hd)
if hd == "" {
continue
}
if hd == "_H:" {
return true
}
if _, ok := TableHeaderToType[hd[0]]; !ok { // all must be table
return false
}
}
return true
}
// ConfigFromTableHeaders attempts to configure a Table based on special table headers
func ConfigFromTableHeaders(dt *Table, hdrs []string) error {
for _, hd := range hdrs {
hd = strings.TrimSpace(hd)
if hd == "" || hd == "_H:" {
continue
}
typ, hd := TableColumnType(hd)
dimst := strings.Index(hd, "]<")
if dimst > 0 {
dims := hd[dimst+2 : len(hd)-1]
lbst := strings.Index(hd, "[")
hd = hd[:lbst]
csh := ShapeFromString(dims)
// new tensor starting
dt.AddColumnOfType(hd, typ, csh...)
continue
}
dimst = strings.Index(hd, "[")
if dimst > 0 {
continue
}
dt.AddColumnOfType(hd, typ)
}
return nil
}
// TableHeaderToType maps special header characters to data type
var TableHeaderToType = map[byte]reflect.Kind{
'$': reflect.String,
'%': reflect.Float32,
'#': reflect.Float64,
'|': reflect.Int,
'^': reflect.Bool,
}
// TableHeaderChar returns the special header character based on given data type
func TableHeaderChar(typ reflect.Kind) byte {
switch {
case typ == reflect.Bool:
return '^'
case typ == reflect.Float32:
return '%'
case typ == reflect.Float64:
return '#'
case typ >= reflect.Int && typ <= reflect.Uintptr:
return '|'
default:
return '$'
}
}
// TableColumnType parses the column header for special table type information
func TableColumnType(nm string) (reflect.Kind, string) {
typ, ok := TableHeaderToType[nm[0]]
if ok {
nm = nm[1:]
} else {
typ = reflect.String // most general, default
}
return typ, nm
}
// ShapeFromString parses string representation of shape as N:d,d,..
func ShapeFromString(dims string) []int {
clni := strings.Index(dims, ":")
nd, _ := strconv.Atoi(dims[:clni])
sh := make([]int, nd)
ci := clni + 1
for i := 0; i < nd; i++ {
dstr := ""
if i < nd-1 {
nci := strings.Index(dims[ci:], ",")
dstr = dims[ci : ci+nci]
ci += nci + 1
} else {
dstr = dims[ci:]
}
d, _ := strconv.Atoi(dstr)
sh[i] = d
}
return sh
}
// ConfigFromDataValues configures a Table based on data types inferred
// from the string representation of given records, using header names if present.
func ConfigFromDataValues(dt *Table, hdrs []string, rec [][]string) error {
nr := len(rec)
for ci, hd := range hdrs {
hd = strings.TrimSpace(hd)
if hd == "" {
hd = fmt.Sprintf("col_%d", ci)
}
nmatch := 0
typ := reflect.String
for ri := 1; ri < nr; ri++ {
rv := rec[ri][ci]
if rv == "" {
continue
}
ctyp := InferDataType(rv)
switch {
case ctyp == reflect.String: // definitive
typ = ctyp
break
case typ == ctyp && (nmatch > 1 || ri == nr-1): // good enough
break
case typ == ctyp: // gather more info
nmatch++
case typ == reflect.String: // always upgrade from string default
nmatch = 0
typ = ctyp
case typ == reflect.Int && ctyp == reflect.Float64: // upgrade
nmatch = 0
typ = ctyp
}
}
dt.AddColumnOfType(hd, typ)
}
return nil
}
// InferDataType returns the inferred data type for the given string
// only deals with float64, int, and string types
func InferDataType(str string) reflect.Kind {
if strings.Contains(str, ".") {
_, err := strconv.ParseFloat(str, 64)
if err == nil {
return reflect.Float64
}
}
_, err := strconv.ParseInt(str, 10, 64)
if err == nil {
return reflect.Int
}
// try float again just in case..
_, err = strconv.ParseFloat(str, 64)
if err == nil {
return reflect.Float64
}
return reflect.String
}
//////// WriteCSV
// WriteCSV writes only rows in table idx view to a comma-separated-values (CSV) file
// (where comma = any delimiter, specified in the delim arg).
// If headers = true then generate column headers that capture the type
// and tensor cell geometry of the columns, enabling full reloading
// of exactly the same table format and data (recommended).
// Otherwise, only the data is written.
func (dt *Table) WriteCSV(w io.Writer, delim tensor.Delims, headers bool) error {
ncol := 0
var err error
if headers {
ncol, err = dt.WriteCSVHeaders(w, delim)
if err != nil {
log.Println(err)
return err
}
}
cw := csv.NewWriter(w)
cw.Comma = delim.Rune()
nrow := dt.NumRows()
for ri := range nrow {
ix := dt.RowIndex(ri)
err = dt.WriteCSVRowWriter(cw, ix, ncol)
if err != nil {
log.Println(err)
return err
}
}
cw.Flush()
return nil
}
// WriteCSVHeaders writes headers to a comma-separated-values (CSV) file
// (where comma = any delimiter, specified in the delim arg).
// Returns number of columns in header
func (dt *Table) WriteCSVHeaders(w io.Writer, delim tensor.Delims) (int, error) {
cw := csv.NewWriter(w)
cw.Comma = delim.Rune()
hdrs := dt.TableHeaders()
nc := len(hdrs)
err := cw.Write(hdrs)
if err != nil {
return nc, err
}
cw.Flush()
return nc, nil
}
// WriteCSVRow writes given row to a comma-separated-values (CSV) file
// (where comma = any delimiter, specified in the delim arg)
func (dt *Table) WriteCSVRow(w io.Writer, row int, delim tensor.Delims) error {
cw := csv.NewWriter(w)
cw.Comma = delim.Rune()
err := dt.WriteCSVRowWriter(cw, row, 0)
cw.Flush()
return err
}
// WriteCSVRowWriter uses csv.Writer to write one row
func (dt *Table) WriteCSVRowWriter(cw *csv.Writer, row int, ncol int) error {
prec := -1
if ps, err := tensor.Precision(dt); err == nil {
prec = ps
}
var rec []string
if ncol > 0 {
rec = make([]string, 0, ncol)
} else {
rec = make([]string, 0)
}
rc := 0
for _, tsr := range dt.Columns.Values {
nd := tsr.NumDims()
if nd == 1 {
vl := ""
if prec <= 0 || tsr.IsString() {
vl = tsr.String1D(row)
} else {
vl = strconv.FormatFloat(tsr.Float1D(row), 'g', prec, 64)
}
if len(rec) <= rc {
rec = append(rec, vl)
} else {
rec[rc] = vl
}
rc++
} else {
csh := tensor.NewShape(tsr.ShapeSizes()[1:]...) // cell shape
tc := csh.Len()
for ti := 0; ti < tc; ti++ {
vl := ""
if prec <= 0 || tsr.IsString() {
vl = tsr.String1D(row*tc + ti)
} else {
vl = strconv.FormatFloat(tsr.Float1D(row*tc+ti), 'g', prec, 64)
}
if len(rec) <= rc {
rec = append(rec, vl)
} else {
rec[rc] = vl
}
rc++
}
}
}
err := cw.Write(rec)
return err
}
// TableHeaders generates special header strings from the table
// with full information about type and tensor cell dimensionality.
func (dt *Table) TableHeaders() []string {
hdrs := []string{}
for i, nm := range dt.Columns.Keys {
tsr := dt.Columns.Values[i]
nm = string([]byte{TableHeaderChar(tsr.DataType())}) + nm
if tsr.NumDims() == 1 {
hdrs = append(hdrs, nm)
} else {
csh := tensor.NewShape(tsr.ShapeSizes()[1:]...) // cell shape
tc := csh.Len()
nd := csh.NumDims()
fnm := nm + fmt.Sprintf("[%v:", nd)
dn := fmt.Sprintf("<%v:", nd)
ffnm := fnm
for di := 0; di < nd; di++ {
ffnm += "0"
dn += fmt.Sprintf("%v", csh.DimSize(di))
if di < nd-1 {
ffnm += ","
dn += ","
}
}
ffnm += "]" + dn + ">"
hdrs = append(hdrs, ffnm)
for ti := 1; ti < tc; ti++ {
idx := csh.IndexFrom1D(ti)
ffnm := fnm
for di := 0; di < nd; di++ {
ffnm += fmt.Sprintf("%v", idx[di])
if di < nd-1 {
ffnm += ","
}
}
ffnm += "]"
hdrs = append(hdrs, ffnm)
}
}
}
return hdrs
}
// CleanCatTSV cleans a TSV file formed by concatenating multiple files together.
// Removes redundant headers and then sorts by given set of columns.
func CleanCatTSV(filename string, sorts ...string) error {
str, err := os.ReadFile(filename)
if err != nil {
slog.Error(err.Error())
return err
}
lns := strings.Split(string(str), "\n")
if len(lns) == 0 {
return nil
}
hdr := lns[0]
f, err := os.Create(filename)
if err != nil {
slog.Error(err.Error())
return err
}
for i, ln := range lns {
if i > 0 && ln == hdr {
continue
}
io.WriteString(f, ln)
io.WriteString(f, "\n")
}
f.Close()
dt := New()
err = dt.OpenCSV(fsx.Filename(filename), tensor.Detect)
if err != nil {
slog.Error(err.Error())
return err
}
dt.SortColumns(tensor.Ascending, tensor.StableSort, sorts...)
st := dt.New()
err = st.SaveCSV(fsx.Filename(filename), tensor.Tab, true)
if err != nil {
slog.Error(err.Error())
}
return err
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package table
import (
"os"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/metadata"
"cogentcore.org/lab/tensor"
)
func setLogRow(dt *Table, row int) {
metadata.Set(dt, "LogRow", row)
}
func logRow(dt *Table) int {
return errors.Ignore1(metadata.Get[int](dt, "LogRow"))
}
func setLogDelim(dt *Table, delim tensor.Delims) {
metadata.Set(dt, "LogDelim", delim)
}
func logDelim(dt *Table) tensor.Delims {
return errors.Ignore1(metadata.Get[tensor.Delims](dt, "LogDelim"))
}
// OpenLog opens a log file for this table, which supports incremental
// output of table data as it is generated, using the standard [Table.SaveCSV]
// output formatting, using given delimiter between values on a line.
// Call [Table.WriteToLog] to write any new data rows to
// the open log file, and [Table.CloseLog] to close the file.
func (dt *Table) OpenLog(filename string, delim tensor.Delims) error {
f, err := os.Create(filename)
if err != nil {
return err
}
metadata.SetFile(dt, f)
setLogDelim(dt, delim)
setLogRow(dt, 0)
return nil
}
var (
ErrLogNoNewRows = errors.New("no new rows to write")
)
// WriteToLog writes any accumulated rows in the table to the file
// opened by [Table.OpenLog]. A Header row is written for the first output.
// If the current number of rows is less than the last number of rows,
// all of those rows are written under the assumption that the rows
// were reset via [Table.SetNumRows].
// Returns error for any failure, including [ErrLogNoNewRows] if
// no new rows are available to write.
func (dt *Table) WriteToLog() error {
f := metadata.File(dt)
if f == nil {
return errors.New("tensor.Table.WriteToLog: log file was not opened")
}
delim := logDelim(dt)
lrow := logRow(dt)
nr := dt.NumRows()
if nr == 0 || lrow == nr {
return ErrLogNoNewRows
}
if lrow == 0 {
dt.WriteCSVHeaders(f, delim)
}
sr := lrow
if nr < lrow {
sr = 0
}
for r := sr; r < nr; r++ {
dt.WriteCSVRow(f, r, delim)
}
setLogRow(dt, nr)
return nil
}
// CloseLog closes the log file opened by [Table.OpenLog].
func (dt *Table) CloseLog() {
f := metadata.File(dt)
f.Close()
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package table
import (
"fmt"
"reflect"
"cogentcore.org/core/base/reflectx"
)
// NewSliceTable returns a new Table with data from the given slice
// of structs.
func NewSliceTable(st any) (*Table, error) {
npv := reflectx.NonPointerValue(reflect.ValueOf(st))
if npv.Kind() != reflect.Slice {
return nil, fmt.Errorf("NewSliceTable: not a slice")
}
eltyp := reflectx.NonPointerType(npv.Type().Elem())
if eltyp.Kind() != reflect.Struct {
return nil, fmt.Errorf("NewSliceTable: element type is not a struct")
}
dt := New()
for i := 0; i < eltyp.NumField(); i++ {
f := eltyp.Field(i)
kind := f.Type.Kind()
if !reflectx.KindIsBasic(kind) {
continue
}
dt.AddColumnOfType(f.Name, kind)
}
UpdateSliceTable(st, dt)
return dt, nil
}
// UpdateSliceTable updates given Table with data from the given slice
// of structs, which must be the same type as used to configure the table
func UpdateSliceTable(st any, dt *Table) {
npv := reflectx.NonPointerValue(reflect.ValueOf(st))
eltyp := reflectx.NonPointerType(npv.Type().Elem())
nr := npv.Len()
dt.SetNumRows(nr)
for ri := 0; ri < nr; ri++ {
for i := 0; i < eltyp.NumField(); i++ {
f := eltyp.Field(i)
kind := f.Type.Kind()
if !reflectx.KindIsBasic(kind) {
continue
}
val := npv.Index(ri).Field(i).Interface()
cl := dt.Column(f.Name)
if kind == reflect.String {
cl.SetStringRow(val.(string), ri, 0)
} else {
fv, _ := reflectx.ToFloat(val)
cl.SetFloatRow(fv, ri, 0)
}
}
}
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package table
//go:generate core generate
import (
"fmt"
"reflect"
"slices"
"cogentcore.org/core/base/metadata"
"cogentcore.org/lab/tensor"
)
// Table is a table of Tensor columns aligned by a common outermost row dimension.
// Use the [Table.Column] (by name) and [Table.ColumnIndex] methods to obtain a
// [tensor.Rows] view of the column, using the shared [Table.Indexes] of the Table.
// Thus, a coordinated sorting and filtered view of the column data is automatically
// available for any of the tensor package functions that use [tensor.Tensor] as the one
// common data representation for all operations.
// Tensor Columns are always raw value types and support SubSpace operations on cells.
type Table struct { //types:add
// Columns has the list of column tensor data for this table.
// Different tables can provide different indexed views onto the same Columns.
Columns *Columns
// Indexes are the indexes into Tensor rows, with nil = sequential.
// Only set if order is different from default sequential order.
// These indexes are shared into the `tensor.Rows` Column values
// to provide a coordinated indexed view into the underlying data.
Indexes []int
// Meta data is used extensively for Name, Precision, Doc etc.
// Use standard Go camel-case key names, standards in [metadata].
Meta metadata.Data
}
// New returns a new Table with its own (empty) set of Columns.
// Can pass an optional name which calls metadata SetName.
func New(name ...string) *Table {
dt := &Table{}
dt.Columns = NewColumns()
if len(name) > 0 {
metadata.SetName(dt, name[0])
}
return dt
}
// NewView returns a new Table with its own Rows view into the
// same underlying set of Column tensor data as the source table.
// Indexes are copied from the existing table -- use Sequential
// to reset to full sequential view.
func NewView(src *Table) *Table {
dt := &Table{Columns: src.Columns}
if src.Indexes != nil {
dt.Indexes = slices.Clone(src.Indexes)
}
dt.Meta.Copy(src.Meta)
return dt
}
// Init initializes a new empty table with [NewColumns].
func (dt *Table) Init() {
dt.Columns = NewColumns()
}
func (dt *Table) Metadata() *metadata.Data { return &dt.Meta }
// IsValidRow returns error if the row is invalid, if error checking is needed.
func (dt *Table) IsValidRow(row int) error {
if row < 0 || row >= dt.NumRows() {
return fmt.Errorf("table.Table IsValidRow: row %d is out of valid range [0..%d]", row, dt.NumRows())
}
return nil
}
// NumColumns returns the number of columns.
func (dt *Table) NumColumns() int { return dt.Columns.Len() }
// Column returns the tensor with given column name, as a [tensor.Rows]
// with the shared [Table.Indexes] from this table. It is best practice to
// access columns by name, and direct access through [Table.Columns] does not
// provide the shared table-wide Indexes.
// Returns nil if not found.
func (dt *Table) Column(name string) *tensor.Rows {
cl := dt.Columns.At(name)
if cl == nil {
return nil
}
return tensor.NewRows(cl, dt.Indexes...)
}
// ColumnTry is a version of [Table.Column] that also returns an error
// if the column name is not found, for cases when error is needed.
func (dt *Table) ColumnTry(name string) (*tensor.Rows, error) {
cl := dt.Column(name)
if cl != nil {
return cl, nil
}
return nil, fmt.Errorf("table.Table: Column named %q not found", name)
}
// ColumnIndex returns the tensor at the given column index,
// as a [tensor.Rows] with the shared [Table.Indexes] from this table.
// It is best practice to instead access columns by name using [Table.Column].
// Direct access through [Table.Columns} does not provide the shared table-wide Indexes.
// Will panic if out of range.
func (dt *Table) ColumnByIndex(idx int) *tensor.Rows {
cl := dt.Columns.Values[idx]
return tensor.NewRows(cl, dt.Indexes...)
}
// ColumnList returns a list of tensors with given column names,
// as [tensor.Rows] with the shared [Table.Indexes] from this table.
func (dt *Table) ColumnList(names ...string) []tensor.Tensor {
list := make([]tensor.Tensor, 0, len(names))
for _, nm := range names {
cl := dt.Column(nm)
if cl != nil {
list = append(list, cl)
}
}
return list
}
// ColumnName returns the name of given column.
func (dt *Table) ColumnName(i int) string {
return dt.Columns.Keys[i]
}
// ColumnIndex returns the index for given column name.
func (dt *Table) ColumnIndex(name string) int {
return dt.Columns.IndexByKey(name)
}
// ColumnIndexList returns a list of indexes to columns of given names.
func (dt *Table) ColumnIndexList(names ...string) []int {
list := make([]int, 0, len(names))
for _, nm := range names {
ci := dt.ColumnIndex(nm)
if ci >= 0 {
list = append(list, ci)
}
}
return list
}
// AddColumn adds a new column to the table, of given type and column name
// (which must be unique). If no cellSizes are specified, it holds scalar values,
// otherwise the cells are n-dimensional tensors of given size.
func AddColumn[T tensor.DataTypes](dt *Table, name string, cellSizes ...int) tensor.Tensor {
rows := dt.Columns.Rows
sz := append([]int{rows}, cellSizes...)
tsr := tensor.New[T](sz...)
// tsr.SetNames("Row")
dt.AddColumn(name, tsr)
return tsr
}
// InsertColumn inserts a new column to the table, of given type and column name
// (which must be unique), at given index.
// If no cellSizes are specified, it holds scalar values,
// otherwise the cells are n-dimensional tensors of given size.
func InsertColumn[T tensor.DataTypes](dt *Table, name string, idx int, cellSizes ...int) tensor.Tensor {
rows := dt.Columns.Rows
sz := append([]int{rows}, cellSizes...)
tsr := tensor.New[T](sz...)
// tsr.SetNames("Row")
dt.InsertColumn(idx, name, tsr)
return tsr
}
// AddColumn adds the given [tensor.Values] as a column to the table,
// returning an error and not adding if the name is not unique.
// Automatically adjusts the shape to fit the current number of rows.
func (dt *Table) AddColumn(name string, tsr tensor.Values) error {
return dt.Columns.AddColumn(name, tsr)
}
// InsertColumn inserts the given [tensor.Values] as a column to the table at given index,
// returning an error and not adding if the name is not unique.
// Automatically adjusts the shape to fit the current number of rows.
func (dt *Table) InsertColumn(idx int, name string, tsr tensor.Values) error {
return dt.Columns.InsertColumn(idx, name, tsr)
}
// AddColumnOfType adds a new scalar column to the table, of given reflect type,
// column name (which must be unique),
// If no cellSizes are specified, it holds scalar values,
// otherwise the cells are n-dimensional tensors of given size.
// Supported types include string, bool (for [tensor.Bool]), float32, float64, int, int32, and byte.
func (dt *Table) AddColumnOfType(name string, typ reflect.Kind, cellSizes ...int) tensor.Tensor {
rows := dt.Columns.Rows
sz := append([]int{rows}, cellSizes...)
tsr := tensor.NewOfType(typ, sz...)
// tsr.SetNames("Row")
dt.AddColumn(name, tsr)
return tsr
}
// AddStringColumn adds a new String column with given name.
// If no cellSizes are specified, it holds scalar values,
// otherwise the cells are n-dimensional tensors of given size.
func (dt *Table) AddStringColumn(name string, cellSizes ...int) *tensor.String {
return AddColumn[string](dt, name, cellSizes...).(*tensor.String)
}
// AddFloat64Column adds a new float64 column with given name.
// If no cellSizes are specified, it holds scalar values,
// otherwise the cells are n-dimensional tensors of given size.
func (dt *Table) AddFloat64Column(name string, cellSizes ...int) *tensor.Float64 {
return AddColumn[float64](dt, name, cellSizes...).(*tensor.Float64)
}
// AddFloat32Column adds a new float32 column with given name.
// If no cellSizes are specified, it holds scalar values,
// otherwise the cells are n-dimensional tensors of given size.
func (dt *Table) AddFloat32Column(name string, cellSizes ...int) *tensor.Float32 {
return AddColumn[float32](dt, name, cellSizes...).(*tensor.Float32)
}
// AddIntColumn adds a new int column with given name.
// If no cellSizes are specified, it holds scalar values,
// otherwise the cells are n-dimensional tensors of given size.
func (dt *Table) AddIntColumn(name string, cellSizes ...int) *tensor.Int {
return AddColumn[int](dt, name, cellSizes...).(*tensor.Int)
}
// DeleteColumnName deletes column of given name.
// returns false if not found.
func (dt *Table) DeleteColumnName(name string) bool {
return dt.Columns.DeleteByKey(name)
}
// DeleteColumnIndex deletes column within the index range [i:j].
func (dt *Table) DeleteColumnByIndex(i, j int) {
dt.Columns.DeleteByIndex(i, j)
}
// DeleteAll deletes all columns, does full reset.
func (dt *Table) DeleteAll() {
dt.Indexes = nil
dt.Columns.Reset()
}
// AddRows adds n rows to end of underlying Table, and to the indexes in this view.
func (dt *Table) AddRows(n int) *Table { //types:add
return dt.SetNumRows(dt.Columns.Rows + n)
}
// InsertRows adds n rows to end of underlying Table, and to the indexes starting at
// given index in this view, providing an efficient insertion operation that only
// exists in the indexed view. To create an in-memory ordering, use [Table.New].
func (dt *Table) InsertRows(at, n int) *Table {
dt.IndexesNeeded()
strow := dt.Columns.Rows
stidx := len(dt.Indexes)
dt.SetNumRows(strow + n) // adds n indexes to end of list
// move those indexes to at:at+n in index list
dt.Indexes = append(dt.Indexes[:at], append(dt.Indexes[stidx:], dt.Indexes[at:]...)...)
dt.Indexes = dt.Indexes[:strow+n]
return dt
}
// SetNumRows sets the number of rows in the table, across all columns.
// If rows = 0 then effective number of rows in tensors is 1, as this dim cannot be 0.
// If indexes are in place and rows are added, indexes for the new rows are added.
func (dt *Table) SetNumRows(rows int) *Table { //types:add
strow := dt.Columns.Rows
dt.Columns.SetNumRows(rows)
if dt.Indexes == nil {
return dt
}
if rows > strow {
for i := range rows - strow {
dt.Indexes = append(dt.Indexes, strow+i)
}
} else {
dt.ValidIndexes()
}
return dt
}
// SetNumRowsToMax gets the current max number of rows across all the column tensors,
// and sets the number of rows to that. This will automatically pad shorter columns
// so they all have the same number of rows. If a table has columns that are not fully
// under its own control, they can change size, so this reestablishes
// a common row dimension.
func (dt *Table) SetNumRowsToMax() {
var maxRow int
for _, tsr := range dt.Columns.Values {
maxRow = max(maxRow, tsr.DimSize(0))
}
dt.SetNumRows(maxRow)
}
// note: no really clean definition of CopyFrom -- no point of re-using existing
// table -- just clone it.
// Clone returns a complete copy of this table, including cloning
// the underlying Columns tensors, and the current [Table.Indexes].
// See also [Table.New] to flatten the current indexes.
func (dt *Table) Clone() *Table {
cp := &Table{}
cp.Columns = dt.Columns.Clone()
cp.Meta.Copy(dt.Meta)
if dt.Indexes != nil {
cp.Indexes = slices.Clone(dt.Indexes)
}
return cp
}
// AppendRows appends shared columns in both tables with input table rows.
func (dt *Table) AppendRows(dt2 *Table) {
strow := dt.Columns.Rows
n := dt2.Columns.Rows
dt.Columns.AppendRows(dt2.Columns)
if dt.Indexes == nil {
return
}
for i := range n {
dt.Indexes = append(dt.Indexes, strow+i)
}
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package table
import (
"errors"
"fmt"
"reflect"
"strings"
"cogentcore.org/lab/tensor"
)
// InsertKeyColumns returns a copy of the given Table with new columns
// having given values, inserted at the start, used as legend keys etc.
// args must be in pairs: column name, value. All rows get the same value.
func (dt *Table) InsertKeyColumns(args ...string) *Table {
n := len(args)
if n%2 != 0 {
fmt.Println("InsertKeyColumns requires even number of args as column name, value pairs")
return dt
}
c := dt.Clone()
nc := n / 2
for j := range nc {
colNm := args[2*j]
val := args[2*j+1]
col := tensor.NewString(c.Columns.Rows)
if c.Column(colNm) == nil {
c.InsertColumn(0, colNm, col)
}
for i := range col.Values {
col.Values[i] = val
}
}
return c
}
// ConfigFromTable configures the columns of this table according to the
// values in the first two columns of given format table, conventionally named
// Name, Type (but names are not used), which must be of the string type.
func (dt *Table) ConfigFromTable(ft *Table) error {
nmcol := ft.ColumnByIndex(0)
tycol := ft.ColumnByIndex(1)
var errs []error
for i := range ft.NumRows() {
name := nmcol.String1D(i)
typ := strings.ToLower(tycol.String1D(i))
kind := reflect.Float64
switch typ {
case "string":
kind = reflect.String
case "bool":
kind = reflect.Bool
case "float32":
kind = reflect.Float32
case "float64":
kind = reflect.Float64
case "int":
kind = reflect.Int
case "int32":
kind = reflect.Int32
case "byte", "uint8":
kind = reflect.Uint8
default:
err := fmt.Errorf("ConfigFromTable: type string %q not recognized", typ)
errs = append(errs, err)
}
dt.AddColumnOfType(name, kind)
}
return errors.Join(errs...)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"fmt"
"slices"
"cogentcore.org/core/base/errors"
)
// AlignShapes aligns the shapes of two tensors, a and b for a binary
// computation producing an output, returning the effective aligned shapes
// for a, b, and the output, all with the same number of dimensions.
// Alignment proceeds from the innermost dimension out, with 1s provided
// beyond the number of dimensions for a or b.
// The output has the max of the dimension sizes for each dimension.
// An error is returned if the rules of alignment are violated:
// each dimension size must be either the same, or one of them
// is equal to 1. This corresponds to the "broadcasting" logic of NumPy.
func AlignShapes(a, b Tensor) (as, bs, os *Shape, err error) {
asz := a.ShapeSizes()
bsz := b.ShapeSizes()
an := len(asz)
bn := len(bsz)
n := max(an, bn)
osizes := make([]int, n)
asizes := make([]int, n)
bsizes := make([]int, n)
for d := range n {
ai := an - 1 - d
bi := bn - 1 - d
oi := n - 1 - d
ad := 1
bd := 1
if ai >= 0 {
ad = asz[ai]
}
if bi >= 0 {
bd = bsz[bi]
}
if ad != bd && !(ad == 1 || bd == 1) {
err = fmt.Errorf("tensor.AlignShapes: output dimension %d does not align for a=%d b=%d: must be either the same or one of them is a 1", oi, ad, bd)
return
}
od := max(ad, bd)
osizes[oi] = od
asizes[oi] = ad
bsizes[oi] = bd
}
as = NewShape(asizes...)
bs = NewShape(bsizes...)
os = NewShape(osizes...)
return
}
// WrapIndex1D returns the 1d flat index for given n-dimensional index
// based on given shape, where any singleton dimension sizes cause the
// resulting index value to remain at 0, effectively causing that dimension
// to wrap around, while the other tensor is presumably using the full range
// of the values along this dimension. See [AlignShapes] for more info.
func WrapIndex1D(sh *Shape, i ...int) int {
nd := sh.NumDims()
ai := slices.Clone(i)
for d := range nd {
if sh.DimSize(d) == 1 {
ai[d] = 0
}
}
return sh.IndexTo1D(ai...)
}
// AlignForAssign ensures that the shapes of two tensors, a and b
// have the proper alignment for assigning b into a.
// Alignment proceeds from the innermost dimension out, with 1s provided
// beyond the number of dimensions for a or b.
// An error is returned if the rules of alignment are violated:
// each dimension size must be either the same, or b is equal to 1.
// This corresponds to the "broadcasting" logic of NumPy.
func AlignForAssign(a, b Tensor) (as, bs *Shape, err error) {
asz := a.ShapeSizes()
bsz := b.ShapeSizes()
an := len(asz)
bn := len(bsz)
n := max(an, bn)
asizes := make([]int, n)
bsizes := make([]int, n)
for d := range n {
ai := an - 1 - d
bi := bn - 1 - d
oi := n - 1 - d
ad := 1
bd := 1
if ai >= 0 {
ad = asz[ai]
}
if bi >= 0 {
bd = bsz[bi]
}
if ad != bd && bd != 1 {
err = fmt.Errorf("tensor.AlignShapes: dimension %d does not align for a=%d b=%d: must be either the same or b is a 1", oi, ad, bd)
return
}
asizes[oi] = ad
bsizes[oi] = bd
}
as = NewShape(asizes...)
bs = NewShape(bsizes...)
return
}
// SplitAtInnerDims returns the sizes of the given tensor's shape
// with the given number of inner-most dimensions retained as is,
// and those above collapsed to a single dimension.
// If the total number of dimensions is < nInner the result is nil.
func SplitAtInnerDims(tsr Tensor, nInner int) []int {
sizes := tsr.ShapeSizes()
nd := len(sizes)
if nd < nInner {
return nil
}
rsz := make([]int, nInner+1)
split := nd - nInner
rows := sizes[:split]
copy(rsz[1:], sizes[split:])
nr := 1
for _, r := range rows {
nr *= r
}
rsz[0] = nr
return rsz
}
// FloatAssignFunc sets a to a binary function of a and b float64 values.
func FloatAssignFunc(fun func(a, b float64) float64, a, b Tensor) error {
as, bs, err := AlignForAssign(a, b)
if err != nil {
return err
}
alen := as.Len()
VectorizeThreaded(1, func(tsr ...Tensor) int { return alen },
func(idx int, tsr ...Tensor) {
ai := as.IndexFrom1D(idx)
bi := WrapIndex1D(bs, ai...)
tsr[0].SetFloat1D(fun(tsr[0].Float1D(idx), tsr[1].Float1D(bi)), idx)
}, a, b)
return nil
}
// StringAssignFunc sets a to a binary function of a and b string values.
func StringAssignFunc(fun func(a, b string) string, a, b Tensor) error {
as, bs, err := AlignForAssign(a, b)
if err != nil {
return err
}
alen := as.Len()
VectorizeThreaded(1, func(tsr ...Tensor) int { return alen },
func(idx int, tsr ...Tensor) {
ai := as.IndexFrom1D(idx)
bi := WrapIndex1D(bs, ai...)
tsr[0].SetString1D(fun(tsr[0].String1D(idx), tsr[1].String1D(bi)), idx)
}, a, b)
return nil
}
// FloatBinaryFunc sets output to a binary function of a, b float64 values.
// The flops (floating point operations) estimate is used to control parallel
// threading using goroutines, and should reflect number of flops in the function.
// See [VectorizeThreaded] for more information.
func FloatBinaryFunc(flops int, fun func(a, b float64) float64, a, b Tensor) Tensor {
return CallOut2Gen2(FloatBinaryFuncOut, flops, fun, a, b)
}
// FloatBinaryFuncOut sets output to a binary function of a, b float64 values.
func FloatBinaryFuncOut(flops int, fun func(a, b float64) float64, a, b Tensor, out Values) error {
as, bs, os, err := AlignShapes(a, b)
if err != nil {
return err
}
out.SetShapeSizes(os.Sizes...)
olen := os.Len()
VectorizeThreaded(flops, func(tsr ...Tensor) int { return olen },
func(idx int, tsr ...Tensor) {
oi := os.IndexFrom1D(idx)
ai := WrapIndex1D(as, oi...)
bi := WrapIndex1D(bs, oi...)
out.SetFloat1D(fun(tsr[0].Float1D(ai), tsr[1].Float1D(bi)), idx)
}, a, b, out)
return nil
}
// StringBinaryFunc sets output to a binary function of a, b string values.
func StringBinaryFunc(fun func(a, b string) string, a, b Tensor) Tensor {
return CallOut2Gen1(StringBinaryFuncOut, fun, a, b)
}
// StringBinaryFuncOut sets output to a binary function of a, b string values.
func StringBinaryFuncOut(fun func(a, b string) string, a, b Tensor, out Values) error {
as, bs, os, err := AlignShapes(a, b)
if err != nil {
return err
}
out.SetShapeSizes(os.Sizes...)
olen := os.Len()
VectorizeThreaded(1, func(tsr ...Tensor) int { return olen },
func(idx int, tsr ...Tensor) {
oi := os.IndexFrom1D(idx)
ai := WrapIndex1D(as, oi...)
bi := WrapIndex1D(bs, oi...)
out.SetString1D(fun(tsr[0].String1D(ai), tsr[1].String1D(bi)), idx)
}, a, b, out)
return nil
}
// FloatFunc sets output to a function of tensor float64 values.
// The flops (floating point operations) estimate is used to control parallel
// threading using goroutines, and should reflect number of flops in the function.
// See [VectorizeThreaded] for more information.
func FloatFunc(flops int, fun func(in float64) float64, in Tensor) Values {
return CallOut1Gen2(FloatFuncOut, flops, fun, in)
}
// FloatFuncOut sets output to a function of tensor float64 values.
func FloatFuncOut(flops int, fun func(in float64) float64, in Tensor, out Values) error {
SetShapeFrom(out, in)
n := in.Len()
VectorizeThreaded(flops, func(tsr ...Tensor) int { return n },
func(idx int, tsr ...Tensor) {
tsr[1].SetFloat1D(fun(tsr[0].Float1D(idx)), idx)
}, in, out)
return nil
}
// FloatSetFunc sets tensor float64 values from a function,
// which gets the index. Must be parallel threadsafe.
// The flops (floating point operations) estimate is used to control parallel
// threading using goroutines, and should reflect number of flops in the function.
// See [VectorizeThreaded] for more information.
func FloatSetFunc(flops int, fun func(idx int) float64, a Tensor) error {
n := a.Len()
VectorizeThreaded(flops, func(tsr ...Tensor) int { return n },
func(idx int, tsr ...Tensor) {
tsr[0].SetFloat1D(fun(idx), idx)
}, a)
return nil
}
//////// Bool
// BoolStringsFunc sets boolean output value based on a function involving
// string values from the two tensors.
func BoolStringsFunc(fun func(a, b string) bool, a, b Tensor) *Bool {
out := NewBool()
errors.Log(BoolStringsFuncOut(fun, a, b, out))
return out
}
// BoolStringsFuncOut sets boolean output value based on a function involving
// string values from the two tensors.
func BoolStringsFuncOut(fun func(a, b string) bool, a, b Tensor, out *Bool) error {
as, bs, os, err := AlignShapes(a, b)
if err != nil {
return err
}
out.SetShapeSizes(os.Sizes...)
olen := os.Len()
VectorizeThreaded(5, func(tsr ...Tensor) int { return olen },
func(idx int, tsr ...Tensor) {
oi := os.IndexFrom1D(idx)
ai := WrapIndex1D(as, oi...)
bi := WrapIndex1D(bs, oi...)
out.SetBool1D(fun(tsr[0].String1D(ai), tsr[1].String1D(bi)), idx)
}, a, b, out)
return nil
}
// BoolFloatsFunc sets boolean output value based on a function involving
// float64 values from the two tensors.
func BoolFloatsFunc(fun func(a, b float64) bool, a, b Tensor) *Bool {
out := NewBool()
errors.Log(BoolFloatsFuncOut(fun, a, b, out))
return out
}
// BoolFloatsFuncOut sets boolean output value based on a function involving
// float64 values from the two tensors.
func BoolFloatsFuncOut(fun func(a, b float64) bool, a, b Tensor, out *Bool) error {
as, bs, os, err := AlignShapes(a, b)
if err != nil {
return err
}
out.SetShapeSizes(os.Sizes...)
olen := os.Len()
VectorizeThreaded(5, func(tsr ...Tensor) int { return olen },
func(idx int, tsr ...Tensor) {
oi := os.IndexFrom1D(idx)
ai := WrapIndex1D(as, oi...)
bi := WrapIndex1D(bs, oi...)
out.SetBool1D(fun(tsr[0].Float1D(ai), tsr[1].Float1D(bi)), idx)
}, a, b, out)
return nil
}
// BoolIntsFunc sets boolean output value based on a function involving
// int values from the two tensors.
func BoolIntsFunc(fun func(a, b int) bool, a, b Tensor) *Bool {
out := NewBool()
errors.Log(BoolIntsFuncOut(fun, a, b, out))
return out
}
// BoolIntsFuncOut sets boolean output value based on a function involving
// int values from the two tensors.
func BoolIntsFuncOut(fun func(a, b int) bool, a, b Tensor, out *Bool) error {
as, bs, os, err := AlignShapes(a, b)
if err != nil {
return err
}
out.SetShapeSizes(os.Sizes...)
olen := os.Len()
VectorizeThreaded(5, func(tsr ...Tensor) int { return olen },
func(idx int, tsr ...Tensor) {
oi := os.IndexFrom1D(idx)
ai := WrapIndex1D(as, oi...)
bi := WrapIndex1D(bs, oi...)
out.SetBool1D(fun(tsr[0].Int1D(ai), tsr[1].Int1D(bi)), idx)
}, a, b, out)
return nil
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"reflect"
"slices"
"unsafe"
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/base/reflectx"
"cogentcore.org/core/base/slicesx"
)
// Base is the base Tensor implementation for given type.
type Base[T any] struct {
// shape contains the N-dimensional shape and indexing functionality.
shape Shape
// Values is a flat 1D slice of the underlying data.
Values []T
// Meta data is used extensively for Name, Plot styles, etc.
// Use standard Go camel-case key names, standards in [metadata].
Meta metadata.Data
}
// Metadata returns the metadata for this tensor, which can be used
// to encode plotting options, etc.
func (tsr *Base[T]) Metadata() *metadata.Data { return &tsr.Meta }
func (tsr *Base[T]) Shape() *Shape { return &tsr.shape }
// ShapeSizes returns the sizes of each dimension as a slice of ints.
// The returned slice is a copy and can be modified without side effects.
func (tsr *Base[T]) ShapeSizes() []int { return slices.Clone(tsr.shape.Sizes) }
// SetShapeSizes sets the dimension sizes of the tensor, and resizes
// backing storage appropriately, retaining all existing data that fits.
func (tsr *Base[T]) SetShapeSizes(sizes ...int) {
tsr.shape.SetShapeSizes(sizes...)
nln := tsr.shape.Len()
tsr.Values = slicesx.SetLength(tsr.Values, nln)
}
// Len returns the number of elements in the tensor (product of shape dimensions).
func (tsr *Base[T]) Len() int { return tsr.shape.Len() }
// NumDims returns the total number of dimensions.
func (tsr *Base[T]) NumDims() int { return tsr.shape.NumDims() }
// DimSize returns size of given dimension.
func (tsr *Base[T]) DimSize(dim int) int { return tsr.shape.DimSize(dim) }
// DataType returns the type of the data elements in the tensor.
// Bool is returned for the Bool tensor type.
func (tsr *Base[T]) DataType() reflect.Kind {
var v T
return reflect.TypeOf(v).Kind()
}
func (tsr *Base[T]) Sizeof() int64 {
var v T
return int64(unsafe.Sizeof(v)) * int64(tsr.Len())
}
func (tsr *Base[T]) Bytes() []byte {
return slicesx.ToBytes(tsr.Values)
}
func (tsr *Base[T]) SetFromBytes(b []byte) {
var v T
tsz := unsafe.Sizeof(v)
d := unsafe.Slice((*T)(unsafe.Pointer(&b[0])), len(b)/int(tsz))
copy(tsr.Values, d)
}
func (tsr *Base[T]) Value(i ...int) T {
return tsr.Values[tsr.shape.IndexTo1D(i...)]
}
func (tsr *Base[T]) ValuePtr(i ...int) *T {
return &tsr.Values[tsr.shape.IndexTo1D(i...)]
}
func (tsr *Base[T]) Set(val T, i ...int) {
tsr.Values[tsr.shape.IndexTo1D(i...)] = val
}
func (tsr *Base[T]) Value1D(i int) T { return tsr.Values[i] }
func (tsr *Base[T]) Set1D(val T, i int) { tsr.Values[i] = val }
// SetNumRows sets the number of rows (outermost dimension) in a RowMajor organized tensor.
// It is safe to set this to 0. For incrementally growing tensors (e.g., a log)
// it is best to first set the anticipated full size, which allocates the
// full amount of memory, and then set to 0 and grow incrementally.
func (tsr *Base[T]) SetNumRows(rows int) {
if tsr.NumDims() == 0 {
tsr.SetShapeSizes(0)
}
_, cells := tsr.shape.RowCellSize()
nln := rows * cells
tsr.shape.Sizes[0] = rows
tsr.Values = slicesx.SetLength(tsr.Values, nln)
}
// subSpaceImpl returns a new tensor with innermost subspace at given
// offset(s) in outermost dimension(s) (len(offs) < NumDims).
// The new tensor points to the values of the this tensor (i.e., modifications
// will affect both), as its Values slice is a view onto the original (which
// is why only inner-most contiguous supsaces are supported).
// Use AsValues() method to separate the two.
func (tsr *Base[T]) subSpaceImpl(offs ...int) *Base[T] {
nd := tsr.NumDims()
od := len(offs)
if od > nd {
return nil
}
var ssz []int
if od == nd { // scalar subspace
ssz = []int{1}
} else {
ssz = tsr.shape.Sizes[od:]
}
stsr := &Base[T]{}
stsr.SetShapeSizes(ssz...)
sti := make([]int, nd)
copy(sti, offs)
stoff := tsr.shape.IndexTo1D(sti...)
sln := stsr.Len()
stsr.Values = tsr.Values[stoff : stoff+sln]
return stsr
}
//////// Strings
func (tsr *Base[T]) StringValue(i ...int) string {
return reflectx.ToString(tsr.Values[tsr.shape.IndexTo1D(i...)])
}
func (tsr *Base[T]) String1D(i int) string {
return reflectx.ToString(tsr.Values[i])
}
func (tsr *Base[T]) StringRow(row, cell int) string {
_, sz := tsr.shape.RowCellSize()
return reflectx.ToString(tsr.Values[row*sz+cell])
}
// Label satisfies the core.Labeler interface for a summary description of the tensor.
func (tsr *Base[T]) Label() string {
return label(metadata.Name(tsr), &tsr.shape)
}
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"reflect"
"unsafe"
"cogentcore.org/core/base/slicesx"
)
// ToBinary returns a binary encoding of the tensor that
// includes its type, shape and all data.
// [FromBinary] makes a tensor from this binary data.
func ToBinary(tsr Values) []byte {
shape := []int{int(tsr.DataType()), tsr.NumDims()}
shape = append(shape, tsr.Shape().Sizes...)
b := slicesx.ToBytes(shape)
b = append(b, tsr.Bytes()...)
return b
}
// FromBinary returns a [Values] tensor reconstructed
// from the binary encoding generated by ToBinary.
func FromBinary(b []byte) Values {
shape := unsafe.Slice((*int)(unsafe.Pointer(&b[0])), 2)
typ := reflect.Kind(shape[0])
ndim := shape[1]
shape = unsafe.Slice((*int)(unsafe.Pointer(&b[0])), ndim+2)
tsr := NewOfType(typ, shape[2:]...)
si := int(unsafe.Sizeof(ndim)) * (ndim + 2)
tsr.SetFromBytes(b[si:])
return tsr
}
// Copyright (c) 2024, The Cogent Core Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package bitslice implements a simple slice-of-bits using a []byte slice for storage,
// which is used for efficient storage of boolean data, such as projection connectivity patterns.
package bitslice
import "fmt"
// bitslice.Slice is the slice of []byte that holds the bits.
// first byte maintains the number of bits used in the last byte (0-7).
// when 0 then prior byte is all full and a new one must be added for append.
type Slice []byte
// BitIndex returns the byte, bit index of given bit index
func BitIndex(idx int) (byte int, bit uint32) {
return idx / 8, uint32(idx % 8)
}
// Make makes a new bitslice of given length and capacity (optional, pass 0 for default)
// *bits* (rounds up 1 for both).
// also reserves first byte for extra bits value
func Make(ln, cp int) Slice {
by, bi := BitIndex(ln)
bln := by
if bi != 0 {
bln++
}
var sl Slice
if cp > 0 {
sl = make(Slice, bln+1, (cp/8)+2)
} else {
sl = make(Slice, bln+1)
}
sl[0] = byte(bi)
return sl
}
// Len returns the length of the slice in bits
func (bs *Slice) Len() int {
ln := len(*bs)
if ln == 0 {
return 0
}
eb := (*bs)[0]
bln := ln - 1
if eb != 0 {
bln--
}
tln := bln*8 + int(eb)
return tln
}
// Cap returns the capacity of the slice in bits -- always modulo 8
func (bs *Slice) Cap() int {
return (cap(*bs) - 1) * 8
}
// SetLen sets the length of the slice, copying values if a new allocation is required
func (bs *Slice) SetLen(ln int) {
by, bi := BitIndex(ln)
bln := by
if bi != 0 {
bln++
}
if cap(*bs) >= bln+1 {
*bs = (*bs)[0 : bln+1]
(*bs)[0] = byte(bi)
} else {
sl := make(Slice, bln+1)
sl[0] = byte(bi)
copy(sl, *bs)
*bs = sl
}
}
// Set sets value of given bit index -- no extra range checking is performed -- will panic if out of range
func (bs *Slice) Set(val bool, idx int) {
by, bi := BitIndex(idx)
if val {
(*bs)[by+1] |= 1 << bi
} else {
(*bs)[by+1] &^= 1 << bi
}
}
// Index returns bit value at given bit index
func (bs *Slice) Index(idx int) bool {
by, bi := BitIndex(idx)
return ((*bs)[by+1] & (1 << bi)) != 0
}
// Append adds a bit to the slice and returns possibly new slice, possibly old slice..
func (bs *Slice) Append(val bool) Slice {
if len(*bs) == 0 {
*bs = Make(1, 0)
bs.Set(val, 0)
return *bs
}
ln := bs.Len()
eb := (*bs)[0]
if eb == 0 {
*bs = append(*bs, 0) // now we add
(*bs)[0] = 1
} else if eb < 7 {
(*bs)[0]++
} else {
(*bs)[0] = 0
}
bs.Set(val, ln)
return *bs
}
// SetAll sets all values to either on or off -- much faster than setting individual bits
func (bs *Slice) SetAll(val bool) {
ln := len(*bs)
for i := 1; i < ln; i++ {
if val {
(*bs)[i] = 0xFF
} else {
(*bs)[i] = 0
}
}
}
// ToBools converts to a []bool slice
func (bs *Slice) ToBools() []bool {
ln := len(*bs)
bb := make([]bool, ln)
for i := 0; i < ln; i++ {
bb[i] = bs.Index(i)
}
return bb
}
// Clone creates a new copy of this bitslice with separate memory
func (bs *Slice) Clone() Slice {
cp := make(Slice, len(*bs))
copy(cp, *bs)
return cp
}
// SubSlice returns a new Slice from given start, end range indexes of this slice
// if end is <= 0 then the length of the source slice is used (equivalent to omitting
// the number after the : in a Go subslice expression)
func (bs *Slice) SubSlice(start, end int) Slice {
ln := bs.Len()
if end <= 0 {
end = ln
}
if end > ln {
panic("bitslice.SubSlice: end index is beyond length of slice")
}
if start > end {
panic("bitslice.SubSlice: start index greater than end index")
}
nln := end - start
if nln <= 0 {
return Slice{}
}
ss := Make(nln, 0)
for i := 0; i < nln; i++ {
ss.Set(bs.Index(i+start), i)
}
return ss
}
// Delete returns a new bit slice with N elements removed starting at given index.
// This must be a copy given the nature of the 8-bit aliasing.
func (bs *Slice) Delete(start, n int) Slice {
ln := bs.Len()
if n <= 0 {
panic("bitslice.Delete: n <= 0")
}
if start >= ln {
panic("bitslice.Delete: start index >= length")
}
end := start + n
if end > ln {
panic("bitslice.Delete: end index greater than length")
}
nln := ln - n
if nln <= 0 {
return Slice{}
}
ss := Make(nln, 0)
for i := 0; i < start; i++ {
ss.Set(bs.Index(i), i)
}
for i := end; i < ln; i++ {
ss.Set(bs.Index(i), i-n)
}
return ss
}
// Insert returns a new bit slice with N false elements inserted starting at given index.
// This must be a copy given the nature of the 8-bit aliasing.
func (bs *Slice) Insert(start, n int) Slice {
ln := bs.Len()
if n <= 0 {
panic("bitslice.Insert: n <= 0")
}
if start > ln {
panic("bitslice.Insert: start index greater than length")
}
nln := ln + n
ss := Make(nln, 0)
for i := 0; i < start; i++ {
ss.Set(bs.Index(i), i)
}
for i := start; i < ln; i++ {
ss.Set(bs.Index(i), i+n)
}
return ss
}
// String satisfies the fmt.Stringer interface
func (bs *Slice) String() string {
ln := bs.Len()
if ln == 0 {
if *bs == nil {
return "nil"
}
return "[]"
}
mx := ln
if mx > 1000 {
mx = 1000
}
str := "["
for i := 0; i < mx; i++ {
val := bs.Index(i)
if val {
str += "1 "
} else {
str += "0 "
}
if (i+1)%80 == 0 {
str += "\n"
}
}
if ln > mx {
str += fmt.Sprintf("...(len=%v)", ln)
}
str += "]"
return str
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"fmt"
"reflect"
"slices"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/base/num"
"cogentcore.org/core/base/reflectx"
"cogentcore.org/lab/tensor/bitslice"
)
// Bool is a tensor of bits backed by a [bitslice.Slice] for efficient storage
// of binary, boolean data. Bool does not support [RowMajor.SubSpace] access
// and related methods due to the nature of the underlying data representation.
type Bool struct {
// shape contains the N-dimensional shape and indexing functionality.
shape Shape
// Values is a flat 1D slice of the underlying data, using [bitslice].
Values bitslice.Slice
// Meta data is used extensively for Name, Plot styles, etc.
// Use standard Go camel-case key names, standards in [metadata].
Meta metadata.Data
}
// NewBool returns a new n-dimensional tensor of bit values
// with the given sizes per dimension (shape).
func NewBool(sizes ...int) *Bool {
tsr := &Bool{}
tsr.SetShapeSizes(sizes...)
tsr.Values = bitslice.Make(tsr.Len(), 0)
return tsr
}
// NewBoolShape returns a new n-dimensional tensor of bit values
// using given shape.
func NewBoolShape(shape *Shape) *Bool {
tsr := &Bool{}
tsr.shape.CopyFrom(shape)
tsr.Values = bitslice.Make(tsr.Len(), 0)
return tsr
}
// NewBoolFromValues returns a new 1-dimensional tensor of given value type
// initialized directly from the given slice values, which are not copied.
// The resulting Tensor thus "wraps" the given values.
func NewBoolFromValues(vals ...bool) *Bool {
n := len(vals)
tsr := &Bool{}
tsr.SetShapeSizes(n)
for i, b := range vals {
tsr.Values.Set(b, i)
}
return tsr
}
// Float64ToBool converts float64 value to bool.
func Float64ToBool(val float64) bool {
return num.ToBool(val)
}
// BoolToFloat64 converts bool to float64 value.
func BoolToFloat64(bv bool) float64 {
return num.FromBool[float64](bv)
}
// IntToBool converts int value to bool.
func IntToBool(val int) bool {
return num.ToBool(val)
}
// BoolToInt converts bool to int value.
func BoolToInt(bv bool) int {
return num.FromBool[int](bv)
}
// String satisfies the fmt.Stringer interface for string of tensor data.
func (tsr *Bool) String() string { return Sprintf("", tsr, 0) }
// Label satisfies the core.Labeler interface for a summary description of the tensor
func (tsr *Bool) Label() string {
return label(metadata.Name(tsr), tsr.Shape())
}
func (tsr *Bool) IsString() bool { return false }
func (tsr *Bool) AsValues() Values { return tsr }
// DataType returns the type of the data elements in the tensor.
// Bool is returned for the Bool tensor type.
func (tsr *Bool) DataType() reflect.Kind { return reflect.Bool }
func (tsr *Bool) Sizeof() int64 { return int64(len(tsr.Values)) }
func (tsr *Bool) Bytes() []byte { return tsr.Values }
func (tsr *Bool) SetFromBytes(b []byte) { copy(tsr.Values, b) }
func (tsr *Bool) Shape() *Shape { return &tsr.shape }
// ShapeSizes returns the sizes of each dimension as a slice of ints.
// The returned slice is a copy and can be modified without side effects.
func (tsr *Bool) ShapeSizes() []int { return slices.Clone(tsr.shape.Sizes) }
// Metadata returns the metadata for this tensor, which can be used
// to encode plotting options, etc.
func (tsr *Bool) Metadata() *metadata.Data { return &tsr.Meta }
// Len returns the number of elements in the tensor (product of shape dimensions).
func (tsr *Bool) Len() int { return tsr.shape.Len() }
// NumDims returns the total number of dimensions.
func (tsr *Bool) NumDims() int { return tsr.shape.NumDims() }
// DimSize returns size of given dimension
func (tsr *Bool) DimSize(dim int) int { return tsr.shape.DimSize(dim) }
func (tsr *Bool) SetShapeSizes(sizes ...int) {
tsr.shape.SetShapeSizes(sizes...)
nln := tsr.Len()
tsr.Values.SetLen(nln)
}
// SetNumRows sets the number of rows (outermost dimension) in a RowMajor organized tensor.
// It is safe to set this to 0. For incrementally growing tensors (e.g., a log)
// it is best to first set the anticipated full size, which allocates the
// full amount of memory, and then set to 0 and grow incrementally.
func (tsr *Bool) SetNumRows(rows int) {
_, cells := tsr.shape.RowCellSize()
nln := rows * cells
tsr.shape.Sizes[0] = rows
tsr.Values.SetLen(nln)
}
// SubSpace is not possible with Bool.
func (tsr *Bool) SubSpace(offs ...int) Values { return nil }
// RowTensor not possible with Bool.
func (tsr *Bool) RowTensor(row int) Values { return nil }
// SetRowTensor not possible with Bool.
func (tsr *Bool) SetRowTensor(val Values, row int) {}
// AppendRow not possible with Bool.
func (tsr *Bool) AppendRow(val Values) {}
/////// Bool
func (tsr *Bool) Value(i ...int) bool {
return tsr.Values.Index(tsr.shape.IndexTo1D(i...))
}
func (tsr *Bool) Set(val bool, i ...int) {
tsr.Values.Set(val, tsr.shape.IndexTo1D(i...))
}
func (tsr *Bool) Value1D(i int) bool { return tsr.Values.Index(i) }
func (tsr *Bool) Set1D(val bool, i int) { tsr.Values.Set(val, i) }
/////// Strings
func (tsr *Bool) String1D(off int) string {
return reflectx.ToString(tsr.Values.Index(off))
}
func (tsr *Bool) SetString1D(val string, off int) {
if bv, err := reflectx.ToBool(val); err == nil {
tsr.Values.Set(bv, off)
}
}
func (tsr *Bool) StringValue(i ...int) string {
return reflectx.ToString(tsr.Values.Index(tsr.shape.IndexTo1D(i...)))
}
func (tsr *Bool) SetString(val string, i ...int) {
if bv, err := reflectx.ToBool(val); err == nil {
tsr.Values.Set(bv, tsr.shape.IndexTo1D(i...))
}
}
func (tsr *Bool) StringRow(row, cell int) string {
_, sz := tsr.shape.RowCellSize()
return reflectx.ToString(tsr.Values.Index(row*sz + cell))
}
func (tsr *Bool) SetStringRow(val string, row, cell int) {
if bv, err := reflectx.ToBool(val); err == nil {
_, sz := tsr.shape.RowCellSize()
tsr.Values.Set(bv, row*sz+cell)
}
}
// AppendRowString not possible with Bool.
func (tsr *Bool) AppendRowString(val ...string) {}
/////// Floats
func (tsr *Bool) Float(i ...int) float64 {
return BoolToFloat64(tsr.Values.Index(tsr.shape.IndexTo1D(i...)))
}
func (tsr *Bool) SetFloat(val float64, i ...int) {
tsr.Values.Set(Float64ToBool(val), tsr.shape.IndexTo1D(i...))
}
func (tsr *Bool) Float1D(off int) float64 {
return BoolToFloat64(tsr.Values.Index(off))
}
func (tsr *Bool) SetFloat1D(val float64, off int) {
tsr.Values.Set(Float64ToBool(val), off)
}
func (tsr *Bool) FloatRow(row, cell int) float64 {
_, sz := tsr.shape.RowCellSize()
return BoolToFloat64(tsr.Values.Index(row*sz + cell))
}
func (tsr *Bool) SetFloatRow(val float64, row, cell int) {
_, sz := tsr.shape.RowCellSize()
tsr.Values.Set(Float64ToBool(val), row*sz+cell)
}
// AppendRowFloat not possible with Bool.
func (tsr *Bool) AppendRowFloat(val ...float64) {}
/////// Ints
func (tsr *Bool) Int(i ...int) int {
return BoolToInt(tsr.Values.Index(tsr.shape.IndexTo1D(i...)))
}
func (tsr *Bool) SetInt(val int, i ...int) {
tsr.Values.Set(IntToBool(val), tsr.shape.IndexTo1D(i...))
}
func (tsr *Bool) Int1D(off int) int {
return BoolToInt(tsr.Values.Index(off))
}
func (tsr *Bool) SetInt1D(val int, off int) {
tsr.Values.Set(IntToBool(val), off)
}
func (tsr *Bool) IntRow(row, cell int) int {
_, sz := tsr.shape.RowCellSize()
return BoolToInt(tsr.Values.Index(row*sz + cell))
}
func (tsr *Bool) SetIntRow(val int, row, cell int) {
_, sz := tsr.shape.RowCellSize()
tsr.Values.Set(IntToBool(val), row*sz+cell)
}
// AppendRowInt not possible with Bool.
func (tsr *Bool) AppendRowInt(val ...int) {}
/////// Bools
func (tsr *Bool) Bool(i ...int) bool {
return tsr.Values.Index(tsr.shape.IndexTo1D(i...))
}
func (tsr *Bool) SetBool(val bool, i ...int) {
tsr.Values.Set(val, tsr.shape.IndexTo1D(i...))
}
func (tsr *Bool) Bool1D(off int) bool {
return tsr.Values.Index(off)
}
func (tsr *Bool) SetBool1D(val bool, off int) {
tsr.Values.Set(val, off)
}
// SetZeros is a convenience function initialize all values to 0 (false).
func (tsr *Bool) SetZeros() {
ln := tsr.Len()
for j := 0; j < ln; j++ {
tsr.Values.Set(false, j)
}
}
// SetTrue is simple convenience function initialize all values to 0
func (tsr *Bool) SetTrue() {
ln := tsr.Len()
for j := 0; j < ln; j++ {
tsr.Values.Set(true, j)
}
}
// Clone clones this tensor, creating a duplicate copy of itself with its
// own separate memory representation of all the values, and returns
// that as a Tensor (which can be converted into the known type as needed).
func (tsr *Bool) Clone() Values {
csr := NewBoolShape(&tsr.shape)
csr.Values = tsr.Values.Clone()
return csr
}
// CopyFrom copies all avail values from other tensor into this tensor, with an
// optimized implementation if the other tensor is of the same type, and
// otherwise it goes through appropriate standard type.
func (tsr *Bool) CopyFrom(frm Values) {
if fsm, ok := frm.(*Bool); ok {
copy(tsr.Values, fsm.Values)
return
}
sz := min(len(tsr.Values), frm.Len())
for i := range sz {
tsr.Values.Set(Float64ToBool(frm.Float1D(i)), i)
}
}
// AppendFrom appends values from other tensor into this tensor,
// which must have the same cell size as this tensor.
// It uses and optimized implementation if the other tensor
// is of the same type, and otherwise it goes through
// appropriate standard type.
func (tsr *Bool) AppendFrom(frm Values) Values {
rows, cell := tsr.shape.RowCellSize()
frows, fcell := frm.Shape().RowCellSize()
if cell != fcell {
errors.Log(fmt.Errorf("tensor.AppendFrom: cell sizes do not match: %d != %d", cell, fcell))
return tsr
}
tsr.SetNumRows(rows + frows)
st := rows * cell
fsz := frows * fcell
if fsm, ok := frm.(*Bool); ok {
copy(tsr.Values[st:st+fsz], fsm.Values)
return tsr
}
for i := range fsz {
tsr.Values.Set(Float64ToBool(frm.Float1D(i)), st+i)
}
return tsr
}
// CopyCellsFrom copies given range of values from other tensor into this tensor,
// using flat 1D indexes: to = starting index in this Tensor to start copying into,
// start = starting index on from Tensor to start copying from, and n = number of
// values to copy. Uses an optimized implementation if the other tensor is
// of the same type, and otherwise it goes through appropriate standard type.
func (tsr *Bool) CopyCellsFrom(frm Values, to, start, n int) {
if fsm, ok := frm.(*Bool); ok {
for i := range n {
tsr.Values.Set(fsm.Values.Index(start+i), to+i)
}
return
}
for i := range n {
tsr.Values.Set(Float64ToBool(frm.Float1D(start+i)), to+i)
}
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"math"
"cogentcore.org/core/base/errors"
)
// Clone returns a copy of the given tensor.
// If it is raw [Values] then a [Values.Clone] is returned.
// Otherwise if it is a view, then [Tensor.AsValues] is returned.
// This is equivalent to the NumPy copy function.
func Clone(tsr Tensor) Values {
if vl, ok := tsr.(Values); ok {
return vl.Clone()
}
return tsr.AsValues()
}
// Flatten returns a copy of the given tensor as a 1D flat list
// of values, by calling Clone(As1D(tsr)).
// It is equivalent to the NumPy flatten function.
func Flatten(tsr Tensor) Values {
if msk, ok := tsr.(*Masked); ok {
return msk.AsValues()
}
return Clone(As1D(tsr))
}
// Squeeze a [Reshaped] view of given tensor with all singleton
// (size = 1) dimensions removed (if none, just returns the tensor).
func Squeeze(tsr Tensor) Tensor {
nd := tsr.NumDims()
sh := tsr.ShapeSizes()
reshape := make([]int, 0, nd)
for _, sz := range sh {
if sz > 1 {
reshape = append(reshape, sz)
}
}
if len(reshape) == nd {
return tsr
}
return NewReshaped(tsr, reshape...)
}
// As1D returns a 1D tensor, which is either the input tensor if it is
// already 1D, or a new [Reshaped] 1D view of it.
// This can be useful e.g., for stats and metric functions that operate
// on a 1D list of values. See also [Flatten].
func As1D(tsr Tensor) Tensor {
if tsr.NumDims() == 1 {
return tsr
}
return NewReshaped(tsr, tsr.Len())
}
// Cells1D returns a flat 1D view of the innermost cells for given row index.
// For a [RowMajor] tensor, it uses the [RowTensor] subspace directly,
// otherwise it uses [Sliced] to extract the cells. In either case,
// [As1D] is used to ensure the result is a 1D tensor.
func Cells1D(tsr Tensor, row int) Tensor {
if rm, ok := tsr.(RowMajor); ok {
return As1D(rm.RowTensor(row))
}
return As1D(NewSliced(tsr, []int{row}))
}
// MustBeValues returns the given tensor as a [Values] subtype, or nil and
// an error if it is not one. Typically outputs of compute operations must
// be values, and are reshaped to hold the results as needed.
func MustBeValues(tsr Tensor) (Values, error) {
vl, ok := tsr.(Values)
if !ok {
return nil, errors.New("tensor.MustBeValues: tensor must be a Values type")
}
return vl, nil
}
// MustBeSameShape returns an error if the two tensors do not have the same shape.
func MustBeSameShape(a, b Tensor) error {
if !a.Shape().IsEqual(b.Shape()) {
return errors.New("tensor.MustBeSameShape: tensors must have the same shape")
}
return nil
}
// SetShape sets the dimension sizes from given Shape
func SetShape(vals Values, sh *Shape) {
vals.SetShapeSizes(sh.Sizes...)
}
// SetShapeSizesFromTensor sets the dimension sizes as 1D int values from given tensor.
// The backing storage is resized appropriately, retaining all existing data that fits.
func SetShapeSizesFromTensor(vals Values, sizes Tensor) {
vals.SetShapeSizes(AsIntSlice(sizes)...)
}
// SetShapeFrom sets shape of given tensor from a source tensor.
func SetShapeFrom(vals Values, from Tensor) {
vals.SetShapeSizes(from.ShapeSizes()...)
}
// AsFloat64Scalar returns the first value of tensor as a float64 scalar.
// Returns 0 if no values.
func AsFloat64Scalar(tsr Tensor) float64 {
if tsr.Len() == 0 {
return 0
}
return tsr.Float1D(0)
}
// AsIntScalar returns the first value of tensor as an int scalar.
// Returns 0 if no values.
func AsIntScalar(tsr Tensor) int {
if tsr.Len() == 0 {
return 0
}
return tsr.Int1D(0)
}
// AsStringScalar returns the first value of tensor as a string scalar.
// Returns "" if no values.
func AsStringScalar(tsr Tensor) string {
if tsr.Len() == 0 {
return ""
}
return tsr.String1D(0)
}
// AsFloat64Slice returns all the tensor values as a slice of float64's.
// This allocates a new slice for the return values, and is not
// a good option for performance-critical code.
func AsFloat64Slice(tsr Tensor) []float64 {
if tsr.Len() == 0 {
return nil
}
sz := tsr.Len()
slc := make([]float64, sz)
for i := range sz {
slc[i] = tsr.Float1D(i)
}
return slc
}
// AsIntSlice returns all the tensor values as a slice of ints.
// This allocates a new slice for the return values, and is not
// a good option for performance-critical code.
func AsIntSlice(tsr Tensor) []int {
if tsr.Len() == 0 {
return nil
}
sz := tsr.Len()
slc := make([]int, sz)
for i := range sz {
slc[i] = tsr.Int1D(i)
}
return slc
}
// AsStringSlice returns all the tensor values as a slice of strings.
// This allocates a new slice for the return values, and is not
// a good option for performance-critical code.
func AsStringSlice(tsr Tensor) []string {
if tsr.Len() == 0 {
return nil
}
sz := tsr.Len()
slc := make([]string, sz)
for i := range sz {
slc[i] = tsr.String1D(i)
}
return slc
}
// AsFloat64 returns the tensor as a [Float64] tensor.
// If already is a Float64, it is returned as such.
// Otherwise, a new Float64 tensor is created and values are copied.
// Use this function for interfacing with gonum or other apis that
// only operate on float64 types.
func AsFloat64(tsr Tensor) *Float64 {
if f, ok := tsr.(*Float64); ok {
return f
}
f := NewFloat64(tsr.ShapeSizes()...)
f.CopyFrom(tsr.AsValues())
return f
}
// AsFloat32 returns the tensor as a [Float32] tensor.
// If already is a Float32, it is returned as such.
// Otherwise, a new Float32 tensor is created and values are copied.
func AsFloat32(tsr Tensor) *Float32 {
if f, ok := tsr.(*Float32); ok {
return f
}
f := NewFloat32(tsr.ShapeSizes()...)
f.CopyFrom(tsr.AsValues())
return f
}
// AsString returns the tensor as a [String] tensor.
// If already is a String, it is returned as such.
// Otherwise, a new String tensor is created and values are copied.
func AsString(tsr Tensor) *String {
if f, ok := tsr.(*String); ok {
return f
}
f := NewString(tsr.ShapeSizes()...)
f.CopyFrom(tsr.AsValues())
return f
}
// AsInt returns the tensor as a [Int] tensor.
// If already is a Int, it is returned as such.
// Otherwise, a new Int tensor is created and values are copied.
func AsInt(tsr Tensor) *Int {
if f, ok := tsr.(*Int); ok {
return f
}
f := NewInt(tsr.ShapeSizes()...)
f.CopyFrom(tsr.AsValues())
return f
}
// Range returns the min, max (and associated indexes, -1 = no values) for the tensor.
// This is needed for display and is thus in the tensor api on Values.
func Range(vals Values) (min, max float64, minIndex, maxIndex int) {
minIndex = -1
maxIndex = -1
n := vals.Len()
for j := range n {
fv := vals.Float1D(j)
if math.IsNaN(fv) {
continue
}
if fv < min || minIndex < 0 {
min = fv
minIndex = j
}
if fv > max || maxIndex < 0 {
max = fv
maxIndex = j
}
}
return
}
// ContainsFloat returns true if source tensor contains any of given vals,
// using Float value method for comparison.
func ContainsFloat(tsr, vals Tensor) bool {
nv := vals.Len()
if nv == 0 {
return false
}
n := tsr.Len()
for i := range n {
tv := tsr.Float1D(i)
for j := range nv {
if tv == vals.Float1D(j) {
return true
}
}
}
return false
}
// ContainsInt returns true if source tensor contains any of given vals,
// using Int value method for comparison.
func ContainsInt(tsr, vals Tensor) bool {
nv := vals.Len()
if nv == 0 {
return false
}
n := tsr.Len()
for i := range n {
tv := tsr.Int1D(i)
for j := range nv {
if tv == vals.Int1D(j) {
return true
}
}
}
return false
}
// ContainsString returns true if source tensor contains any of given vals,
// using String value method for comparison, and given options for how to
// compare the strings.
func ContainsString(tsr, vals Tensor, opts StringMatch) bool {
nv := vals.Len()
if nv == 0 {
return false
}
n := tsr.Len()
for i := range n {
tv := tsr.String1D(i)
for j := range nv {
if opts.Match(tv, vals.String1D(j)) {
return true
}
}
}
return false
}
// CopyFromLargerShape copies values from another tensor of a larger
// shape, using indexes in this shape. The other tensor must have at
// least the same or greater shape values on each dimension as the target.
// Uses float numbers to copy if not a string.
func CopyFromLargerShape(tsr, from Tensor) {
n := tsr.Len()
if tsr.IsString() {
for i := range n {
idx := tsr.Shape().IndexFrom1D(i)
v := from.StringValue(idx...)
tsr.SetString(v, idx...)
}
} else {
for i := range n {
idx := tsr.Shape().IndexFrom1D(i)
v := from.Float(idx...)
tsr.SetFloat(v, idx...)
}
}
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"math/rand"
"slices"
"cogentcore.org/core/base/slicesx"
)
// NewFromValues returns a new Values tensor from given list of values
// which must be the same type, for the supported tensor types including
// string, bool, and standard float, int types.
func NewFromValues(val ...any) Values {
if len(val) == 0 {
return nil
}
switch val[0].(type) {
case string:
return NewStringFromValues(slicesx.As[any, string](val)...)
case bool:
return NewBoolFromValues(slicesx.As[any, bool](val)...)
case float64:
return NewNumberFromValues(slicesx.As[any, float64](val)...)
case float32:
return NewNumberFromValues(slicesx.As[any, float32](val)...)
case int:
return NewNumberFromValues(slicesx.As[any, int](val)...)
case int32:
return NewNumberFromValues(slicesx.As[any, int32](val)...)
case uint32:
return NewNumberFromValues(slicesx.As[any, uint32](val)...)
case int64:
return NewNumberFromValues(slicesx.As[any, int64](val)...)
case byte:
return NewNumberFromValues(slicesx.As[any, byte](val)...)
}
return nil
}
// NewFloat64Scalar is a convenience method for a Tensor
// representation of a single float64 scalar value.
func NewFloat64Scalar(val float64) *Float64 {
return NewNumberFromValues(val)
}
// NewFloat32Scalar is a convenience method for a Tensor
// representation of a single float32 scalar value.
func NewFloat32Scalar(val float32) *Float32 {
return NewNumberFromValues(val)
}
// NewIntScalar is a convenience method for a Tensor
// representation of a single int scalar value.
func NewIntScalar(val int) *Int {
return NewNumberFromValues(val)
}
// NewStringScalar is a convenience method for a Tensor
// representation of a single string scalar value.
func NewStringScalar(val string) *String {
return NewStringFromValues(val)
}
// NewFloat64FromValues returns a new 1-dimensional tensor of given value type
// initialized directly from the given slice values, which are not copied.
// The resulting Tensor thus "wraps" the given values.
func NewFloat64FromValues(vals ...float64) *Float64 {
return NewNumberFromValues(vals...)
}
// NewFloat32FromValues returns a new 1-dimensional tensor of given value type
// initialized directly from the given slice values, which are not copied.
// The resulting Tensor thus "wraps" the given values.
func NewFloat32FromValues(vals ...float32) *Float32 {
return NewNumberFromValues(vals...)
}
// NewIntFromValues returns a new 1-dimensional tensor of given value type
// initialized directly from the given slice values, which are not copied.
// The resulting Tensor thus "wraps" the given values.
func NewIntFromValues(vals ...int) *Int {
return NewNumberFromValues(vals...)
}
// NewStringFromValues returns a new 1-dimensional tensor of given value type
// initialized directly from the given slice values, which are not copied.
// The resulting Tensor thus "wraps" the given values.
func NewStringFromValues(vals ...string) *String {
n := len(vals)
tsr := &String{}
tsr.Values = vals
tsr.SetShapeSizes(n)
return tsr
}
// SetAllFloat64 sets all values of given tensor to given value.
func SetAllFloat64(tsr Tensor, val float64) {
VectorizeThreaded(1, func(tsr ...Tensor) int { return tsr[0].Len() },
func(idx int, tsr ...Tensor) {
tsr[0].SetFloat1D(val, idx)
}, tsr)
}
// SetAllInt sets all values of given tensor to given value.
func SetAllInt(tsr Tensor, val int) {
VectorizeThreaded(1, func(tsr ...Tensor) int { return tsr[0].Len() },
func(idx int, tsr ...Tensor) {
tsr[0].SetInt1D(val, idx)
}, tsr)
}
// SetAllString sets all values of given tensor to given value.
func SetAllString(tsr Tensor, val string) {
VectorizeThreaded(1, func(tsr ...Tensor) int { return tsr[0].Len() },
func(idx int, tsr ...Tensor) {
tsr[0].SetString1D(val, idx)
}, tsr)
}
// NewFloat64Full returns a new tensor full of given scalar value,
// of given shape sizes.
func NewFloat64Full(val float64, sizes ...int) *Float64 {
tsr := NewFloat64(sizes...)
SetAllFloat64(tsr, val)
return tsr
}
// NewFloat64Ones returns a new tensor full of 1s,
// of given shape sizes.
func NewFloat64Ones(sizes ...int) *Float64 {
tsr := NewFloat64(sizes...)
SetAllFloat64(tsr, 1.0)
return tsr
}
// NewIntFull returns a new tensor full of given scalar value,
// of given shape sizes.
func NewIntFull(val int, sizes ...int) *Int {
tsr := NewInt(sizes...)
SetAllInt(tsr, val)
return tsr
}
// NewStringFull returns a new tensor full of given scalar value,
// of given shape sizes.
func NewStringFull(val string, sizes ...int) *String {
tsr := NewString(sizes...)
SetAllString(tsr, val)
return tsr
}
// NewFloat64Rand returns a new tensor full of random numbers from
// global random source, of given shape sizes.
func NewFloat64Rand(sizes ...int) *Float64 {
tsr := NewFloat64(sizes...)
FloatSetFunc(1, func(idx int) float64 { return rand.Float64() }, tsr)
return tsr
}
// NewIntRange returns a new [Int] [Tensor] with given [Slice]
// range parameters, with the same semantics as NumPy arange based on
// the number of arguments passed:
// - 1 = stop
// - 2 = start, stop
// - 3 = start, stop, step
func NewIntRange(svals ...int) *Int {
if len(svals) == 0 {
return NewInt()
}
sl := Slice{}
switch len(svals) {
case 1:
sl.Stop = svals[0]
case 2:
sl.Start = svals[0]
sl.Stop = svals[1]
case 3:
sl.Start = svals[0]
sl.Stop = svals[1]
sl.Step = svals[2]
}
return sl.IntTensor(sl.Stop)
}
// NewFloat64SpacedLinear returns a new [Float64] tensor with num linearly
// spaced numbers between start and stop values, as tensors, which
// must be the same length and determine the cell shape of the output.
// If num is 0, then a default of 50 is used.
// If endpoint = true, then the stop value is _inclusive_, i.e., it will
// be the final value, otherwise it is exclusive.
// This corresponds to the NumPy linspace function.
func NewFloat64SpacedLinear(start, stop Tensor, num int, endpoint bool) *Float64 {
if num <= 0 {
num = 50
}
fnum := float64(num)
if endpoint {
fnum -= 1
}
step := Clone(start)
n := step.Len()
for i := range n {
step.SetFloat1D((stop.Float1D(i)-start.Float1D(i))/fnum, i)
}
var tsr *Float64
if start.Len() == 1 {
tsr = NewFloat64(num)
} else {
tsz := slices.Clone(start.Shape().Sizes)
tsz = append([]int{num}, tsz...)
tsr = NewFloat64(tsz...)
}
for r := range num {
for i := range n {
tsr.SetFloatRow(start.Float1D(i)+float64(r)*step.Float1D(i), r, i)
}
}
return tsr
}
// Code generated by "core generate"; DO NOT EDIT.
package tensor
import (
"cogentcore.org/core/enums"
)
var _DelimsValues = []Delims{0, 1, 2, 3}
// DelimsN is the highest valid value for type Delims, plus one.
const DelimsN Delims = 4
var _DelimsValueMap = map[string]Delims{`Tab`: 0, `Comma`: 1, `Space`: 2, `Detect`: 3}
var _DelimsDescMap = map[Delims]string{0: `Tab is the tab rune delimiter, for TSV tab separated values`, 1: `Comma is the comma rune delimiter, for CSV comma separated values`, 2: `Space is the space rune delimiter, for SSV space separated value`, 3: `Detect is used during reading a file -- reads the first line and detects tabs or commas`}
var _DelimsMap = map[Delims]string{0: `Tab`, 1: `Comma`, 2: `Space`, 3: `Detect`}
// String returns the string representation of this Delims value.
func (i Delims) String() string { return enums.String(i, _DelimsMap) }
// SetString sets the Delims value from its string representation,
// and returns an error if the string is invalid.
func (i *Delims) SetString(s string) error { return enums.SetString(i, s, _DelimsValueMap, "Delims") }
// Int64 returns the Delims value as an int64.
func (i Delims) Int64() int64 { return int64(i) }
// SetInt64 sets the Delims value from an int64.
func (i *Delims) SetInt64(in int64) { *i = Delims(in) }
// Desc returns the description of the Delims value.
func (i Delims) Desc() string { return enums.Desc(i, _DelimsDescMap) }
// DelimsValues returns all possible values for the type Delims.
func DelimsValues() []Delims { return _DelimsValues }
// Values returns all possible values for the type Delims.
func (i Delims) Values() []enums.Enum { return enums.Values(_DelimsValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i Delims) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *Delims) UnmarshalText(text []byte) error { return enums.UnmarshalText(i, text, "Delims") }
var _SlicesMagicValues = []SlicesMagic{0, 1, 2}
// SlicesMagicN is the highest valid value for type SlicesMagic, plus one.
const SlicesMagicN SlicesMagic = 3
var _SlicesMagicValueMap = map[string]SlicesMagic{`FullAxis`: 0, `NewAxis`: 1, `Ellipsis`: 2}
var _SlicesMagicDescMap = map[SlicesMagic]string{0: `FullAxis indicates that the full existing axis length should be used. This is equivalent to Slice{}, but is more semantic. In NumPy it is equivalent to a single : colon.`, 1: `NewAxis creates a new singleton (length=1) axis, used to to reshape without changing the size. Can also be used in [Reshaped].`, 2: `Ellipsis (...) is used in [NewSliced] expressions to produce a flexibly-sized stretch of FullAxis dimensions, which automatically aligns the remaining slice elements based on the source dimensionality.`}
var _SlicesMagicMap = map[SlicesMagic]string{0: `FullAxis`, 1: `NewAxis`, 2: `Ellipsis`}
// String returns the string representation of this SlicesMagic value.
func (i SlicesMagic) String() string { return enums.String(i, _SlicesMagicMap) }
// SetString sets the SlicesMagic value from its string representation,
// and returns an error if the string is invalid.
func (i *SlicesMagic) SetString(s string) error {
return enums.SetString(i, s, _SlicesMagicValueMap, "SlicesMagic")
}
// Int64 returns the SlicesMagic value as an int64.
func (i SlicesMagic) Int64() int64 { return int64(i) }
// SetInt64 sets the SlicesMagic value from an int64.
func (i *SlicesMagic) SetInt64(in int64) { *i = SlicesMagic(in) }
// Desc returns the description of the SlicesMagic value.
func (i SlicesMagic) Desc() string { return enums.Desc(i, _SlicesMagicDescMap) }
// SlicesMagicValues returns all possible values for the type SlicesMagic.
func SlicesMagicValues() []SlicesMagic { return _SlicesMagicValues }
// Values returns all possible values for the type SlicesMagic.
func (i SlicesMagic) Values() []enums.Enum { return enums.Values(_SlicesMagicValues) }
// MarshalText implements the [encoding.TextMarshaler] interface.
func (i SlicesMagic) MarshalText() ([]byte, error) { return []byte(i.String()), nil }
// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (i *SlicesMagic) UnmarshalText(text []byte) error {
return enums.UnmarshalText(i, text, "SlicesMagic")
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"fmt"
"reflect"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/metadata"
)
// Func represents a registered tensor function, which has
// In number of input Tensor arguments, and Out number of output
// arguments (typically 1). There can also be an 'any' first
// argument to support other kinds of parameters.
// This is used to make tensor functions available to the Goal language.
type Func struct {
// Name is the original CamelCase Go name for function
Name string
// Fun is the function, which must _only_ take some number of Tensor
// args, with an optional any first arg.
Fun any
// Args has parsed information about the function args, for Goal.
Args []*Arg
}
// Arg has key information that Goal needs about each arg, for converting
// expressions into the appropriate type.
type Arg struct {
// Type has full reflection type info.
Type reflect.Type
// IsTensor is true if it satisfies the Tensor interface.
IsTensor bool
// IsInt is true if Kind = Int, for shape, slice etc params.
IsInt bool
// IsVariadic is true if this is the last arg and has ...; type will be an array.
IsVariadic bool
}
// NewFunc creates a new Func desciption of the given
// function, which must have a signature like this:
// func([opt any,] a, b, out tensor.Tensor) error
// i.e., taking some specific number of Tensor arguments (up to 5).
// Functions can also take an 'any' first argument to handle other
// non-tensor inputs (e.g., function pointer, dirfs directory, etc).
// The name should be a standard 'package.FuncName' qualified, exported
// CamelCase name, with 'out' indicating the number of output arguments,
// and an optional arg indicating an 'any' first argument.
// The remaining arguments in the function (automatically
// determined) are classified as input arguments.
func NewFunc(name string, fun any) (*Func, error) {
fn := &Func{Name: name, Fun: fun}
fn.GetArgs()
return fn, nil
}
// GetArgs gets key info about each arg, for use by Goal transpiler.
func (fn *Func) GetArgs() {
ft := reflect.TypeOf(fn.Fun)
n := ft.NumIn()
if n == 0 {
return
}
fn.Args = make([]*Arg, n)
tsrt := reflect.TypeFor[Tensor]()
for i := range n {
at := ft.In(i)
ag := &Arg{Type: at}
if ft.IsVariadic() && i == n-1 {
ag.IsVariadic = true
}
if at.Kind() == reflect.Int || (at.Kind() == reflect.Slice && at.Elem().Kind() == reflect.Int) {
ag.IsInt = true
} else if at.Implements(tsrt) {
ag.IsTensor = true
}
fn.Args[i] = ag
}
}
func (fn *Func) String() string {
s := fn.Name + "("
na := len(fn.Args)
for i, a := range fn.Args {
if a.IsVariadic {
s += "..."
}
ts := a.Type.String()
if ts == "interface {}" {
ts = "any"
}
s += ts
if i < na-1 {
s += ", "
}
}
s += ")"
return s
}
// Funcs is the global tensor named function registry.
// All functions must have a signature like this:
// func([opt any,] a, b, out tensor.Tensor) error
// i.e., taking some specific number of Tensor arguments (up to 5),
// with the number of output vs. input arguments registered.
// Functions can also take an 'any' first argument to handle other
// non-tensor inputs (e.g., function pointer, dirfs directory, etc).
// This is used to make tensor functions available to the Goal
// language.
var Funcs map[string]*Func
// AddFunc adds given named function to the global tensor named function
// registry, which is used by Goal to call functions by name.
// See [NewFunc] for more informa.tion.
func AddFunc(name string, fun any) error {
if Funcs == nil {
Funcs = make(map[string]*Func)
}
_, ok := Funcs[name]
if ok {
return fmt.Errorf("tensor.AddFunc: function of name %q already exists, not added", name)
}
fn, err := NewFunc(name, fun)
if errors.Log(err) != nil {
return err
}
Funcs[name] = fn
// note: can record orig camel name if needed for docs etc later.
return nil
}
// FuncByName finds function of given name in the registry,
// returning an error if the function name has not been registered.
func FuncByName(name string) (*Func, error) {
fn, ok := Funcs[name]
if !ok {
return nil, fmt.Errorf("tensor.FuncByName: function of name %q not registered", name)
}
return fn, nil
}
// These generic functions provide a one liner for wrapping functions
// that take an output Tensor as the last argument, which is important
// for memory re-use of the output in performance-critical cases.
// The names indicate the number of input tensor arguments.
// Additional generic non-Tensor inputs are supported up to 2,
// with Gen1 and Gen2 versions.
// FloatPromoteType returns the DataType for Tensor(s) that promotes
// the Float type if any of the elements are of that type.
// Otherwise it returns the type of the first tensor.
func FloatPromoteType(tsr ...Tensor) reflect.Kind {
ft := tsr[0].DataType()
for i := 1; i < len(tsr); i++ {
t := tsr[i].DataType()
if t == reflect.Float64 {
ft = t
} else if t == reflect.Float32 && ft != reflect.Float64 {
ft = t
}
}
return ft
}
// CallOut1 adds output [Values] tensor for function.
func CallOut1(fun func(a Tensor, out Values) error, a Tensor) Values {
out := NewOfType(a.DataType())
errors.Log(fun(a, out))
return out
}
// CallOut1Float64 adds Float64 output [Values] tensor for function.
func CallOut1Float64(fun func(a Tensor, out Values) error, a Tensor) Values {
out := NewFloat64()
errors.Log(fun(a, out))
return out
}
func CallOut2Float64(fun func(a, b Tensor, out Values) error, a, b Tensor) Values {
out := NewFloat64()
errors.Log(fun(a, b, out))
return out
}
func CallOut2(fun func(a, b Tensor, out Values) error, a, b Tensor) Values {
out := NewOfType(FloatPromoteType(a, b))
errors.Log(fun(a, b, out))
return out
}
func CallOut3(fun func(a, b, c Tensor, out Values) error, a, b, c Tensor) Values {
out := NewOfType(FloatPromoteType(a, b, c))
errors.Log(fun(a, b, c, out))
return out
}
func CallOut2Bool(fun func(a, b Tensor, out *Bool) error, a, b Tensor) *Bool {
out := NewBool()
errors.Log(fun(a, b, out))
return out
}
func CallOut1Gen1[T any](fun func(g T, a Tensor, out Values) error, g T, a Tensor) Values {
out := NewOfType(a.DataType())
errors.Log(fun(g, a, out))
return out
}
func CallOut1Gen2[T any, S any](fun func(g T, h S, a Tensor, out Values) error, g T, h S, a Tensor) Values {
out := NewOfType(a.DataType())
errors.Log(fun(g, h, a, out))
return out
}
func CallOut2Gen1[T any](fun func(g T, a, b Tensor, out Values) error, g T, a, b Tensor) Values {
out := NewOfType(FloatPromoteType(a, b))
errors.Log(fun(g, a, b, out))
return out
}
func CallOut2Gen2[T any, S any](fun func(g T, h S, a, b Tensor, out Values) error, g T, h S, a, b Tensor) Values {
out := NewOfType(FloatPromoteType(a, b))
errors.Log(fun(g, h, a, b, out))
return out
}
//////// Metadata
// SetCalcFunc sets a function to calculate updated value for given tensor,
// storing the function pointer in the Metadata "CalcFunc" key for the tensor.
// Can be called by [Calc] function.
func SetCalcFunc(tsr Tensor, fun func() error) {
tsr.Metadata().Set("CalcFunc", fun)
}
// Calc calls function set by [SetCalcFunc] to compute an updated value for
// given tensor. Returns an error if func not set, or any error from func itself.
// Function is stored as CalcFunc in Metadata.
func Calc(tsr Tensor) error {
fun, err := metadata.Get[func() error](*tsr.Metadata(), "CalcFunc")
if err != nil {
return err
}
return fun()
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"reflect"
"slices"
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/base/reflectx"
)
// Indexed provides an arbitrarily indexed view onto another "source" [Tensor]
// with each index value providing a full n-dimensional index into the source.
// The shape of this view is determined by the shape of the [Indexed.Indexes]
// tensor up to the final innermost dimension, which holds the index values.
// Thus the innermost dimension size of the indexes is equal to the number
// of dimensions in the source tensor. Given the essential role of the
// indexes in this view, it is not usable without the indexes.
// This view is not memory-contiguous and does not support the [RowMajor]
// interface or efficient access to inner-dimensional subspaces.
// To produce a new concrete [Values] that has raw data actually
// organized according to the indexed order (i.e., the copy function
// of numpy), call [Indexed.AsValues].
type Indexed struct { //types:add
// Tensor source that we are an indexed view onto.
Tensor Tensor
// Indexes is the list of indexes into the source tensor,
// with the innermost dimension providing the index values
// (size = number of dimensions in the source tensor), and
// the remaining outer dimensions determine the shape
// of this [Indexed] tensor view.
Indexes *Int
}
// NewIndexed returns a new [Indexed] view of given tensor,
// with tensor of indexes into the source tensor.
func NewIndexed(tsr Tensor, idx *Int) *Indexed {
ix := &Indexed{Tensor: tsr}
ix.Indexes = idx
return ix
}
// AsIndexed returns the tensor as a [Indexed] view, if it is one.
// Otherwise, it returns nil; there is no usable "null" Indexed view.
func AsIndexed(tsr Tensor) *Indexed {
if ix, ok := tsr.(*Indexed); ok {
return ix
}
return nil
}
// SetTensor sets as indexes into given tensor with sequential initial indexes.
func (ix *Indexed) SetTensor(tsr Tensor) {
ix.Tensor = tsr
}
// SourceIndexes returns the actual indexes into underlying source tensor
// based on given list of indexes into the [Indexed.Indexes] tensor,
// _excluding_ the final innermost dimension.
func (ix *Indexed) SourceIndexes(i ...int) []int {
idx := slices.Clone(i)
idx = append(idx, 0) // first index
oned := ix.Indexes.Shape().IndexTo1D(idx...)
nd := ix.Tensor.NumDims()
return ix.Indexes.Values[oned : oned+nd]
}
// SourceIndexesFrom1D returns the full indexes into source tensor based on the
// given 1d index, which is based on the outer dimensions, excluding the
// final innermost dimension.
func (ix *Indexed) SourceIndexesFrom1D(oned int) []int {
nd := ix.Tensor.NumDims()
oned *= nd
return ix.Indexes.Values[oned : oned+nd]
}
func (ix *Indexed) Label() string { return label(metadata.Name(ix), ix.Shape()) }
func (ix *Indexed) String() string { return Sprintf("", ix, 0) }
func (ix *Indexed) Metadata() *metadata.Data { return ix.Tensor.Metadata() }
func (ix *Indexed) IsString() bool { return ix.Tensor.IsString() }
func (ix *Indexed) DataType() reflect.Kind { return ix.Tensor.DataType() }
func (ix *Indexed) Shape() *Shape { return NewShape(ix.ShapeSizes()...) }
func (ix *Indexed) Len() int { return ix.Shape().Len() }
func (ix *Indexed) NumDims() int { return ix.Indexes.NumDims() - 1 }
func (ix *Indexed) DimSize(dim int) int { return ix.Indexes.DimSize(dim) }
func (ix *Indexed) ShapeSizes() []int {
si := slices.Clone(ix.Indexes.ShapeSizes())
return si[:len(si)-1] // exclude last dim
}
// AsValues returns a copy of this tensor as raw [Values].
// This "renders" the Indexed view into a fully contiguous
// and optimized memory representation of that view, which will be faster
// to access for further processing, and enables all the additional
// functionality provided by the [Values] interface.
func (ix *Indexed) AsValues() Values {
dt := ix.Tensor.DataType()
vt := NewOfType(dt, ix.ShapeSizes()...)
n := ix.Len()
switch {
case ix.Tensor.IsString():
for i := range n {
vt.SetString1D(ix.String1D(i), i)
}
case reflectx.KindIsFloat(dt):
for i := range n {
vt.SetFloat1D(ix.Float1D(i), i)
}
default:
for i := range n {
vt.SetInt1D(ix.Int1D(i), i)
}
}
return vt
}
//////// Floats
// Float returns the value of given index as a float64.
// The indexes are indirected through the [Indexed.Indexes].
func (ix *Indexed) Float(i ...int) float64 {
return ix.Tensor.Float(ix.SourceIndexes(i...)...)
}
// SetFloat sets the value of given index as a float64
// The indexes are indirected through the [Indexed.Indexes].
func (ix *Indexed) SetFloat(val float64, i ...int) {
ix.Tensor.SetFloat(val, ix.SourceIndexes(i...)...)
}
// Float1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (ix *Indexed) Float1D(i int) float64 {
return ix.Tensor.Float(ix.SourceIndexesFrom1D(i)...)
}
// SetFloat1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (ix *Indexed) SetFloat1D(val float64, i int) {
ix.Tensor.SetFloat(val, ix.SourceIndexesFrom1D(i)...)
}
//////// Strings
// StringValue returns the value of given index as a string.
// The indexes are indirected through the [Indexed.Indexes].
func (ix *Indexed) StringValue(i ...int) string {
return ix.Tensor.StringValue(ix.SourceIndexes(i...)...)
}
// SetString sets the value of given index as a string
// The indexes are indirected through the [Indexed.Indexes].
func (ix *Indexed) SetString(val string, i ...int) {
ix.Tensor.SetString(val, ix.SourceIndexes(i...)...)
}
// String1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (ix *Indexed) String1D(i int) string {
return ix.Tensor.StringValue(ix.SourceIndexesFrom1D(i)...)
}
// SetString1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (ix *Indexed) SetString1D(val string, i int) {
ix.Tensor.SetString(val, ix.SourceIndexesFrom1D(i)...)
}
//////// Ints
// Int returns the value of given index as an int.
// The indexes are indirected through the [Indexed.Indexes].
func (ix *Indexed) Int(i ...int) int {
return ix.Tensor.Int(ix.SourceIndexes(i...)...)
}
// SetInt sets the value of given index as an int
// The indexes are indirected through the [Indexed.Indexes].
func (ix *Indexed) SetInt(val int, i ...int) {
ix.Tensor.SetInt(val, ix.SourceIndexes(i...)...)
}
// Int1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (ix *Indexed) Int1D(i int) int {
return ix.Tensor.Int(ix.SourceIndexesFrom1D(i)...)
}
// SetInt1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (ix *Indexed) SetInt1D(val int, i int) {
ix.Tensor.SetInt(val, ix.SourceIndexesFrom1D(i)...)
}
// check for interface impl
var _ Tensor = (*Indexed)(nil)
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"bufio"
"bytes"
"encoding/csv"
"fmt"
"io"
"io/fs"
"log"
"os"
"strconv"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/base/reflectx"
)
// Delim are standard CSV delimiter options (Tab, Comma, Space)
type Delims int32 //enums:enum
const (
// Tab is the tab rune delimiter, for TSV tab separated values
Tab Delims = iota
// Comma is the comma rune delimiter, for CSV comma separated values
Comma
// Space is the space rune delimiter, for SSV space separated value
Space
// Detect is used during reading a file -- reads the first line and detects tabs or commas
Detect
)
func (dl Delims) Rune() rune {
switch dl {
case Tab:
return '\t'
case Comma:
return ','
case Space:
return ' '
}
return '\t'
}
// SetPrecision sets the "precision" metadata value that determines
// the precision to use in writing floating point numbers to files.
func SetPrecision(obj any, prec int) {
metadata.Set(obj, "Precision", prec)
}
// Precision gets the "precision" metadata value that determines
// the precision to use in writing floating point numbers to files.
// returns an error if not set.
func Precision(obj any) (int, error) {
return metadata.Get[int](obj, "Precision")
}
// SaveCSV writes a tensor to a comma-separated-values (CSV) file
// (where comma = any delimiter, specified in the delim arg).
// Outer-most dims are rows in the file, and inner-most is column --
// Reading just grabs all values and doesn't care about shape.
func SaveCSV(tsr Tensor, filename fsx.Filename, delim Delims) error {
fp, err := os.Create(string(filename))
defer fp.Close()
if err != nil {
log.Println(err)
return err
}
bw := bufio.NewWriter(fp)
WriteCSV(tsr, bw, delim)
bw.Flush()
return nil
}
// OpenCSV reads a tensor from a comma-separated-values (CSV) file
// (where comma = any delimiter, specified in the delim arg),
// using the Go standard encoding/csv reader conforming
// to the official CSV standard.
// Reads all values and assigns as many as fit.
func OpenCSV(tsr Tensor, filename fsx.Filename, delim Delims) error {
fp, err := os.Open(string(filename))
if err != nil {
return errors.Log(err)
}
defer fp.Close()
return ReadCSV(tsr, bufio.NewReader(fp), delim)
}
// OpenFS is the version of [OpenCSV] that uses an [fs.FS] filesystem.
func OpenFS(tsr Tensor, fsys fs.FS, filename string, delim Delims) error {
fp, err := fsys.Open(filename)
if err != nil {
return errors.Log(err)
}
defer fp.Close()
return ReadCSV(tsr, fp, delim)
}
//////// WriteCSV
// WriteCSV writes a tensor to a comma-separated-values (CSV) file
// (where comma = any delimiter, specified in the delim arg).
// Outer-most dims are rows in the file, and inner-most is column --
// Reading just grabs all values and doesn't care about shape.
func WriteCSV(tsr Tensor, w io.Writer, delim Delims) error {
prec := -1
if ps, err := Precision(tsr); err == nil {
prec = ps
}
cw := csv.NewWriter(w)
cw.Comma = delim.Rune()
nrow := tsr.DimSize(0)
nin := tsr.Len() / nrow
rec := make([]string, nin)
str := tsr.IsString()
for ri := 0; ri < nrow; ri++ {
for ci := 0; ci < nin; ci++ {
idx := ri*nin + ci
if str {
rec[ci] = tsr.String1D(idx)
} else {
rec[ci] = strconv.FormatFloat(tsr.Float1D(idx), 'g', prec, 64)
}
}
err := cw.Write(rec)
if err != nil {
log.Println(err)
return err
}
}
cw.Flush()
return nil
}
// ReadCSV reads a tensor from a comma-separated-values (CSV) file
// (where comma = any delimiter, specified in the delim arg),
// using the Go standard encoding/csv reader conforming
// to the official CSV standard.
// Reads all values and assigns as many as fit.
func ReadCSV(tsr Tensor, r io.Reader, delim Delims) error {
cr := csv.NewReader(r)
cr.Comma = delim.Rune()
rec, err := cr.ReadAll() // todo: lazy, avoid resizing
if err != nil || len(rec) == 0 {
return err
}
rows := len(rec)
cols := len(rec[0])
sz := tsr.Len()
idx := 0
for ri := 0; ri < rows; ri++ {
for ci := 0; ci < cols; ci++ {
str := rec[ri][ci]
tsr.SetString1D(str, idx)
idx++
if idx >= sz {
goto done
}
}
}
done:
return nil
}
func label(nm string, sh *Shape) string {
if nm != "" {
nm += " " + sh.String()
} else {
nm = sh.String()
}
return nm
}
// padToLength returns the given string with added spaces
// to pad out to target length. at least 1 space will be added
func padToLength(str string, tlen int) string {
slen := len(str)
if slen < tlen-1 {
return str + strings.Repeat(" ", tlen-slen)
}
return str + " "
}
// prepadToLength returns the given string with added spaces
// to pad out to target length at start (for numbers).
// at least 1 space will be added
func prepadToLength(str string, tlen int) string {
slen := len(str)
if slen < tlen-1 {
return strings.Repeat(" ", tlen-slen-1) + str + " "
}
return str + " "
}
// MaxPrintLineWidth is the maximum line width in characters
// to generate for tensor Sprintf function.
var MaxPrintLineWidth = 80
// Sprintf returns a string representation of the given tensor,
// with a maximum length of as given: output is terminated
// when it exceeds that length. If maxLen = 0, [MaxSprintLength] is used.
// The format is the per-element format string.
// If empty it uses general %g for number or %s for string.
func Sprintf(format string, tsr Tensor, maxLen int) string {
if maxLen == 0 {
maxLen = MaxSprintLength
}
defFmt := format == ""
if defFmt {
switch {
case tsr.IsString():
format = "%s"
case reflectx.KindIsInt(tsr.DataType()):
format = "%.10g"
default:
format = "%.10g"
}
}
nd := tsr.NumDims()
if nd == 1 && tsr.DimSize(0) == 1 { // scalar special case
if tsr.IsString() {
return fmt.Sprintf(format, tsr.String1D(0))
} else {
return fmt.Sprintf(format, tsr.Float1D(0))
}
}
mxwd := 0
n := min(tsr.Len(), maxLen)
for i := range n {
s := ""
if tsr.IsString() {
s = fmt.Sprintf(format, tsr.String1D(i))
} else {
s = fmt.Sprintf(format, tsr.Float1D(i))
}
if len(s) > mxwd {
mxwd = len(s)
}
}
onedRow := false
shp := tsr.Shape()
rowShape, colShape, _, colIdxs := Projection2DDimShapes(shp, onedRow)
rows, cols, _, _ := Projection2DShape(shp, onedRow)
rowWd := len(rowShape.String()) + 1
legend := ""
if nd > 2 {
leg := bytes.Repeat([]byte("r "), nd)
for _, i := range colIdxs {
leg[2*i] = 'c'
}
legend = "[" + string(leg[:len(leg)-1]) + "]"
}
rowWd = max(rowWd, len(legend)+1)
hdrWd := len(colShape.String()) + 1
colWd := mxwd + 1
var b strings.Builder
b.WriteString(tsr.Label())
noidx := false
if tsr.NumDims() == 1 {
b.WriteString(" ")
rowWd = len(tsr.Label()) + 1
noidx = true
} else {
b.WriteString("\n")
}
if !noidx && nd > 1 && cols > 1 {
colWd = max(colWd, hdrWd)
b.WriteString(padToLength(legend, rowWd))
totWd := rowWd
for c := 0; c < cols; c++ {
_, cc := Projection2DCoords(shp, onedRow, 0, c)
s := prepadToLength(fmt.Sprintf("%v", cc), colWd)
if totWd+len(s) > MaxPrintLineWidth {
b.WriteString("\n" + strings.Repeat(" ", rowWd))
totWd = rowWd
}
b.WriteString(s)
totWd += len(s)
}
b.WriteString("\n")
}
ctr := 0
for r := range rows {
rc, _ := Projection2DCoords(shp, onedRow, r, 0)
if !noidx {
b.WriteString(padToLength(fmt.Sprintf("%v", rc), rowWd))
}
ri := r
totWd := rowWd
for c := 0; c < cols; c++ {
s := ""
if tsr.IsString() {
s = padToLength(fmt.Sprintf(format, Projection2DString(tsr, onedRow, ri, c)), colWd)
} else {
s = prepadToLength(fmt.Sprintf(format, Projection2DValue(tsr, onedRow, ri, c)), colWd)
}
if totWd+len(s) > MaxPrintLineWidth {
b.WriteString("\n" + strings.Repeat(" ", rowWd))
totWd = rowWd
}
b.WriteString(s)
totWd += len(s)
}
b.WriteString("\n")
ctr += cols
if ctr > maxLen {
b.WriteString("...\n")
break
}
}
return b.String()
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"math"
"reflect"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/base/reflectx"
)
// Masked is a filtering wrapper around another "source" [Tensor],
// that provides a bit-masked view onto the Tensor defined by a [Bool] [Values]
// tensor with a matching shape. If the bool mask has a 'false'
// then the corresponding value cannot be Set, and Float access returns
// NaN indicating missing data (other type access returns the zero value).
// A new Masked view defaults to a full transparent view of the source tensor.
// To produce a new [Values] tensor with only the 'true' cases,
// (i.e., the copy function of numpy), call [Masked.AsValues].
type Masked struct { //types:add
// Tensor source that we are a masked view onto.
Tensor Tensor
// Bool tensor with same shape as source tensor, providing mask.
Mask *Bool
}
// NewMasked returns a new [Masked] view of given tensor,
// with given [Bool] mask values. If no mask is provided,
// a default full transparent (all bool values = true) mask is used.
func NewMasked(tsr Tensor, mask ...*Bool) *Masked {
ms := &Masked{Tensor: tsr}
if len(mask) == 1 {
ms.Mask = mask[0]
ms.SyncShape()
} else {
ms.Mask = NewBoolShape(tsr.Shape())
ms.Mask.SetTrue()
}
return ms
}
// Mask is the general purpose masking function, which checks
// if the mask arg is a Bool and uses if so.
// Otherwise, it logs an error.
func Mask(tsr, mask Tensor) Tensor {
if mb, ok := mask.(*Bool); ok {
return NewMasked(tsr, mb)
}
errors.Log(errors.New("tensor.Mask: provided tensor is not a Bool tensor"))
return tsr
}
// AsMasked returns the tensor as a [Masked] view.
// If it already is one, then it is returned, otherwise it is wrapped
// with an initially fully transparent mask.
func AsMasked(tsr Tensor) *Masked {
if ms, ok := tsr.(*Masked); ok {
return ms
}
return NewMasked(tsr)
}
// SetTensor sets the given source tensor. If the shape does not match
// the current Mask, then a new transparent mask is established.
func (ms *Masked) SetTensor(tsr Tensor) {
ms.Tensor = tsr
ms.SyncShape()
}
// SyncShape ensures that [Masked.Mask] shape is the same as source tensor.
// If the Mask does not exist or is a different shape from the source,
// then it is created or reshaped, and all values set to true ("transparent").
func (ms *Masked) SyncShape() {
if ms.Mask == nil {
ms.Mask = NewBoolShape(ms.Tensor.Shape())
ms.Mask.SetTrue()
return
}
if !ms.Mask.Shape().IsEqual(ms.Tensor.Shape()) {
SetShapeFrom(ms.Mask, ms.Tensor)
ms.Mask.SetTrue()
}
}
func (ms *Masked) Label() string { return label(metadata.Name(ms), ms.Shape()) }
func (ms *Masked) String() string { return Sprintf("", ms, 0) }
func (ms *Masked) Metadata() *metadata.Data { return ms.Tensor.Metadata() }
func (ms *Masked) IsString() bool { return ms.Tensor.IsString() }
func (ms *Masked) DataType() reflect.Kind { return ms.Tensor.DataType() }
func (ms *Masked) ShapeSizes() []int { return ms.Tensor.ShapeSizes() }
func (ms *Masked) Shape() *Shape { return ms.Tensor.Shape() }
func (ms *Masked) Len() int { return ms.Tensor.Len() }
func (ms *Masked) NumDims() int { return ms.Tensor.NumDims() }
func (ms *Masked) DimSize(dim int) int { return ms.Tensor.DimSize(dim) }
// AsValues returns a copy of this tensor as raw [Values].
// This "renders" the Masked view into a fully contiguous
// and optimized memory representation of that view.
// Because the masking pattern is unpredictable, only a 1D shape is possible.
func (ms *Masked) AsValues() Values {
dt := ms.Tensor.DataType()
n := ms.Len()
switch {
case ms.Tensor.IsString():
vals := make([]string, 0, n)
for i := range n {
if !ms.Mask.Bool1D(i) {
continue
}
vals = append(vals, ms.Tensor.String1D(i))
}
return NewStringFromValues(vals...)
case reflectx.KindIsFloat(dt):
vals := make([]float64, 0, n)
for i := range n {
if !ms.Mask.Bool1D(i) {
continue
}
vals = append(vals, ms.Tensor.Float1D(i))
}
return NewFloat64FromValues(vals...)
default:
vals := make([]int, 0, n)
for i := range n {
if !ms.Mask.Bool1D(i) {
continue
}
vals = append(vals, ms.Tensor.Int1D(i))
}
return NewIntFromValues(vals...)
}
}
// SourceIndexes returns a flat [Int] tensor of the mask values
// that match the given getTrue argument state.
// These can be used as indexes in the [Indexed] view, for example.
// The resulting tensor is 2D with inner dimension = number of source
// tensor dimensions, to hold the indexes, and outer dimension = number
// of indexes.
func (ms *Masked) SourceIndexes(getTrue bool) *Int {
n := ms.Len()
nd := ms.Tensor.NumDims()
idxs := make([]int, 0, n*nd)
for i := range n {
if ms.Mask.Bool1D(i) != getTrue {
continue
}
ix := ms.Tensor.Shape().IndexFrom1D(i)
idxs = append(idxs, ix...)
}
it := NewIntFromValues(idxs...)
it.SetShapeSizes(len(idxs)/nd, nd)
return it
}
//////// Floats
func (ms *Masked) Float(i ...int) float64 {
if !ms.Mask.Bool(i...) {
return math.NaN()
}
return ms.Tensor.Float(i...)
}
func (ms *Masked) SetFloat(val float64, i ...int) {
if !ms.Mask.Bool(i...) {
return
}
ms.Tensor.SetFloat(val, i...)
}
func (ms *Masked) Float1D(i int) float64 {
if !ms.Mask.Bool1D(i) {
return math.NaN()
}
return ms.Tensor.Float1D(i)
}
func (ms *Masked) SetFloat1D(val float64, i int) {
if !ms.Mask.Bool1D(i) {
return
}
ms.Tensor.SetFloat1D(val, i)
}
//////// Strings
func (ms *Masked) StringValue(i ...int) string {
if !ms.Mask.Bool(i...) {
return ""
}
return ms.Tensor.StringValue(i...)
}
func (ms *Masked) SetString(val string, i ...int) {
if !ms.Mask.Bool(i...) {
return
}
ms.Tensor.SetString(val, i...)
}
func (ms *Masked) String1D(i int) string {
if !ms.Mask.Bool1D(i) {
return ""
}
return ms.Tensor.String1D(i)
}
func (ms *Masked) SetString1D(val string, i int) {
if !ms.Mask.Bool1D(i) {
return
}
ms.Tensor.SetString1D(val, i)
}
//////// Ints
func (ms *Masked) Int(i ...int) int {
if !ms.Mask.Bool(i...) {
return 0
}
return ms.Tensor.Int(i...)
}
func (ms *Masked) SetInt(val int, i ...int) {
if !ms.Mask.Bool(i...) {
return
}
ms.Tensor.SetInt(val, i...)
}
func (ms *Masked) Int1D(i int) int {
if !ms.Mask.Bool1D(i) {
return 0
}
return ms.Tensor.Int1D(i)
}
// SetInt1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (ms *Masked) SetInt1D(val int, i int) {
if !ms.Mask.Bool1D(i) {
return
}
ms.Tensor.SetInt1D(val, i)
}
// Filter sets the mask values using given Filter function.
// The filter function gets the 1D index into the source tensor.
func (ms *Masked) Filter(filterer func(tsr Tensor, idx int) bool) *Masked {
n := ms.Tensor.Len()
for i := range n {
ms.Mask.SetBool1D(filterer(ms.Tensor, i), i)
}
return ms
}
// check for interface impl
var _ Tensor = (*Masked)(nil)
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"fmt"
"strconv"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/num"
"cogentcore.org/core/base/reflectx"
)
// Number is a tensor of numerical values
type Number[T num.Number] struct {
Base[T]
}
// Float64 is an alias for Number[float64].
type Float64 = Number[float64]
// Float32 is an alias for Number[float32].
type Float32 = Number[float32]
// Int is an alias for Number[int].
type Int = Number[int]
// Int32 is an alias for Number[int32].
type Int32 = Number[int32]
// Uint32 is an alias for Number[uint32].
type Uint32 = Number[uint32]
// Byte is an alias for Number[byte].
type Byte = Number[byte]
// NewFloat32 returns a new [Float32] tensor
// with the given sizes per dimension (shape).
func NewFloat32(sizes ...int) *Float32 {
return New[float32](sizes...).(*Float32)
}
// NewFloat64 returns a new [Float64] tensor
// with the given sizes per dimension (shape).
func NewFloat64(sizes ...int) *Float64 {
return New[float64](sizes...).(*Float64)
}
// NewInt returns a new Int tensor
// with the given sizes per dimension (shape).
func NewInt(sizes ...int) *Int {
return New[int](sizes...).(*Int)
}
// NewInt32 returns a new Int32 tensor
// with the given sizes per dimension (shape).
func NewInt32(sizes ...int) *Int32 {
return New[int32](sizes...).(*Int32)
}
// NewUint32 returns a new Uint32 tensor
// with the given sizes per dimension (shape).
func NewUint32(sizes ...int) *Uint32 {
return New[uint32](sizes...).(*Uint32)
}
// NewByte returns a new Byte tensor
// with the given sizes per dimension (shape).
func NewByte(sizes ...int) *Byte {
return New[uint8](sizes...).(*Byte)
}
// NewNumber returns a new n-dimensional tensor of numerical values
// with the given sizes per dimension (shape).
func NewNumber[T num.Number](sizes ...int) *Number[T] {
tsr := &Number[T]{}
tsr.SetShapeSizes(sizes...)
tsr.Values = make([]T, tsr.Len())
return tsr
}
// NewNumberShape returns a new n-dimensional tensor of numerical values
// using given shape.
func NewNumberShape[T num.Number](shape *Shape) *Number[T] {
tsr := &Number[T]{}
tsr.shape.CopyFrom(shape)
tsr.Values = make([]T, tsr.Len())
return tsr
}
// todo: this should in principle work with yaegi:add but it is crashing
// will come back to it later.
// NewNumberFromValues returns a new 1-dimensional tensor of given value type
// initialized directly from the given slice values, which are not copied.
// The resulting Tensor thus "wraps" the given values.
func NewNumberFromValues[T num.Number](vals ...T) *Number[T] {
n := len(vals)
tsr := &Number[T]{}
tsr.Values = vals
tsr.SetShapeSizes(n)
return tsr
}
// String satisfies the fmt.Stringer interface for string of tensor data.
func (tsr *Number[T]) String() string { return Sprintf("", tsr, 0) }
func (tsr *Number[T]) IsString() bool { return false }
func (tsr *Number[T]) AsValues() Values { return tsr }
func (tsr *Number[T]) SetAdd(val T, i ...int) {
tsr.Values[tsr.shape.IndexTo1D(i...)] += val
}
func (tsr *Number[T]) SetSub(val T, i ...int) {
tsr.Values[tsr.shape.IndexTo1D(i...)] -= val
}
func (tsr *Number[T]) SetMul(val T, i ...int) {
tsr.Values[tsr.shape.IndexTo1D(i...)] *= val
}
func (tsr *Number[T]) SetDiv(val T, i ...int) {
tsr.Values[tsr.shape.IndexTo1D(i...)] /= val
}
/////// Strings
func (tsr *Number[T]) SetString(val string, i ...int) {
if fv, err := strconv.ParseFloat(val, 64); err == nil {
tsr.Values[tsr.shape.IndexTo1D(i...)] = T(fv)
}
}
func (tsr Number[T]) SetString1D(val string, i int) {
if fv, err := strconv.ParseFloat(val, 64); err == nil {
tsr.Values[i] = T(fv)
}
}
func (tsr *Number[T]) SetStringRow(val string, row, cell int) {
if fv, err := strconv.ParseFloat(val, 64); err == nil {
_, sz := tsr.shape.RowCellSize()
tsr.Values[row*sz+cell] = T(fv)
}
}
// AppendRowString adds a row and sets string value(s), up to number of cells.
func (tsr *Number[T]) AppendRowString(val ...string) {
if tsr.NumDims() == 0 {
tsr.SetShapeSizes(0)
}
nrow, sz := tsr.shape.RowCellSize()
tsr.SetNumRows(nrow + 1)
mx := min(sz, len(val))
for i := range mx {
tsr.SetStringRow(val[i], nrow, i)
}
}
/////// Floats
func (tsr *Number[T]) Float(i ...int) float64 {
return float64(tsr.Values[tsr.shape.IndexTo1D(i...)])
}
func (tsr *Number[T]) SetFloat(val float64, i ...int) {
tsr.Values[tsr.shape.IndexTo1D(i...)] = T(val)
}
func (tsr *Number[T]) Float1D(i int) float64 {
return float64(tsr.Values[NegIndex(i, len(tsr.Values))])
}
func (tsr *Number[T]) SetFloat1D(val float64, i int) {
tsr.Values[NegIndex(i, len(tsr.Values))] = T(val)
}
func (tsr *Number[T]) FloatRow(row, cell int) float64 {
_, sz := tsr.shape.RowCellSize()
i := row*sz + cell
return float64(tsr.Values[NegIndex(i, len(tsr.Values))])
}
func (tsr *Number[T]) SetFloatRow(val float64, row, cell int) {
_, sz := tsr.shape.RowCellSize()
tsr.Values[row*sz+cell] = T(val)
}
// AppendRowFloat adds a row and sets float value(s), up to number of cells.
func (tsr *Number[T]) AppendRowFloat(val ...float64) {
if tsr.NumDims() == 0 {
tsr.SetShapeSizes(0)
}
nrow, sz := tsr.shape.RowCellSize()
tsr.SetNumRows(nrow + 1)
mx := min(sz, len(val))
for i := range mx {
tsr.SetFloatRow(val[i], nrow, i)
}
}
/////// Ints
func (tsr *Number[T]) Int(i ...int) int {
return int(tsr.Values[tsr.shape.IndexTo1D(i...)])
}
func (tsr *Number[T]) SetInt(val int, i ...int) {
tsr.Values[tsr.shape.IndexTo1D(i...)] = T(val)
}
func (tsr *Number[T]) Int1D(i int) int {
return int(tsr.Values[NegIndex(i, len(tsr.Values))])
}
func (tsr *Number[T]) SetInt1D(val int, i int) {
tsr.Values[NegIndex(i, len(tsr.Values))] = T(val)
}
func (tsr *Number[T]) IntRow(row, cell int) int {
_, sz := tsr.shape.RowCellSize()
i := row*sz + cell
return int(tsr.Values[i])
}
func (tsr *Number[T]) SetIntRow(val int, row, cell int) {
_, sz := tsr.shape.RowCellSize()
tsr.Values[row*sz+cell] = T(val)
}
// AppendRowInt adds a row and sets int value(s), up to number of cells.
func (tsr *Number[T]) AppendRowInt(val ...int) {
if tsr.NumDims() == 0 {
tsr.SetShapeSizes(0)
}
nrow, sz := tsr.shape.RowCellSize()
tsr.SetNumRows(nrow + 1)
mx := min(sz, len(val))
for i := range mx {
tsr.SetIntRow(val[i], nrow, i)
}
}
// SetZeros is simple convenience function initialize all values to 0
func (tsr *Number[T]) SetZeros() {
for j := range tsr.Values {
tsr.Values[j] = 0
}
}
// Clone clones this tensor, creating a duplicate copy of itself with its
// own separate memory representation of all the values.
func (tsr *Number[T]) Clone() Values {
csr := NewNumberShape[T](&tsr.shape)
copy(csr.Values, tsr.Values)
return csr
}
// CopyFrom copies all avail values from other tensor into this tensor, with an
// optimized implementation if the other tensor is of the same type, and
// otherwise it goes through appropriate standard type.
func (tsr *Number[T]) CopyFrom(frm Values) {
if fsm, ok := frm.(*Number[T]); ok {
copy(tsr.Values, fsm.Values)
return
}
sz := min(tsr.Len(), frm.Len())
if reflectx.KindIsInt(tsr.DataType()) {
for i := range sz {
tsr.Values[i] = T(frm.Int1D(i))
}
} else {
for i := range sz {
tsr.Values[i] = T(frm.Float1D(i))
}
}
}
// AppendFrom appends values from other tensor into this tensor,
// which must have the same cell size as this tensor.
// It uses and optimized implementation if the other tensor
// is of the same type, and otherwise it goes through
// appropriate standard type.
func (tsr *Number[T]) AppendFrom(frm Values) Values {
rows, cell := tsr.shape.RowCellSize()
frows, fcell := frm.Shape().RowCellSize()
if cell != fcell {
errors.Log(fmt.Errorf("tensor.AppendFrom: cell sizes do not match: %d != %d", cell, fcell))
return tsr
}
tsr.SetNumRows(rows + frows)
st := rows * cell
fsz := frows * fcell
if fsm, ok := frm.(*Number[T]); ok {
copy(tsr.Values[st:st+fsz], fsm.Values)
return tsr
}
for i := range fsz {
tsr.Values[st+i] = T(frm.Float1D(i))
}
return tsr
}
// CopyCellsFrom copies given range of values from other tensor into this tensor,
// using flat 1D indexes: to = starting index in this Tensor to start copying into,
// start = starting index on from Tensor to start copying from, and n = number of
// values to copy. Uses an optimized implementation if the other tensor is
// of the same type, and otherwise it goes through appropriate standard type.
func (tsr *Number[T]) CopyCellsFrom(frm Values, to, start, n int) {
if fsm, ok := frm.(*Number[T]); ok {
copy(tsr.Values[to:to+n], fsm.Values[start:start+n])
return
}
for i := range n {
tsr.Values[to+i] = T(frm.Float1D(start + i))
}
}
// SubSpace returns a new tensor with innermost subspace at given
// offset(s) in outermost dimension(s) (len(offs) < NumDims).
// The new tensor points to the values of the this tensor (i.e., modifications
// will affect both), as its Values slice is a view onto the original (which
// is why only inner-most contiguous supsaces are supported).
// Use AsValues() method to separate the two.
func (tsr *Number[T]) SubSpace(offs ...int) Values {
b := tsr.subSpaceImpl(offs...)
rt := &Number[T]{Base: *b}
return rt
}
// RowTensor is a convenience version of [RowMajor.SubSpace] to return the
// SubSpace for the outermost row dimension. [Rows] defines a version
// of this that indirects through the row indexes.
func (tsr *Number[T]) RowTensor(row int) Values {
return tsr.SubSpace(row)
}
// SetRowTensor sets the values of the SubSpace at given row to given values.
func (tsr *Number[T]) SetRowTensor(val Values, row int) {
_, cells := tsr.shape.RowCellSize()
st := row * cells
mx := min(val.Len(), cells)
tsr.CopyCellsFrom(val, st, 0, mx)
}
// AppendRow adds a row and sets values to given values.
func (tsr *Number[T]) AppendRow(val Values) {
if tsr.NumDims() == 0 {
tsr.SetShapeSizes(0)
}
nrow := tsr.DimSize(0)
tsr.SetNumRows(nrow + 1)
tsr.SetRowTensor(val, nrow)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
const (
// OnedRow is for onedRow arguments to Projection2D functions,
// specifies that the 1D case goes along the row.
OnedRow = true
// OnedColumn is for onedRow arguments to Projection2D functions,
// specifies that the 1D case goes along the column.
OnedColumn = false
)
// Projection2DShape returns the size of a 2D projection of the given tensor Shape,
// collapsing higher dimensions down to 2D (and 1D up to 2D).
// For the 1D case, onedRow determines if the values are row-wise or not.
// Even multiples of inner-most dimensions are placed along the row, odd in the column.
// If there are an odd number of dimensions, the first dimension is row-wise, and
// the remaining inner dimensions use the above logic from there, as if it was even.
// rowEx returns the number of "extra" (outer-dimensional) rows
// and colEx returns the number of extra cols, to add extra spacing between these dimensions.
func Projection2DShape(shp *Shape, onedRow bool) (rows, cols, rowEx, colEx int) {
if shp.Len() == 0 {
return 1, 1, 0, 0
}
nd := shp.NumDims()
if nd == 1 {
if onedRow {
return shp.DimSize(0), 1, 0, 0
}
return 1, shp.DimSize(0), 0, 0
}
if nd == 2 {
return shp.DimSize(0), shp.DimSize(1), 0, 0
}
rowShape, colShape, rowIdxs, colIdxs := Projection2DDimShapes(shp, onedRow)
rows = rowShape.Len()
cols = colShape.Len()
nri := len(rowIdxs)
if nri > 1 {
rowEx = 1
for i := range nri - 1 {
rowEx *= shp.DimSize(rowIdxs[i])
}
}
nci := len(colIdxs)
if nci > 1 {
colEx = 1
for i := range nci - 1 {
colEx *= shp.DimSize(colIdxs[i])
}
}
return
}
// Projection2DDimShapes returns the shapes and dimension indexes for a 2D projection
// of given tensor Shape, collapsing higher dimensions down to 2D (and 1D up to 2D).
// For the 1D case, onedRow determines if the values are row-wise or not.
// Even multiples of inner-most dimensions are placed along the row, odd in the column.
// If there are an odd number of dimensions, the first dimension is row-wise, and
// the remaining inner dimensions use the above logic from there, as if it was even.
// This is the main organizing function for all Projection2D calls.
func Projection2DDimShapes(shp *Shape, onedRow bool) (rowShape, colShape *Shape, rowIdxs, colIdxs []int) {
nd := shp.NumDims()
if nd == 1 {
if onedRow {
return NewShape(shp.DimSize(0)), NewShape(1), []int{0}, nil
}
return NewShape(1), NewShape(shp.DimSize(0)), nil, []int{0}
}
if nd == 2 {
return NewShape(shp.DimSize(0)), NewShape(shp.DimSize(1)), []int{0}, []int{1}
}
var rs, cs []int
odd := nd%2 == 1
sd := 0
end := nd
if odd {
end = nd - 1
sd = 1
rs = []int{shp.DimSize(0)}
rowIdxs = []int{0}
}
for d := range end {
ad := d + sd
if d%2 == 0 { // even goes to row
rs = append(rs, shp.DimSize(ad))
rowIdxs = append(rowIdxs, ad)
} else {
cs = append(cs, shp.DimSize(ad))
colIdxs = append(colIdxs, ad)
}
}
rowShape = NewShape(rs...)
colShape = NewShape(cs...)
return
}
// Projection2DIndex returns the flat 1D index for given row, col coords for a 2D projection
// of the given tensor shape, collapsing higher dimensions down to 2D (and 1D up to 2D).
// See [Projection2DShape] for full info.
func Projection2DIndex(shp *Shape, onedRow bool, row, col int) int {
if shp.Len() == 0 {
return 0
}
nd := shp.NumDims()
if nd == 1 {
if onedRow {
return row
}
return col
}
if nd == 2 {
return shp.IndexTo1D(row, col)
}
rowShape, colShape, rowIdxs, colIdxs := Projection2DDimShapes(shp, onedRow)
ris := rowShape.IndexFrom1D(row)
cis := colShape.IndexFrom1D(col)
ixs := make([]int, nd)
for i, ri := range rowIdxs {
ixs[ri] = ris[i]
}
for i, ci := range colIdxs {
ixs[ci] = cis[i]
}
return shp.IndexTo1D(ixs...)
}
// Projection2DCoords returns the corresponding full-dimensional coordinates
// that go into the given row, col coords for a 2D projection of the given tensor,
// collapsing higher dimensions down to 2D (and 1D up to 2D).
// See [Projection2DShape] for full info.
func Projection2DCoords(shp *Shape, onedRow bool, row, col int) (rowCoords, colCoords []int) {
if shp.Len() == 0 {
return []int{0}, []int{0}
}
idx := Projection2DIndex(shp, onedRow, row, col)
dims := shp.IndexFrom1D(idx)
nd := shp.NumDims()
if nd == 1 {
if onedRow {
return dims, []int{0}
}
return []int{0}, dims
}
if nd == 2 {
return dims[:1], dims[1:]
}
_, _, rowIdxs, colIdxs := Projection2DDimShapes(shp, onedRow)
rowCoords = make([]int, len(rowIdxs))
colCoords = make([]int, len(colIdxs))
for i, ri := range rowIdxs {
rowCoords[i] = dims[ri]
}
for i, ci := range colIdxs {
colCoords[i] = dims[ci]
}
return
}
// Projection2DValue returns the float64 value at given row, col coords for a 2D projection
// of the given tensor, collapsing higher dimensions down to 2D (and 1D up to 2D).
// See [Projection2DShape] for full info.
func Projection2DValue(tsr Tensor, onedRow bool, row, col int) float64 {
idx := Projection2DIndex(tsr.Shape(), onedRow, row, col)
return tsr.Float1D(idx)
}
// Projection2DString returns the string value at given row, col coords for a 2D projection
// of the given tensor, collapsing higher dimensions down to 2D (and 1D up to 2D).
// See [Projection2DShape] for full info.
func Projection2DString(tsr Tensor, onedRow bool, row, col int) string {
idx := Projection2DIndex(tsr.Shape(), onedRow, row, col)
return tsr.String1D(idx)
}
// Projection2DSet sets a float64 value at given row, col coords for a 2D projection
// of the given tensor, collapsing higher dimensions down to 2D (and 1D up to 2D).
// See [Projection2DShape] for full info.
func Projection2DSet(tsr Tensor, onedRow bool, row, col int, val float64) {
idx := Projection2DIndex(tsr.Shape(), onedRow, row, col)
tsr.SetFloat1D(val, idx)
}
// Projection2DSetString sets a string value at given row, col coords for a 2D projection
// of the given tensor, collapsing higher dimensions down to 2D (and 1D up to 2D).
// See [Projection2DShape] for full info.
func Projection2DSetString(tsr Tensor, onedRow bool, row, col int, val string) {
idx := Projection2DIndex(tsr.Shape(), onedRow, row, col)
tsr.SetString1D(val, idx)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"reflect"
"slices"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/metadata"
)
// Reshaped is a reshaping wrapper around another "source" [Tensor],
// that provides a length-preserving reshaped view onto the source Tensor.
// Reshaping by adding new size=1 dimensions (via [NewAxis] value) is
// often important for properly aligning two tensors in a computationally
// compatible manner; see the [AlignShapes] function.
// [Reshaped.AsValues] on this view returns a new [Values] with the view
// shape, calling [Clone] on the source tensor to get the values.
type Reshaped struct { //types:add
// Tensor source that we are a masked view onto.
Tensor Tensor
// Reshape is the effective shape we use for access.
// This must have the same Len() as the source Tensor.
Reshape Shape
}
// NewReshaped returns a new [Reshaped] view of given tensor, with given shape
// sizes. If no such sizes are provided, the source shape is used.
// A single -1 value can be used to automatically specify the remaining tensor
// length, as long as the other sizes are an even multiple of the total length.
// A single -1 returns a 1D view of the entire tensor.
func NewReshaped(tsr Tensor, sizes ...int) *Reshaped {
rs := &Reshaped{Tensor: tsr}
if len(sizes) == 0 {
rs.Reshape.CopyFrom(tsr.Shape())
} else {
errors.Log(rs.SetShapeSizes(sizes...))
}
return rs
}
// Reshape returns a view of the given tensor with given shape sizes.
// A single -1 value can be used to automatically specify the remaining tensor
// length, as long as the other sizes are an even multiple of the total length.
// A single -1 returns a 1D view of the entire tensor.
func Reshape(tsr Tensor, sizes ...int) Tensor {
if len(sizes) == 0 {
err := errors.New("tensor.Reshape: must pass shape sizes")
errors.Log(err)
return tsr
}
if len(sizes) == 1 {
sz := sizes[0]
if sz == -1 {
return As1D(tsr)
}
}
rs := &Reshaped{Tensor: tsr}
errors.Log(rs.SetShapeSizes(sizes...))
return rs
}
// Transpose returns a new [Reshaped] tensor with the strides
// switched so that rows and column dimensions are effectively
// reversed.
func Transpose(tsr Tensor) Tensor {
rs := &Reshaped{Tensor: tsr}
rs.Reshape.CopyFrom(tsr.Shape())
rs.Reshape.Strides = ColumnMajorStrides(rs.Reshape.Sizes...)
return rs
}
// NewRowCellsView returns a 2D [Reshaped] view onto the given tensor,
// with a single outer "row" dimension and a single inner "cells" dimension,
// with the given 'split' dimension specifying where the cells start.
// All dimensions prior to split are collapsed to form the new outer row dimension,
// and the remainder are collapsed to form the 1D cells dimension.
// This is useful for stats, metrics and other packages that operate
// on data in this shape.
func NewRowCellsView(tsr Tensor, split int) *Reshaped {
sizes := tsr.ShapeSizes()
rows := sizes[:split]
cells := sizes[split:]
nr := 1
for _, r := range rows {
nr *= r
}
nc := 1
for _, c := range cells {
nc *= c
}
return NewReshaped(tsr, nr, nc)
}
// AsReshaped returns the tensor as a [Reshaped] view.
// If it already is one, then it is returned, otherwise it is wrapped
// with an initial shape equal to the source tensor.
func AsReshaped(tsr Tensor) *Reshaped {
if rs, ok := tsr.(*Reshaped); ok {
return rs
}
return NewReshaped(tsr)
}
// SetShapeSizes sets our shape sizes to the given values, which must result in
// the same length as the source tensor. An error is returned if not.
// If a different subset of content is desired, use another view such as [Sliced].
// Note that any number of size = 1 dimensions can be added without affecting
// the length, and the [NewAxis] value can be used to semantically
// indicate when such a new dimension is being inserted. This is often useful
// for aligning two tensors to achieve a desired computation; see [AlignShapes]
// function. A single -1 can be used to specify a dimension size that takes the
// remaining length, as long as the other sizes are an even multiple of the length.
// A single -1 indicates to use the full length.
func (rs *Reshaped) SetShapeSizes(sizes ...int) error {
sln := rs.Tensor.Len()
if sln == 0 {
return nil
}
if sln == 1 {
sz := sizes[0]
if sz < 0 {
rs.Reshape.SetShapeSizes(sln)
return nil
}
}
sz := slices.Clone(sizes)
ln := 1
negIdx := -1
for i, s := range sz {
if s < 0 {
negIdx = i
} else {
ln *= s
}
}
if negIdx >= 0 {
if sln%ln != 0 {
return errors.New("tensor.Reshaped SetShapeSizes: -1 cannot be used because the remaining dimensions are not an even multiple of the source tensor length")
}
sz[negIdx] = sln / ln
}
rs.Reshape.SetShapeSizes(sz...)
if rs.Reshape.Len() != sln {
return errors.New("tensor.Reshaped SetShapeSizes: new length is different from source tensor; use Sliced or other views to change view content")
}
return nil
}
func (rs *Reshaped) Label() string { return label(metadata.Name(rs), rs.Shape()) }
func (rs *Reshaped) String() string { return Sprintf("", rs, 0) }
func (rs *Reshaped) Metadata() *metadata.Data { return rs.Tensor.Metadata() }
func (rs *Reshaped) IsString() bool { return rs.Tensor.IsString() }
func (rs *Reshaped) DataType() reflect.Kind { return rs.Tensor.DataType() }
func (rs *Reshaped) ShapeSizes() []int { return slices.Clone(rs.Reshape.Sizes) }
func (rs *Reshaped) Shape() *Shape { return &rs.Reshape }
func (rs *Reshaped) Len() int { return rs.Reshape.Len() }
func (rs *Reshaped) NumDims() int { return rs.Reshape.NumDims() }
func (rs *Reshaped) DimSize(dim int) int { return rs.Reshape.DimSize(dim) }
// AsValues returns a copy of this tensor as raw [Values], with
// the same shape as our view. This calls [Clone] on the source
// tensor to get the Values and then sets our shape sizes to it.
func (rs *Reshaped) AsValues() Values {
vals := Clone(rs.Tensor)
vals.SetShapeSizes(rs.Reshape.Sizes...)
return vals
}
//////// Floats
func (rs *Reshaped) Float(i ...int) float64 {
return rs.Tensor.Float1D(rs.Reshape.IndexTo1D(i...))
}
func (rs *Reshaped) SetFloat(val float64, i ...int) {
rs.Tensor.SetFloat1D(val, rs.Reshape.IndexTo1D(i...))
}
func (rs *Reshaped) Float1D(i int) float64 { return rs.Tensor.Float1D(i) }
func (rs *Reshaped) SetFloat1D(val float64, i int) { rs.Tensor.SetFloat1D(val, i) }
//////// Strings
func (rs *Reshaped) StringValue(i ...int) string {
return rs.Tensor.String1D(rs.Reshape.IndexTo1D(i...))
}
func (rs *Reshaped) SetString(val string, i ...int) {
rs.Tensor.SetString1D(val, rs.Reshape.IndexTo1D(i...))
}
func (rs *Reshaped) String1D(i int) string { return rs.Tensor.String1D(i) }
func (rs *Reshaped) SetString1D(val string, i int) { rs.Tensor.SetString1D(val, i) }
//////// Ints
func (rs *Reshaped) Int(i ...int) int {
return rs.Tensor.Int1D(rs.Reshape.IndexTo1D(i...))
}
func (rs *Reshaped) SetInt(val int, i ...int) {
rs.Tensor.SetInt1D(val, rs.Reshape.IndexTo1D(i...))
}
func (rs *Reshaped) Int1D(i int) int { return rs.Tensor.Int1D(i) }
func (rs *Reshaped) SetInt1D(val int, i int) { rs.Tensor.SetInt1D(val, i) }
// check for interface impl
var _ Tensor = (*Reshaped)(nil)
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"cmp"
"math"
"math/rand"
"reflect"
"slices"
"sort"
"strings"
"cogentcore.org/core/base/metadata"
)
// Rows is a row-indexed wrapper view around a [Values] [Tensor] that allows
// arbitrary row-wise ordering and filtering according to the [Rows.Indexes].
// Sorting and filtering a tensor along this outermost row dimension only
// requires updating the indexes while leaving the underlying Tensor alone.
// Unlike the more general [Sliced] view, Rows maintains memory contiguity
// for the inner dimensions ("cells") within each row, and supports the [RowMajor]
// interface, with the [Set]FloatRow[Cell] methods providing efficient access.
// Use [Rows.AsValues] to obtain a concrete [Values] representation with the
// current row sorting.
type Rows struct { //types:add
// Tensor source that we are an indexed view onto.
// Note that this must be a concrete [Values] tensor, to enable efficient
// [RowMajor] access and subspace functions.
Tensor Values
// Indexes are the indexes into Tensor rows, with nil = sequential.
// Only set if order is different from default sequential order.
// Use the [Rows.RowIndex] method for nil-aware logic.
Indexes []int
}
// NewRows returns a new [Rows] view of given tensor,
// with optional list of indexes (none / nil = sequential).
func NewRows(tsr Values, idxs ...int) *Rows {
rw := &Rows{Tensor: tsr, Indexes: slices.Clone(idxs)}
return rw
}
// AsRows returns the tensor as a [Rows] view.
// If it already is one, then it is returned, otherwise
// a new Rows is created to wrap around the given tensor, which is
// enforced to be a [Values] tensor either because it already is one,
// or by calling [Tensor.AsValues] on it.
func AsRows(tsr Tensor) *Rows {
if rw, ok := tsr.(*Rows); ok {
return rw
}
return NewRows(tsr.AsValues())
}
// SetTensor sets as indexes into given [Values] tensor with sequential initial indexes.
func (rw *Rows) SetTensor(tsr Values) {
rw.Tensor = tsr
rw.Sequential()
}
func (rw *Rows) IsString() bool { return rw.Tensor.IsString() }
func (rw *Rows) DataType() reflect.Kind { return rw.Tensor.DataType() }
// RowIndex returns the actual index into underlying tensor row based on given
// index value. If Indexes == nil, index is passed through.
func (rw *Rows) RowIndex(idx int) int {
if rw.Indexes == nil {
return idx
}
return rw.Indexes[idx]
}
// NumRows returns the effective number of rows in this Rows view,
// which is the length of the index list or number of outer
// rows dimension of tensor if no indexes (full sequential view).
func (rw *Rows) NumRows() int {
if rw.Indexes == nil {
return rw.Tensor.DimSize(0)
}
return len(rw.Indexes)
}
func (rw *Rows) String() string { return Sprintf("", rw.Tensor, 0) }
func (rw *Rows) Label() string { return rw.Tensor.Label() }
func (rw *Rows) Metadata() *metadata.Data { return rw.Tensor.Metadata() }
func (rw *Rows) NumDims() int { return rw.Tensor.NumDims() }
// If we have Indexes, this is the effective shape sizes using
// the current number of indexes as the outermost row dimension size.
func (rw *Rows) ShapeSizes() []int {
if rw.Indexes == nil || rw.Tensor.NumDims() == 0 {
return rw.Tensor.ShapeSizes()
}
sh := slices.Clone(rw.Tensor.ShapeSizes())
sh[0] = len(rw.Indexes)
return sh
}
// Shape() returns a [Shape] representation of the tensor shape
// (dimension sizes). If we have Indexes, this is the effective
// shape using the current number of indexes as the outermost row dimension size.
func (rw *Rows) Shape() *Shape {
if rw.Indexes == nil {
return rw.Tensor.Shape()
}
return NewShape(rw.ShapeSizes()...)
}
// Len returns the total number of elements in the tensor,
// taking into account the Indexes via [Rows],
// as NumRows() * cell size.
func (rw *Rows) Len() int {
rows := rw.NumRows()
_, cells := rw.Tensor.Shape().RowCellSize()
return cells * rows
}
// DimSize returns size of given dimension, returning NumRows()
// for first dimension.
func (rw *Rows) DimSize(dim int) int {
if dim == 0 {
return rw.NumRows()
}
return rw.Tensor.DimSize(dim)
}
// RowCellSize returns the size of the outermost Row shape dimension
// (via [Rows.NumRows] method), and the size of all the remaining
// inner dimensions (the "cell" size).
func (rw *Rows) RowCellSize() (rows, cells int) {
_, cells = rw.Tensor.Shape().RowCellSize()
rows = rw.NumRows()
return
}
// ValidIndexes deletes all invalid indexes from the list.
// Call this if rows (could) have been deleted from tensor.
func (rw *Rows) ValidIndexes() {
if rw.Tensor.DimSize(0) <= 0 || rw.Indexes == nil {
rw.Indexes = nil
return
}
ni := rw.NumRows()
for i := ni - 1; i >= 0; i-- {
if rw.Indexes[i] >= rw.Tensor.DimSize(0) {
rw.Indexes = append(rw.Indexes[:i], rw.Indexes[i+1:]...)
}
}
}
// Sequential sets Indexes to nil, resulting in sequential row-wise access into tensor.
func (rw *Rows) Sequential() { //types:add
rw.Indexes = nil
}
// IndexesNeeded is called prior to an operation that needs actual indexes,
// e.g., Sort, Filter. If Indexes == nil, they are set to all rows, otherwise
// current indexes are left as is. Use Sequential, then IndexesNeeded to ensure
// all rows are represented.
func (rw *Rows) IndexesNeeded() {
if rw.Tensor.DimSize(0) <= 0 {
rw.Indexes = nil
return
}
if rw.Indexes != nil {
return
}
rw.Indexes = make([]int, rw.Tensor.DimSize(0))
for i := range rw.Indexes {
rw.Indexes[i] = i
}
}
// ExcludeMissing deletes indexes where the values are missing, as indicated by NaN.
// Uses first cell of higher dimensional data.
func (rw *Rows) ExcludeMissing() { //types:add
if rw.Tensor.DimSize(0) <= 0 {
rw.Indexes = nil
return
}
rw.IndexesNeeded()
ni := rw.NumRows()
for i := ni - 1; i >= 0; i-- {
if math.IsNaN(rw.Tensor.FloatRow(rw.Indexes[i], 0)) {
rw.Indexes = append(rw.Indexes[:i], rw.Indexes[i+1:]...)
}
}
}
// Permuted sets indexes to a permuted order. If indexes already exist
// then existing list of indexes is permuted, otherwise a new set of
// permuted indexes are generated
func (rw *Rows) Permuted() {
if rw.Tensor.DimSize(0) <= 0 {
rw.Indexes = nil
return
}
if rw.Indexes == nil {
rw.Indexes = rand.Perm(rw.Tensor.DimSize(0))
} else {
rand.Shuffle(len(rw.Indexes), func(i, j int) {
rw.Indexes[i], rw.Indexes[j] = rw.Indexes[j], rw.Indexes[i]
})
}
}
const (
// Ascending specifies an ascending sort direction for tensor Sort routines
Ascending = true
// Descending specifies a descending sort direction for tensor Sort routines
Descending = false
// StableSort specifies using stable, original order-preserving sort, which is slower.
StableSort = true
// Unstable specifies using faster but unstable sorting.
UnstableSort = false
)
// SortFunc sorts the row-wise indexes using given compare function.
// The compare function operates directly on row numbers into the Tensor
// as these row numbers have already been projected through the indexes.
// cmp(a, b) should return a negative number when a < b, a positive
// number when a > b and zero when a == b.
func (rw *Rows) SortFunc(cmp func(tsr Values, i, j int) int) {
rw.IndexesNeeded()
slices.SortFunc(rw.Indexes, func(a, b int) int {
return cmp(rw.Tensor, a, b) // key point: these are already indirected through indexes!!
})
}
// SortIndexes sorts the indexes into our Tensor directly in
// numerical order, producing the native ordering, while preserving
// any filtering that might have occurred.
func (rw *Rows) SortIndexes() {
if rw.Indexes == nil {
return
}
sort.Ints(rw.Indexes)
}
// CompareAscending is a sort compare function that reverses direction
// based on the ascending bool.
func CompareAscending[T cmp.Ordered](a, b T, ascending bool) int {
if ascending {
return cmp.Compare(a, b)
}
return cmp.Compare(b, a)
}
// Sort does default alpha or numeric sort of row-wise data.
// Uses first cell of higher dimensional data.
func (rw *Rows) Sort(ascending bool) {
if rw.Tensor.IsString() {
rw.SortFunc(func(tsr Values, i, j int) int {
return CompareAscending(tsr.StringRow(i, 0), tsr.StringRow(j, 0), ascending)
})
} else {
rw.SortFunc(func(tsr Values, i, j int) int {
return CompareAscending(tsr.FloatRow(i, 0), tsr.FloatRow(j, 0), ascending)
})
}
}
// SortStableFunc stably sorts the row-wise indexes using given compare function.
// The compare function operates directly on row numbers into the Tensor
// as these row numbers have already been projected through the indexes.
// cmp(a, b) should return a negative number when a < b, a positive
// number when a > b and zero when a == b.
// It is *essential* that it always returns 0 when the two are equal
// for the stable function to actually work.
func (rw *Rows) SortStableFunc(cmp func(tsr Values, i, j int) int) {
rw.IndexesNeeded()
slices.SortStableFunc(rw.Indexes, func(a, b int) int {
return cmp(rw.Tensor, a, b) // key point: these are already indirected through indexes!!
})
}
// SortStable does stable default alpha or numeric sort.
// Uses first cell of higher dimensional data.
func (rw *Rows) SortStable(ascending bool) {
if rw.Tensor.IsString() {
rw.SortStableFunc(func(tsr Values, i, j int) int {
return CompareAscending(tsr.StringRow(i, 0), tsr.StringRow(j, 0), ascending)
})
} else {
rw.SortStableFunc(func(tsr Values, i, j int) int {
return CompareAscending(tsr.FloatRow(i, 0), tsr.FloatRow(j, 0), ascending)
})
}
}
// FilterFunc is a function used for filtering that returns
// true if Tensor row should be included in the current filtered
// view of the tensor, and false if it should be removed.
type FilterFunc func(tsr Values, row int) bool
// Filter filters the indexes using given Filter function.
// The Filter function operates directly on row numbers into the Tensor
// as these row numbers have already been projected through the indexes.
func (rw *Rows) Filter(filterer func(tsr Values, row int) bool) {
rw.IndexesNeeded()
sz := len(rw.Indexes)
for i := sz - 1; i >= 0; i-- { // always go in reverse for filtering
if !filterer(rw.Tensor, rw.Indexes[i]) { // delete
rw.Indexes = append(rw.Indexes[:i], rw.Indexes[i+1:]...)
}
}
}
// todo: move to stringsx
// StringMatch are options for how to compare strings.
type StringMatch struct { //types:add
// Contains means the string only needs to contain the target string,
// with the default (false) requiring a complete match to entire string.
Contains bool
// IgnoreCase means that differences in case are ignored in comparing strings,
// with the default (false) using case.
IgnoreCase bool
// Exclude means to exclude matches,
// with the default (false) being to include.
Exclude bool
}
// Match compares two strings according to the options.
// The trg is the target string that you are comparing,
// such that it must contain the given str string,
// not the other way around.
func (sm *StringMatch) Match(trg, str string) bool {
has := false
switch {
case sm.Contains && sm.IgnoreCase:
has = strings.Contains(strings.ToLower(trg), strings.ToLower(str))
case sm.Contains:
has = strings.Contains(trg, str)
case sm.IgnoreCase:
has = strings.EqualFold(trg, str)
default:
has = (trg == str)
}
if sm.Exclude {
return !has
}
return has
}
// FilterString filters the indexes using string values compared to given
// string. Includes rows with matching values unless the Exclude option is set.
// If Contains option is set, it only checks if row contains string;
// if IgnoreCase, ignores case, otherwise filtering is case sensitive.
// Uses first cell of higher dimensional data.
func (rw *Rows) FilterString(str string, opts StringMatch) { //types:add
rw.Filter(func(tsr Values, row int) bool {
val := tsr.StringRow(row, 0)
has := opts.Match(val, str)
return has
})
}
// AsValues returns this tensor as raw [Values].
// If the row [Rows.Indexes] are nil, then the wrapped Values tensor
// is returned. Otherwise, it "renders" the Rows view into a fully contiguous
// and optimized memory representation of that view, which will be faster
// to access for further processing, and enables all the additional
// functionality provided by the [Values] interface.
func (rw *Rows) AsValues() Values {
if rw.Indexes == nil {
return rw.Tensor
}
vt := NewOfType(rw.Tensor.DataType(), rw.ShapeSizes()...)
rows := rw.NumRows()
for r := range rows {
vt.SetRowTensor(rw.RowTensor(r), r)
}
return vt
}
// CloneIndexes returns a copy of the current Rows view with new indexes,
// with a pointer to the same underlying Tensor as the source.
func (rw *Rows) CloneIndexes() *Rows {
nix := &Rows{}
nix.Tensor = rw.Tensor
nix.CopyIndexes(rw)
return nix
}
// CopyIndexes copies indexes from other Rows view.
func (rw *Rows) CopyIndexes(oix *Rows) {
if oix.Indexes == nil {
rw.Indexes = nil
} else {
rw.Indexes = slices.Clone(oix.Indexes)
}
}
// addRowsIndexes adds n rows to indexes starting at end of current tensor size
func (rw *Rows) addRowsIndexes(n int) { //types:add
if rw.Indexes == nil {
return
}
stidx := rw.Tensor.DimSize(0)
for i := stidx; i < stidx+n; i++ {
rw.Indexes = append(rw.Indexes, i)
}
}
// AddRows adds n rows to end of underlying Tensor, and to the indexes in this view
func (rw *Rows) AddRows(n int) { //types:add
stidx := rw.Tensor.DimSize(0)
rw.addRowsIndexes(n)
rw.Tensor.SetNumRows(stidx + n)
}
// InsertRows adds n rows to end of underlying Tensor, and to the indexes starting at
// given index in this view
func (rw *Rows) InsertRows(at, n int) {
stidx := rw.Tensor.DimSize(0)
rw.IndexesNeeded()
rw.Tensor.SetNumRows(stidx + n)
nw := make([]int, n, n+len(rw.Indexes)-at)
for i := 0; i < n; i++ {
nw[i] = stidx + i
}
rw.Indexes = append(rw.Indexes[:at], append(nw, rw.Indexes[at:]...)...)
}
// DeleteRows deletes n rows of indexes starting at given index in the list of indexes
func (rw *Rows) DeleteRows(at, n int) {
rw.IndexesNeeded()
rw.Indexes = append(rw.Indexes[:at], rw.Indexes[at+n:]...)
}
// Swap switches the indexes for i and j
func (rw *Rows) Swap(i, j int) {
if rw.Indexes == nil {
return
}
rw.Indexes[i], rw.Indexes[j] = rw.Indexes[j], rw.Indexes[i]
}
/////// Floats
// Float returns the value of given index as a float64.
// The first index value is indirected through the indexes.
func (rw *Rows) Float(i ...int) float64 {
if rw.Indexes == nil {
return rw.Tensor.Float(i...)
}
ic := slices.Clone(i)
ic[0] = rw.Indexes[ic[0]]
return rw.Tensor.Float(ic...)
}
// SetFloat sets the value of given index as a float64
// The first index value is indirected through the [Rows.Indexes].
func (rw *Rows) SetFloat(val float64, i ...int) {
if rw.Indexes == nil {
rw.Tensor.SetFloat(val, i...)
return
}
ic := slices.Clone(i)
ic[0] = rw.Indexes[ic[0]]
rw.Tensor.SetFloat(val, ic...)
}
// FloatRow returns the value at given row and cell,
// where row is outermost dim, and cell is 1D index into remaining inner dims.
// Row is indirected through the [Rows.Indexes].
// This is the preferred interface for all Rows operations.
func (rw *Rows) FloatRow(row, cell int) float64 {
return rw.Tensor.FloatRow(rw.RowIndex(row), cell)
}
// SetFloatRow sets the value at given row and cell,
// where row is outermost dim, and cell is 1D index into remaining inner dims.
// Row is indirected through the [Rows.Indexes].
// This is the preferred interface for all Rows operations.
func (rw *Rows) SetFloatRow(val float64, row, cell int) {
rw.Tensor.SetFloatRow(val, rw.RowIndex(row), cell)
}
// Float1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (rw *Rows) Float1D(i int) float64 {
if rw.Indexes == nil {
return rw.Tensor.Float1D(i)
}
return rw.Float(rw.Tensor.Shape().IndexFrom1D(i)...)
}
// SetFloat1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (rw *Rows) SetFloat1D(val float64, i int) {
if rw.Indexes == nil {
rw.Tensor.SetFloat1D(val, i)
return
}
rw.SetFloat(val, rw.Tensor.Shape().IndexFrom1D(i)...)
}
/////// Strings
// StringValue returns the value of given index as a string.
// The first index value is indirected through the indexes.
func (rw *Rows) StringValue(i ...int) string {
if rw.Indexes == nil {
return rw.Tensor.StringValue(i...)
}
ic := slices.Clone(i)
ic[0] = rw.Indexes[ic[0]]
return rw.Tensor.StringValue(ic...)
}
// SetString sets the value of given index as a string
// The first index value is indirected through the [Rows.Indexes].
func (rw *Rows) SetString(val string, i ...int) {
if rw.Indexes == nil {
rw.Tensor.SetString(val, i...)
return
}
ic := slices.Clone(i)
ic[0] = rw.Indexes[ic[0]]
rw.Tensor.SetString(val, ic...)
}
// StringRow returns the value at given row and cell,
// where row is outermost dim, and cell is 1D index into remaining inner dims.
// Row is indirected through the [Rows.Indexes].
// This is the preferred interface for all Rows operations.
func (rw *Rows) StringRow(row, cell int) string {
return rw.Tensor.StringRow(rw.RowIndex(row), cell)
}
// SetStringRow sets the value at given row and cell,
// where row is outermost dim, and cell is 1D index into remaining inner dims.
// Row is indirected through the [Rows.Indexes].
// This is the preferred interface for all Rows operations.
func (rw *Rows) SetStringRow(val string, row, cell int) {
rw.Tensor.SetStringRow(val, rw.RowIndex(row), cell)
}
// AppendRowFloat adds a row and sets float value(s), up to number of cells.
func (rw *Rows) AppendRowFloat(val ...float64) {
rw.addRowsIndexes(1)
rw.Tensor.AppendRowFloat(val...)
}
// String1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (rw *Rows) String1D(i int) string {
if rw.Indexes == nil {
return rw.Tensor.String1D(i)
}
return rw.StringValue(rw.Tensor.Shape().IndexFrom1D(i)...)
}
// SetString1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (rw *Rows) SetString1D(val string, i int) {
if rw.Indexes == nil {
rw.Tensor.SetString1D(val, i)
return
}
rw.SetString(val, rw.Tensor.Shape().IndexFrom1D(i)...)
}
// AppendRowString adds a row and sets string value(s), up to number of cells.
func (rw *Rows) AppendRowString(val ...string) {
rw.addRowsIndexes(1)
rw.Tensor.AppendRowString(val...)
}
/////// Ints
// Int returns the value of given index as an int.
// The first index value is indirected through the indexes.
func (rw *Rows) Int(i ...int) int {
if rw.Indexes == nil {
return rw.Tensor.Int(i...)
}
ic := slices.Clone(i)
ic[0] = rw.Indexes[ic[0]]
return rw.Tensor.Int(ic...)
}
// SetInt sets the value of given index as an int
// The first index value is indirected through the [Rows.Indexes].
func (rw *Rows) SetInt(val int, i ...int) {
if rw.Indexes == nil {
rw.Tensor.SetInt(val, i...)
return
}
ic := slices.Clone(i)
ic[0] = rw.Indexes[ic[0]]
rw.Tensor.SetInt(val, ic...)
}
// IntRow returns the value at given row and cell,
// where row is outermost dim, and cell is 1D index into remaining inner dims.
// Row is indirected through the [Rows.Indexes].
// This is the preferred interface for all Rows operations.
func (rw *Rows) IntRow(row, cell int) int {
return rw.Tensor.IntRow(rw.RowIndex(row), cell)
}
// SetIntRow sets the value at given row and cell,
// where row is outermost dim, and cell is 1D index into remaining inner dims.
// Row is indirected through the [Rows.Indexes].
// This is the preferred interface for all Rows operations.
func (rw *Rows) SetIntRow(val int, row, cell int) {
rw.Tensor.SetIntRow(val, rw.RowIndex(row), cell)
}
// AppendRowInt adds a row and sets int value(s), up to number of cells.
func (rw *Rows) AppendRowInt(val ...int) {
rw.addRowsIndexes(1)
rw.Tensor.AppendRowInt(val...)
}
// Int1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (rw *Rows) Int1D(i int) int {
if rw.Indexes == nil {
return rw.Tensor.Int1D(i)
}
return rw.Int(rw.Tensor.Shape().IndexFrom1D(i)...)
}
// SetInt1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (rw *Rows) SetInt1D(val int, i int) {
if rw.Indexes == nil {
rw.Tensor.SetInt1D(val, i)
return
}
rw.SetInt(val, rw.Tensor.Shape().IndexFrom1D(i)...)
}
/////// SubSpaces
// SubSpace returns a new tensor with innermost subspace at given
// offset(s) in outermost dimension(s) (len(offs) < NumDims).
// The new tensor points to the values of the this tensor (i.e., modifications
// will affect both), as its Values slice is a view onto the original (which
// is why only inner-most contiguous supsaces are supported).
// Use Clone() method to separate the two.
// Rows version does indexed indirection of the outermost row dimension
// of the offsets.
func (rw *Rows) SubSpace(offs ...int) Values {
if len(offs) == 0 {
return nil
}
offs[0] = rw.RowIndex(offs[0])
return rw.Tensor.SubSpace(offs...)
}
// RowTensor is a convenience version of [Rows.SubSpace] to return the
// SubSpace for the outermost row dimension, indirected through the indexes.
func (rw *Rows) RowTensor(row int) Values {
return rw.Tensor.RowTensor(rw.RowIndex(row))
}
// SetRowTensor sets the values of the SubSpace at given row to given values,
// with row indirected through the indexes.
func (rw *Rows) SetRowTensor(val Values, row int) {
rw.Tensor.SetRowTensor(val, rw.RowIndex(row))
}
// AppendRow adds a row and sets values to given values.
func (rw *Rows) AppendRow(val Values) {
nrow := rw.Tensor.DimSize(0)
rw.AddRows(1)
rw.Tensor.SetRowTensor(val, nrow)
}
// check for interface impl
var _ RowMajor = (*Rows)(nil)
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"fmt"
"slices"
)
// Shape manages a tensor's shape information, including sizes and strides,
// and can compute the flat index into an underlying 1D data storage array based on an
// n-dimensional index (and vice-versa).
// Per Go / C / Python conventions, indexes are Row-Major, ordered from
// outer to inner left-to-right, so the inner-most is right-most.
type Shape struct {
// size per dimension.
Sizes []int
// offsets for each dimension.
Strides []int `display:"-"`
}
// NewShape returns a new shape with given sizes.
// RowMajor ordering is used by default.
func NewShape(sizes ...int) *Shape {
sh := &Shape{}
sh.SetShapeSizes(sizes...)
return sh
}
// SetShapeSizes sets the shape sizes from list of ints.
// RowMajor ordering is used by default.
func (sh *Shape) SetShapeSizes(sizes ...int) {
sh.Sizes = slices.Clone(sizes)
sh.Strides = RowMajorStrides(sizes...)
}
// SetShapeSizesFromTensor sets the shape sizes from given tensor.
// RowMajor ordering is used by default.
func (sh *Shape) SetShapeSizesFromTensor(sizes Tensor) {
sh.SetShapeSizes(AsIntSlice(sizes)...)
}
// SizesAsTensor returns shape sizes as an Int Tensor.
func (sh *Shape) SizesAsTensor() *Int {
return NewIntFromValues(sh.Sizes...)
}
// CopyFrom copies the shape parameters from another Shape struct.
// copies the data so it is not accidentally subject to updates.
func (sh *Shape) CopyFrom(cp *Shape) {
sh.Sizes = slices.Clone(cp.Sizes)
sh.Strides = slices.Clone(cp.Strides)
}
// Len returns the total length of elements in the tensor
// (i.e., the product of the shape sizes).
func (sh *Shape) Len() int {
if len(sh.Sizes) == 0 {
return 0
}
ln := 1
for _, v := range sh.Sizes {
ln *= v
}
return ln
}
// NumDims returns the total number of dimensions.
func (sh *Shape) NumDims() int { return len(sh.Sizes) }
// DimSize returns the size of given dimension.
func (sh *Shape) DimSize(i int) int {
return sh.Sizes[i]
}
// IndexIsValid() returns true if given index is valid (within ranges for all dimensions)
func (sh *Shape) IndexIsValid(idx ...int) bool {
if len(idx) != sh.NumDims() {
return false
}
for i, v := range sh.Sizes {
if idx[i] < 0 || idx[i] >= v {
return false
}
}
return true
}
// IsEqual returns true if this shape is same as other (does not compare names)
func (sh *Shape) IsEqual(oth *Shape) bool {
if slices.Compare(sh.Sizes, oth.Sizes) != 0 {
return false
}
if slices.Compare(sh.Strides, oth.Strides) != 0 {
return false
}
return true
}
// RowCellSize returns the size of the outermost Row shape dimension,
// and the size of all the remaining inner dimensions (the "cell" size).
// Used for Tensors that are columns in a data table.
func (sh *Shape) RowCellSize() (rows, cells int) {
if len(sh.Sizes) == 0 {
return 0, 1
}
rows = sh.Sizes[0]
if len(sh.Sizes) == 1 {
cells = 1
} else if rows > 0 {
cells = sh.Len() / rows
} else {
ln := 1
for _, v := range sh.Sizes[1:] {
ln *= v
}
cells = ln
}
return
}
// IndexTo1D returns the flat 1D index from given n-dimensional indicies.
// No checking is done on the length or size of the index values relative
// to the shape of the tensor.
func (sh *Shape) IndexTo1D(index ...int) int {
oned := 0
for i, v := range index {
oned += v * sh.Strides[i]
}
return oned
}
// IndexFrom1D returns the n-dimensional index from a "flat" 1D array index.
func (sh *Shape) IndexFrom1D(oned int) []int {
nd := len(sh.Sizes)
index := make([]int, nd)
rem := oned
for i := nd - 1; i >= 0; i-- {
s := sh.Sizes[i]
if s == 0 {
return index
}
iv := rem % s
rem /= s
index[i] = iv
}
return index
}
// String satisfies the fmt.Stringer interface
func (sh *Shape) String() string {
return fmt.Sprintf("%v", sh.Sizes)
}
// RowMajorStrides returns strides for sizes where the first dimension is outermost
// and subsequent dimensions are progressively inner.
func RowMajorStrides(sizes ...int) []int {
if len(sizes) == 0 {
return nil
}
sizes[0] = max(1, sizes[0]) // critical for strides to not be nil due to rows = 0
rem := int(1)
for _, v := range sizes {
rem *= v
}
if rem == 0 {
strides := make([]int, len(sizes))
for i := range strides {
strides[i] = rem
}
return strides
}
strides := make([]int, len(sizes))
for i, v := range sizes {
rem /= v
strides[i] = rem
}
return strides
}
// ColumnMajorStrides returns strides for sizes where the first dimension is inner-most
// and subsequent dimensions are progressively outer
func ColumnMajorStrides(sizes ...int) []int {
total := int(1)
for _, v := range sizes {
if v == 0 {
strides := make([]int, len(sizes))
for i := range strides {
strides[i] = total
}
return strides
}
}
strides := make([]int, len(sizes))
for i, v := range sizes {
strides[i] = total
total *= v
}
return strides
}
// AddShapes returns a new shape by adding two shapes one after the other.
func AddShapes(shape1, shape2 *Shape) *Shape {
sh1 := shape1.Sizes
sh2 := shape2.Sizes
nsh := make([]int, len(sh1)+len(sh2))
copy(nsh, sh1)
copy(nsh[len(sh1):], sh2)
sh := NewShape(nsh...)
return sh
}
// CellsSizes returns the sizes of inner cells dimensions given
// overall tensor sizes. It returns []int{1} for the 1D case.
// Used for ensuring cell-wise outputs are the right size.
func CellsSize(sizes []int) []int {
csz := slices.Clone(sizes)
if len(csz) == 1 {
csz[0] = 1
} else {
csz = csz[1:]
}
return csz
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"math/rand"
"reflect"
"slices"
"sort"
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/base/reflectx"
"cogentcore.org/core/base/slicesx"
)
// Sliced provides a re-sliced view onto another "source" [Tensor],
// defined by a set of [Sliced.Indexes] for each dimension (must have
// at least 1 index per dimension to avoid a null view).
// Thus, each dimension can be transformed in arbitrary ways relative
// to the original tensor (filtered subsets, reversals, sorting, etc).
// This view is not memory-contiguous and does not support the [RowMajor]
// interface or efficient access to inner-dimensional subspaces.
// A new Sliced view defaults to a full transparent view of the source tensor.
// There is additional cost for every access operation associated with the
// indexed indirection, and access is always via the full n-dimensional indexes.
// See also [Rows] for a version that only indexes the outermost row dimension,
// which is much more efficient for this common use-case, and does support [RowMajor].
// To produce a new concrete [Values] that has raw data actually organized according
// to the indexed order (i.e., the copy function of numpy), call [Sliced.AsValues].
type Sliced struct { //types:add
// Tensor source that we are an indexed view onto.
Tensor Tensor
// Indexes are the indexes for each dimension, with dimensions as the outer
// slice (enforced to be the same length as the NumDims of the source Tensor),
// and a list of dimension index values (within range of DimSize(d)).
// A nil list of indexes for a dimension automatically provides a full,
// sequential view of that dimension.
Indexes [][]int
}
// NewSliced returns a new [Sliced] view of given tensor,
// with optional list of indexes for each dimension (none / nil = sequential).
// Any dimensions without indexes default to nil = full sequential view.
func NewSliced(tsr Tensor, idxs ...[]int) *Sliced {
sl := &Sliced{Tensor: tsr, Indexes: idxs}
sl.ValidIndexes()
return sl
}
// AnySlice returns a new Tensor view using the given index
// variables to be used for Sliced, Masked, or Indexed
// depending on what is present.
// - If a [Bool] tensor is provided then [NewMasked] is called.
// - If a single tensor is provided with Len > max(1, tsr.NumDims()),
// [NewIndexed] is called.
// - Otherwise, [Reslice] is called with the args.
func AnySlice(tsr Tensor, idx ...any) Tensor {
n := len(idx)
if n == 0 {
return tsr
}
if n == 1 {
if b, ok := idx[0].(*Bool); ok {
return NewMasked(tsr, b)
}
if i, ok := idx[0].(Tensor); ok {
if i.Len() > 1 && i.Len() > tsr.NumDims() {
return NewIndexed(tsr, AsInt(i))
}
}
}
return Reslice(tsr, idx...)
}
// Reslice returns a new [Sliced] (and potentially [Reshaped]) view of given tensor,
// with given slice expressions for each dimension, which can be:
// - an integer, indicating a specific index value along that dimension.
// Can use negative numbers to index from the end.
// This axis will also be removed using a [Reshaped].
// - a [Slice] object expressing a range of indexes.
// - [FullAxis] includes the full original axis (equivalent to `Slice{}`).
// - [Ellipsis] creates a flexibly-sized stretch of FullAxis dimensions,
// which automatically aligns the remaining slice elements based on the source
// dimensionality.
// - [NewAxis] creates a new singleton (length=1) axis, used to to reshape
// without changing the size. This triggers a [Reshaped].
// - any remaining dimensions without indexes default to nil = full sequential view.
func Reslice(tsr Tensor, sls ...any) Tensor {
ns := len(sls)
if ns == 0 {
return NewSliced(tsr)
}
nd := tsr.NumDims()
ed := nd - ns // extra dimensions
ixs := make([][]int, nd)
doReshape := false // indicates if we need a Reshaped
reshape := make([]int, 0, nd+2) // if we need one, this is the target shape
ci := 0
for d := range ns {
s := sls[d]
if st, ok := s.(Tensor); ok {
doReshape = true // doesn't add to new shape.
ni := st.Len()
for i := range ni {
ix := st.Int1D(i)
if ix < 0 {
ixs[ci] = []int{tsr.DimSize(ci) + ix}
} else {
ixs[ci] = []int{ix}
}
ci++
}
continue
}
switch x := s.(type) {
case int:
doReshape = true // doesn't add to new shape.
if x < 0 {
ixs[ci] = []int{tsr.DimSize(ci) + x}
} else {
ixs[ci] = []int{x}
}
case Slice:
ixs[ci] = x.IntSlice(tsr.DimSize(ci))
reshape = append(reshape, len(ixs[ci]))
case SlicesMagic:
switch x {
case FullAxis:
ixs[ci] = Slice{}.IntSlice(tsr.DimSize(ci))
reshape = append(reshape, len(ixs[ci]))
case NewAxis:
ed++ // we are not real
doReshape = true
reshape = append(reshape, 1)
continue // skip the increment in ci
case Ellipsis:
ed++ // extra for us
for range ed {
ixs[ci] = Slice{}.IntSlice(tsr.DimSize(ci))
reshape = append(reshape, len(ixs[ci]))
ci++
}
if ed > 0 {
ci--
}
ed = 0 // ate them up
}
}
ci++
}
for range ed { // fill any extra dimensions
ixs[ci] = Slice{}.IntSlice(tsr.DimSize(ci))
reshape = append(reshape, len(ixs[ci]))
ci++
}
sl := NewSliced(tsr, ixs...)
if doReshape {
if len(reshape) == 0 { // all indexes
reshape = []int{1}
}
return NewReshaped(sl, reshape...)
}
return sl
}
// AsSliced returns the tensor as a [Sliced] view.
// If it already is one, then it is returned, otherwise it is wrapped
// in a new Sliced, with default full sequential ("transparent") view.
func AsSliced(tsr Tensor) *Sliced {
if sl, ok := tsr.(*Sliced); ok {
return sl
}
return NewSliced(tsr)
}
// SetTensor sets tensor as source for this view, and initializes a full
// transparent view onto source (calls [Sliced.Sequential]).
func (sl *Sliced) SetTensor(tsr Tensor) {
sl.Tensor = tsr
sl.Sequential()
}
// SourceIndex returns the actual index into source tensor dimension
// based on given index value.
func (sl *Sliced) SourceIndex(dim, idx int) int {
ix := sl.Indexes[dim]
if ix == nil {
return idx
}
return ix[idx]
}
// SourceIndexes returns the actual n-dimensional indexes into source tensor
// based on given list of indexes based on the Sliced view shape.
func (sl *Sliced) SourceIndexes(i ...int) []int {
ix := slices.Clone(i)
for d, idx := range i {
ix[d] = sl.SourceIndex(d, idx)
}
return ix
}
// SourceIndexesFrom1D returns the n-dimensional indexes into source tensor
// based on the given 1D index based on the Sliced view shape.
func (sl *Sliced) SourceIndexesFrom1D(oned int) []int {
sh := sl.Shape()
oix := sh.IndexFrom1D(oned) // full indexes in our coords
return sl.SourceIndexes(oix...)
}
// ValidIndexes ensures that [Sliced.Indexes] are valid,
// removing any out-of-range values and setting the view to nil (full sequential)
// for any dimension with no indexes (which is an invalid condition).
// Call this when any structural changes are made to underlying Tensor.
func (sl *Sliced) ValidIndexes() {
nd := sl.Tensor.NumDims()
sl.Indexes = slicesx.SetLength(sl.Indexes, nd)
for d := range nd {
ni := len(sl.Indexes[d])
if ni == 0 { // invalid
sl.Indexes[d] = nil // full
continue
}
ds := sl.Tensor.DimSize(d)
ix := sl.Indexes[d]
for i := ni - 1; i >= 0; i-- {
if ix[i] >= ds {
ix = append(ix[:i], ix[i+1:]...)
}
}
sl.Indexes[d] = ix
}
}
// Sequential sets all Indexes to nil, resulting in full sequential access into tensor.
func (sl *Sliced) Sequential() { //types:add
nd := sl.Tensor.NumDims()
sl.Indexes = slicesx.SetLength(sl.Indexes, nd)
for d := range nd {
sl.Indexes[d] = nil
}
}
// IndexesNeeded is called prior to an operation that needs actual indexes,
// on given dimension. If Indexes == nil, they are set to all items, otherwise
// current indexes are left as is. Use Sequential, then IndexesNeeded to ensure
// all dimension indexes are represented.
func (sl *Sliced) IndexesNeeded(d int) {
ix := sl.Indexes[d]
if ix != nil {
return
}
ix = make([]int, sl.Tensor.DimSize(d))
for i := range ix {
ix[i] = i
}
sl.Indexes[d] = ix
}
func (sl *Sliced) Label() string { return label(metadata.Name(sl), sl.Shape()) }
func (sl *Sliced) String() string { return Sprintf("", sl, 0) }
func (sl *Sliced) Metadata() *metadata.Data { return sl.Tensor.Metadata() }
func (sl *Sliced) IsString() bool { return sl.Tensor.IsString() }
func (sl *Sliced) DataType() reflect.Kind { return sl.Tensor.DataType() }
func (sl *Sliced) Shape() *Shape { return NewShape(sl.ShapeSizes()...) }
func (sl *Sliced) Len() int { return sl.Shape().Len() }
func (sl *Sliced) NumDims() int { return sl.Tensor.NumDims() }
// For each dimension, we return the effective shape sizes using
// the current number of indexes per dimension.
func (sl *Sliced) ShapeSizes() []int {
nd := sl.Tensor.NumDims()
if nd == 0 {
return sl.Tensor.ShapeSizes()
}
sh := slices.Clone(sl.Tensor.ShapeSizes())
for d := range nd {
if sl.Indexes[d] != nil {
sh[d] = len(sl.Indexes[d])
}
}
return sh
}
// DimSize returns the effective view size of given dimension.
func (sl *Sliced) DimSize(dim int) int {
if sl.Indexes[dim] != nil {
return len(sl.Indexes[dim])
}
return sl.Tensor.DimSize(dim)
}
// AsValues returns a copy of this tensor as raw [Values].
// This "renders" the Sliced view into a fully contiguous
// and optimized memory representation of that view, which will be faster
// to access for further processing, and enables all the additional
// functionality provided by the [Values] interface.
func (sl *Sliced) AsValues() Values {
dt := sl.Tensor.DataType()
vt := NewOfType(dt, sl.ShapeSizes()...)
n := sl.Len()
switch {
case sl.Tensor.IsString():
for i := range n {
vt.SetString1D(sl.String1D(i), i)
}
case reflectx.KindIsFloat(dt):
for i := range n {
vt.SetFloat1D(sl.Float1D(i), i)
}
default:
for i := range n {
vt.SetInt1D(sl.Int1D(i), i)
}
}
return vt
}
//////// Floats
// Float returns the value of given index as a float64.
// The indexes are indirected through the [Sliced.Indexes].
func (sl *Sliced) Float(i ...int) float64 {
return sl.Tensor.Float(sl.SourceIndexes(i...)...)
}
// SetFloat sets the value of given index as a float64
// The indexes are indirected through the [Sliced.Indexes].
func (sl *Sliced) SetFloat(val float64, i ...int) {
sl.Tensor.SetFloat(val, sl.SourceIndexes(i...)...)
}
// Float1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (sl *Sliced) Float1D(i int) float64 {
return sl.Tensor.Float(sl.SourceIndexesFrom1D(i)...)
}
// SetFloat1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (sl *Sliced) SetFloat1D(val float64, i int) {
sl.Tensor.SetFloat(val, sl.SourceIndexesFrom1D(i)...)
}
//////// Strings
// StringValue returns the value of given index as a string.
// The indexes are indirected through the [Sliced.Indexes].
func (sl *Sliced) StringValue(i ...int) string {
return sl.Tensor.StringValue(sl.SourceIndexes(i...)...)
}
// SetString sets the value of given index as a string
// The indexes are indirected through the [Sliced.Indexes].
func (sl *Sliced) SetString(val string, i ...int) {
sl.Tensor.SetString(val, sl.SourceIndexes(i...)...)
}
// String1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (sl *Sliced) String1D(i int) string {
return sl.Tensor.StringValue(sl.SourceIndexesFrom1D(i)...)
}
// SetString1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (sl *Sliced) SetString1D(val string, i int) {
sl.Tensor.SetString(val, sl.SourceIndexesFrom1D(i)...)
}
//////// Ints
// Int returns the value of given index as an int.
// The indexes are indirected through the [Sliced.Indexes].
func (sl *Sliced) Int(i ...int) int {
return sl.Tensor.Int(sl.SourceIndexes(i...)...)
}
// SetInt sets the value of given index as an int
// The indexes are indirected through the [Sliced.Indexes].
func (sl *Sliced) SetInt(val int, i ...int) {
sl.Tensor.SetInt(val, sl.SourceIndexes(i...)...)
}
// Int1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (sl *Sliced) Int1D(i int) int {
return sl.Tensor.Int(sl.SourceIndexesFrom1D(i)...)
}
// SetInt1D is somewhat expensive if indexes are set, because it needs to convert
// the flat index back into a full n-dimensional index and then use that api.
func (sl *Sliced) SetInt1D(val int, i int) {
sl.Tensor.SetInt(val, sl.SourceIndexesFrom1D(i)...)
}
// Permuted sets indexes in given dimension to a permuted order.
// If indexes already exist then existing list of indexes is permuted,
// otherwise a new set of permuted indexes are generated
func (sl *Sliced) Permuted(dim int) {
ix := sl.Indexes[dim]
if ix == nil {
ix = rand.Perm(sl.Tensor.DimSize(dim))
} else {
rand.Shuffle(len(ix), func(i, j int) {
ix[i], ix[j] = ix[j], ix[i]
})
}
sl.Indexes[dim] = ix
}
// SortFunc sorts the indexes along given dimension using given compare function.
// The compare function operates directly on indexes into the source Tensor
// that Sliced is a view of, as these row numbers have already been projected
// through the indexes. That is why the tensor is passed through to the compare
// function, to ensure the proper tensor values are being used.
// cmp(a, b) should return a negative number when a < b, a positive
// number when a > b and zero when a == b.
func (sl *Sliced) SortFunc(dim int, cmp func(tsr Tensor, i, j int) int) {
sl.IndexesNeeded(dim)
ix := sl.Indexes[dim]
slices.SortFunc(ix, func(a, b int) int {
return cmp(sl.Tensor, a, b) // key point: these are already indirected through indexes!!
})
sl.Indexes[dim] = ix
}
// SortIndexes sorts the indexes along given dimension directly in
// numerical order, producing the native ordering, while preserving
// any filtering that might have occurred.
func (sl *Sliced) SortIndexes(dim int) {
ix := sl.Indexes[dim]
if ix == nil {
return
}
sort.Ints(ix)
sl.Indexes[dim] = ix
}
// SortStableFunc stably sorts along given dimension using given compare function.
// The compare function operates directly on indexes into the source Tensor
// that Sliced is a view of, as these row numbers have already been projected
// through the indexes. That is why the tensor is passed through to the compare
// function, to ensure the proper tensor values are being used.
// cmp(a, b) should return a negative number when a < b, a positive
// number when a > b and zero when a == b.
// It is *essential* that it always returns 0 when the two are equal
// for the stable function to actually work.
func (sl *Sliced) SortStableFunc(dim int, cmp func(tsr Tensor, i, j int) int) {
sl.IndexesNeeded(dim)
ix := sl.Indexes[dim]
slices.SortStableFunc(ix, func(a, b int) int {
return cmp(sl.Tensor, a, b) // key point: these are already indirected through indexes!!
})
sl.Indexes[dim] = ix
}
// Filter filters the indexes using the given Filter function
// for setting the indexes for given dimension, and index into the
// source data.
func (sl *Sliced) Filter(dim int, filterer func(tsr Tensor, dim, idx int) bool) {
sl.IndexesNeeded(dim)
ix := sl.Indexes[dim]
sz := len(ix)
for i := sz - 1; i >= 0; i-- { // always go in reverse for filtering
if !filterer(sl, dim, ix[i]) { // delete
ix = append(ix[:i], ix[i+1:]...)
}
}
sl.Indexes[dim] = ix
}
// check for interface impl
var _ Tensor = (*Sliced)(nil)
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
// SlicesMagic are special elements in slice expressions, including
// NewAxis, FullAxis, and Ellipsis in [NewSliced] expressions.
type SlicesMagic int //enums:enum
const (
// FullAxis indicates that the full existing axis length should be used.
// This is equivalent to Slice{}, but is more semantic. In NumPy it is
// equivalent to a single : colon.
FullAxis SlicesMagic = iota
// NewAxis creates a new singleton (length=1) axis, used to to reshape
// without changing the size. Can also be used in [Reshaped].
NewAxis
// Ellipsis (...) is used in [NewSliced] expressions to produce
// a flexibly-sized stretch of FullAxis dimensions, which automatically
// aligns the remaining slice elements based on the source dimensionality.
Ellipsis
)
// Slice represents a slice of index values, for extracting slices of data,
// along a dimension of a given size, which is provided separately as an argument.
// Uses standard 'for' loop logic with a Start and _exclusive_ Stop value,
// and a Step increment: for i := Start; i < Stop; i += Step.
// The values stored in this struct are the _inputs_ for computing the actual
// slice values based on the actual size parameter for the dimension.
// Negative numbers count back from the end (i.e., size + val), and
// the zero value results in a list of all values in the dimension, with Step = 1 if 0.
// The behavior is identical to the NumPy slice.
type Slice struct {
// Start is the starting value. If 0 and Step < 0, = size-1;
// If negative, = size+Start.
Start int
// Stop value. If 0 and Step >= 0, = size;
// If 0 and Step < 0, = -1, to include whole range.
// If negative = size+Stop.
Stop int
// Step increment. If 0, = 1; if negative then Start must be > Stop
// to produce anything.
Step int
}
// NewSlice returns a new Slice with given srat, stop, step values.
func NewSlice(start, stop, step int) Slice {
return Slice{Start: start, Stop: stop, Step: step}
}
// GetStart is the actual start value given the size of the dimension.
func (sl Slice) GetStart(size int) int {
if sl.Start == 0 && sl.Step < 0 {
return size - 1
}
if sl.Start < 0 {
return size + sl.Start
}
return sl.Start
}
// GetStop is the actual end value given the size of the dimension.
func (sl Slice) GetStop(size int) int {
if sl.Stop == 0 && sl.Step >= 0 {
return size
}
if sl.Stop == 0 && sl.Step < 0 {
return -1
}
if sl.Stop < 0 {
return size + sl.Stop
}
return min(sl.Stop, size)
}
// GetStep is the actual increment value.
func (sl Slice) GetStep() int {
if sl.Step == 0 {
return 1
}
return sl.Step
}
// Len is the number of elements in the actual slice given
// size of the dimension.
func (sl Slice) Len(size int) int {
s := sl.GetStart(size)
e := sl.GetStop(size)
i := sl.GetStep()
n := max((e-s)/i, 0)
pe := s + n*i
if i < 0 {
if pe > e {
n++
}
} else {
if pe < e {
n++
}
}
return n
}
// ToIntSlice writes values to given []int slice, with given size parameter
// for the dimension being sliced. If slice is wrong size to hold values,
// not all are written: allocate ints using Len(size) to fit.
func (sl Slice) ToIntSlice(size int, ints []int) {
n := len(ints)
if n == 0 {
return
}
s := sl.GetStart(size)
e := sl.GetStop(size)
inc := sl.GetStep()
idx := 0
if inc < 0 {
for i := s; i > e; i += inc {
ints[idx] = i
idx++
if idx >= n {
break
}
}
} else {
for i := s; i < e; i += inc {
ints[idx] = i
idx++
if idx >= n {
break
}
}
}
}
// IntSlice returns []int slice with slice index values, up to given actual size.
func (sl Slice) IntSlice(size int) []int {
n := sl.Len(size)
if n == 0 {
return nil
}
ints := make([]int, n)
sl.ToIntSlice(size, ints)
return ints
}
// IntTensor returns an [Int] [Tensor] for slice, using actual size.
func (sl Slice) IntTensor(size int) *Int {
n := sl.Len(size)
if n == 0 {
return nil
}
tsr := NewInt(n)
sl.ToIntSlice(size, tsr.Values)
return tsr
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"encoding/binary"
"fmt"
"strconv"
"cogentcore.org/core/base/errors"
)
// String is a tensor of string values
type String struct {
Base[string]
}
// NewString returns a new n-dimensional tensor of string values
// with the given sizes per dimension (shape).
func NewString(sizes ...int) *String {
tsr := &String{}
tsr.SetShapeSizes(sizes...)
tsr.Values = make([]string, tsr.Len())
return tsr
}
// NewStringShape returns a new n-dimensional tensor of string values
// using given shape.
func NewStringShape(shape *Shape) *String {
tsr := &String{}
tsr.shape.CopyFrom(shape)
tsr.Values = make([]string, tsr.Len())
return tsr
}
// StringToFloat64 converts string value to float64 using strconv,
// returning 0 if any error
func StringToFloat64(str string) float64 {
if fv, err := strconv.ParseFloat(str, 64); err == nil {
return fv
}
return 0
}
// Float64ToString converts float64 to string value using strconv, g format
func Float64ToString(val float64) string {
return strconv.FormatFloat(val, 'g', -1, 64)
}
// String satisfies the fmt.Stringer interface for string of tensor data.
func (tsr *String) String() string {
return Sprintf("", tsr, 0)
}
func (tsr *String) IsString() bool {
return true
}
func (tsr *String) AsValues() Values { return tsr }
// Bytes encodes the actual string values starting with the length
// of each string as an encoded int value. SetFromBytes decodes
// this format to reconstruct the contents.
func (tsr *String) Bytes() []byte {
n := tsr.Len()
lb := make([]byte, 8)
var b []byte
for i := range n {
s := tsr.Values[i]
bs := []byte(s)
binary.Encode(lb, binary.LittleEndian, int64(len(bs)))
b = append(b, lb...)
b = append(b, bs...)
}
return b
}
// SetFromBytes sets string values from strings encoded using
// [String.Bytes] function.
func (tsr *String) SetFromBytes(b []byte) {
n := tsr.Len()
si := 0
for i := 0; i < len(b) && si < n; {
var l int64
binary.Decode(b[i:i+8], binary.LittleEndian, &l)
s := string(b[i+8 : i+8+int(l)])
tsr.Values[si] = s
si++
i += 8 + int(l)
}
}
/////// Strings
func (tsr *String) SetString(val string, i ...int) {
tsr.Values[tsr.shape.IndexTo1D(i...)] = val
}
func (tsr *String) String1D(i int) string {
return tsr.Values[i]
}
func (tsr *String) SetString1D(val string, i int) {
tsr.Values[NegIndex(i, len(tsr.Values))] = val
}
func (tsr *String) StringRow(row, cell int) string {
_, sz := tsr.shape.RowCellSize()
return tsr.Values[row*sz+cell]
}
func (tsr *String) SetStringRow(val string, row, cell int) {
_, sz := tsr.shape.RowCellSize()
tsr.Values[row*sz+cell] = val
}
// AppendRowString adds a row and sets string value(s), up to number of cells.
func (tsr *String) AppendRowString(val ...string) {
if tsr.NumDims() == 0 {
tsr.SetShapeSizes(0)
}
nrow, sz := tsr.shape.RowCellSize()
tsr.SetNumRows(nrow + 1)
mx := min(sz, len(val))
for i := range mx {
tsr.SetStringRow(val[i], nrow, i)
}
}
/////// Floats
func (tsr *String) Float(i ...int) float64 {
return StringToFloat64(tsr.Values[tsr.shape.IndexTo1D(i...)])
}
func (tsr *String) SetFloat(val float64, i ...int) {
tsr.Values[tsr.shape.IndexTo1D(i...)] = Float64ToString(val)
}
func (tsr *String) Float1D(i int) float64 {
return StringToFloat64(tsr.Values[NegIndex(i, len(tsr.Values))])
}
func (tsr *String) SetFloat1D(val float64, i int) {
tsr.Values[NegIndex(i, len(tsr.Values))] = Float64ToString(val)
}
func (tsr *String) FloatRow(row, cell int) float64 {
_, sz := tsr.shape.RowCellSize()
return StringToFloat64(tsr.Values[row*sz+cell])
}
func (tsr *String) SetFloatRow(val float64, row, cell int) {
_, sz := tsr.shape.RowCellSize()
tsr.Values[row*sz+cell] = Float64ToString(val)
}
// AppendRowFloat adds a row and sets float value(s), up to number of cells.
func (tsr *String) AppendRowFloat(val ...float64) {
if tsr.NumDims() == 0 {
tsr.SetShapeSizes(0)
}
nrow, sz := tsr.shape.RowCellSize()
tsr.SetNumRows(nrow + 1)
mx := min(sz, len(val))
for i := range mx {
tsr.SetFloatRow(val[i], nrow, i)
}
}
/////// Ints
func (tsr *String) Int(i ...int) int {
return errors.Ignore1(strconv.Atoi(tsr.Values[tsr.shape.IndexTo1D(i...)]))
}
func (tsr *String) SetInt(val int, i ...int) {
tsr.Values[tsr.shape.IndexTo1D(i...)] = strconv.Itoa(val)
}
func (tsr *String) Int1D(i int) int {
return errors.Ignore1(strconv.Atoi(tsr.Values[NegIndex(i, len(tsr.Values))]))
}
func (tsr *String) SetInt1D(val int, i int) {
tsr.Values[NegIndex(i, len(tsr.Values))] = strconv.Itoa(val)
}
func (tsr *String) IntRow(row, cell int) int {
_, sz := tsr.shape.RowCellSize()
return errors.Ignore1(strconv.Atoi(tsr.Values[row*sz+cell]))
}
func (tsr *String) SetIntRow(val int, row, cell int) {
_, sz := tsr.shape.RowCellSize()
tsr.Values[row*sz+cell] = strconv.Itoa(val)
}
// AppendRowInt adds a row and sets int value(s), up to number of cells.
func (tsr *String) AppendRowInt(val ...int) {
if tsr.NumDims() == 0 {
tsr.SetShapeSizes(0)
}
nrow, sz := tsr.shape.RowCellSize()
tsr.SetNumRows(nrow + 1)
mx := min(sz, len(val))
for i := range mx {
tsr.SetIntRow(val[i], nrow, i)
}
}
// SetZeros is a simple convenience function initialize all values to the
// zero value of the type (empty strings for string type).
func (tsr *String) SetZeros() {
for j := range tsr.Values {
tsr.Values[j] = ""
}
}
// Clone clones this tensor, creating a duplicate copy of itself with its
// own separate memory representation of all the values, and returns
// that as a Tensor (which can be converted into the known type as needed).
func (tsr *String) Clone() Values {
csr := NewStringShape(&tsr.shape)
copy(csr.Values, tsr.Values)
return csr
}
// CopyFrom copies all avail values from other tensor into this tensor, with an
// optimized implementation if the other tensor is of the same type, and
// otherwise it goes through appropriate standard type.
func (tsr *String) CopyFrom(frm Values) {
if fsm, ok := frm.(*String); ok {
copy(tsr.Values, fsm.Values)
return
}
sz := min(tsr.Len(), frm.Len())
for i := 0; i < sz; i++ {
tsr.Values[i] = Float64ToString(frm.Float1D(i))
}
}
// AppendFrom appends values from other tensor into this tensor,
// which must have the same cell size as this tensor.
// It uses and optimized implementation if the other tensor
// is of the same type, and otherwise it goes through
// appropriate standard type.
func (tsr *String) AppendFrom(frm Values) Values {
rows, cell := tsr.shape.RowCellSize()
frows, fcell := frm.Shape().RowCellSize()
if cell != fcell {
errors.Log(fmt.Errorf("tensor.AppendFrom: cell sizes do not match: %d != %d", cell, fcell))
return tsr
}
tsr.SetNumRows(rows + frows)
st := rows * cell
fsz := frows * fcell
if fsm, ok := frm.(*String); ok {
copy(tsr.Values[st:st+fsz], fsm.Values)
return tsr
}
for i := 0; i < fsz; i++ {
tsr.Values[st+i] = Float64ToString(frm.Float1D(i))
}
return tsr
}
// CopyCellsFrom copies given range of values from other tensor into this tensor,
// using flat 1D indexes: to = starting index in this Tensor to start copying into,
// start = starting index on from Tensor to start copying from, and n = number of
// values to copy. Uses an optimized implementation if the other tensor is
// of the same type, and otherwise it goes through appropriate standard type.
func (tsr *String) CopyCellsFrom(frm Values, to, start, n int) {
if fsm, ok := frm.(*String); ok {
for i := 0; i < n; i++ {
tsr.Values[to+i] = fsm.Values[start+i]
}
return
}
for i := 0; i < n; i++ {
tsr.Values[to+i] = Float64ToString(frm.Float1D(start + i))
}
}
// SubSpace returns a new tensor with innermost subspace at given
// offset(s) in outermost dimension(s) (len(offs) < NumDims).
// The new tensor points to the values of the this tensor (i.e., modifications
// will affect both), as its Values slice is a view onto the original (which
// is why only inner-most contiguous supsaces are supported).
// Use Clone() method to separate the two.
func (tsr *String) SubSpace(offs ...int) Values {
b := tsr.subSpaceImpl(offs...)
rt := &String{Base: *b}
return rt
}
// RowTensor is a convenience version of [RowMajor.SubSpace] to return the
// SubSpace for the outermost row dimension. [Rows] defines a version
// of this that indirects through the row indexes.
func (tsr *String) RowTensor(row int) Values {
return tsr.SubSpace(row)
}
// SetRowTensor sets the values of the SubSpace at given row to given values.
func (tsr *String) SetRowTensor(val Values, row int) {
_, cells := tsr.shape.RowCellSize()
st := row * cells
mx := min(val.Len(), cells)
tsr.CopyCellsFrom(val, st, 0, mx)
}
// AppendRow adds a row and sets values to given values.
func (tsr *String) AppendRow(val Values) {
if tsr.NumDims() == 0 {
tsr.SetShapeSizes(0)
}
nrow := tsr.DimSize(0)
tsr.SetNumRows(nrow + 1)
tsr.SetRowTensor(val, nrow)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
//go:generate core generate
import (
"fmt"
"reflect"
"cogentcore.org/core/base/metadata"
)
// DataTypes are the primary tensor data types with specific support.
// Any numerical type can also be used. bool is represented using an
// efficient bit slice.
type DataTypes interface {
string | bool | float32 | float64 | int | int64 | uint64 | int32 | uint32 | byte
}
// MaxSprintLength is the default maximum length of a String() representation
// of a tensor, as generated by the Sprint function. Defaults to 1000.
var MaxSprintLength = 1000
// todo: add a conversion function to copy data from Column-Major to a tensor:
// It is also possible to use Column-Major order, which is used in R, Julia, and MATLAB
// where the inner-most index is first and outermost last.
// Tensor is the most general interface for n-dimensional tensors.
// Per C / Go / Python conventions, indexes are Row-Major, ordered from
// outer to inner left-to-right, so the inner-most is right-most.
// It is implemented for raw [Values] with direct integer indexing
// by the [Number], [String], and [Bool] types, covering the different
// concrete types specified by [DataTypes] (see [Values] for
// additional interface methods for raw value types).
// For float32 and float64 values, use NaN to indicate missing values,
// as all of the data analysis and plot packages skip NaNs.
// View Tensor types provide different ways of viewing a source tensor,
// including [Sliced] for arbitrary slices of dimension indexes,
// [Masked] for boolean masked access of individual elements,
// and [Indexed] for arbitrary indexes of values, organized into the
// shape of the indexes, not the original source data.
// [Reshaped] provides length preserving reshaping (mostly for computational
// alignment purposes), and [Rows] provides an optimized row-indexed
// view for [table.Table] data.
type Tensor interface {
fmt.Stringer
// Label satisfies the core.Labeler interface for a summary
// description of the tensor, including metadata Name if set.
Label() string
// Metadata returns the metadata for this tensor, which can be used
// to encode name, docs, shape dimension names, plotting options, etc.
Metadata() *metadata.Data
// Shape() returns a [Shape] representation of the tensor shape
// (dimension sizes). For tensors that present a view onto another
// tensor, this typically must be constructed.
// In general, it is better to use the specific [Tensor.ShapeSizes],
// [Tensor.ShapeSizes], [Tensor.DimSize] etc as neeed.
Shape() *Shape
// ShapeSizes returns the sizes of each dimension as a slice of ints.
// The returned slice is a copy and can be modified without side effects.
ShapeSizes() []int
// Len returns the total number of elements in the tensor,
// i.e., the product of all shape dimensions.
// Len must always be such that the 1D() accessors return
// values using indexes from 0..Len()-1.
Len() int
// NumDims returns the total number of dimensions.
NumDims() int
// DimSize returns size of given dimension.
DimSize(dim int) int
// DataType returns the type of the data elements in the tensor.
// Bool is returned for the Bool tensor type.
DataType() reflect.Kind
// IsString returns true if the data type is a String; otherwise it is numeric.
IsString() bool
// AsValues returns this tensor as raw [Values]. If it already is,
// it is returned directly. If it is a View tensor, the view is
// "rendered" into a fully contiguous and optimized [Values] representation
// of that view, which will be faster to access for further processing,
// and enables all the additional functionality provided by the [Values] interface.
AsValues() Values
//////// Floats
// Float returns the value of given n-dimensional index (matching Shape) as a float64.
Float(i ...int) float64
// SetFloat sets the value of given n-dimensional index (matching Shape) as a float64.
SetFloat(val float64, i ...int)
// Float1D returns the value of given 1-dimensional index (0-Len()-1) as a float64.
// If index is negative, it indexes from the end of the list (-1 = last).
// This can be somewhat expensive in wrapper views ([Rows], [Sliced]), which
// convert the flat index back into a full n-dimensional index and use that api.
// [Tensor.FloatRow] is preferred.
Float1D(i int) float64
// SetFloat1D sets the value of given 1-dimensional index (0-Len()-1) as a float64.
// If index is negative, it indexes from the end of the list (-1 = last).
// This can be somewhat expensive in the commonly-used [Rows] view;
// [Tensor.SetFloatRow] is preferred.
SetFloat1D(val float64, i int)
//////// Strings
// StringValue returns the value of given n-dimensional index (matching Shape) as a string.
// 'String' conflicts with [fmt.Stringer], so we have to use StringValue here.
StringValue(i ...int) string
// SetString sets the value of given n-dimensional index (matching Shape) as a string.
SetString(val string, i ...int)
// String1D returns the value of given 1-dimensional index (0-Len()-1) as a string.
// If index is negative, it indexes from the end of the list (-1 = last).
String1D(i int) string
// SetString1D sets the value of given 1-dimensional index (0-Len()-1) as a string.
// If index is negative, it indexes from the end of the list (-1 = last).
SetString1D(val string, i int)
//////// Ints
// Int returns the value of given n-dimensional index (matching Shape) as a int.
Int(i ...int) int
// SetInt sets the value of given n-dimensional index (matching Shape) as a int.
SetInt(val int, i ...int)
// Int1D returns the value of given 1-dimensional index (0-Len()-1) as a int.
// If index is negative, it indexes from the end of the list (-1 = last).
Int1D(i int) int
// SetInt1D sets the value of given 1-dimensional index (0-Len()-1) as a int.
// If index is negative, it indexes from the end of the list (-1 = last).
SetInt1D(val int, i int)
}
// NegIndex handles negative index values as counting backward from n.
func NegIndex(i, n int) int {
if i < 0 {
return n + i
}
return i
}
// Copyright (c) 2020, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensormpi
import (
"fmt"
"log"
"cogentcore.org/lab/base/mpi"
)
// Alloc allocates n items to current mpi proc based on WorldSize and WorldRank.
// Returns start and end (exclusive) range for current proc.
func AllocN(n int) (st, end int, err error) {
nproc := mpi.WorldSize()
if n%nproc != 0 {
err = fmt.Errorf("tensormpi.AllocN: number: %d is not an even multiple of number of MPI procs: %d -- must be!", n, nproc)
log.Println(err)
}
pt := n / nproc
st = pt * mpi.WorldRank()
end = st + pt
return
}
// Copyright (c) 2021, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensormpi
import (
"errors"
"fmt"
"math/rand"
"cogentcore.org/lab/base/mpi"
)
// RandCheck checks that the current random numbers generated across each
// MPI processor are identical.
func RandCheck(comm *mpi.Comm) error {
ws := comm.Size()
rnd := rand.Int()
src := []int{rnd}
agg := make([]int, ws)
err := comm.AllGatherInt(agg, src)
if err != nil {
return err
}
errs := ""
for i := range agg {
if agg[i] != rnd {
errs += fmt.Sprintf("%d ", i)
}
}
if errs != "" {
err = errors.New("tensormpi.RandCheck: random numbers differ in procs: " + errs)
mpi.Printf("%s\n", err)
}
return err
}
// Copyright (c) 2020, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensormpi
import (
"cogentcore.org/lab/base/mpi"
"cogentcore.org/lab/table"
)
// GatherTableRows does an MPI AllGather on given src table data, gathering into dest.
// dest will have np * src.Rows Rows, filled with each processor's data, in order.
// dest must be a clone of src: if not same number of cols, will be configured from src.
func GatherTableRows(dest, src *table.Table, comm *mpi.Comm) {
sr := src.NumRows()
np := mpi.WorldSize()
dr := np * sr
if dest.NumColumns() != src.NumColumns() {
*dest = *src.Clone()
}
dest.SetNumRows(dr)
for ci, st := range src.Columns.Values {
dt := dest.Columns.Values[ci]
GatherTensorRows(dt, st, comm)
}
}
// ReduceTable does an MPI AllReduce on given src table data using given operation,
// gathering into dest.
// each processor must have the same table organization -- the tensor values are
// just aggregated directly across processors.
// dest will be a clone of src if not the same (cos & rows),
// does nothing for strings.
func ReduceTable(dest, src *table.Table, comm *mpi.Comm, op mpi.Op) {
sr := src.NumRows()
if dest.NumColumns() != src.NumColumns() {
*dest = *src.Clone()
}
dest.SetNumRows(sr)
for ci, st := range src.Columns.Values {
dt := dest.Columns.Values[ci]
ReduceTensor(dt, st, comm, op)
}
}
// Copyright (c) 2020, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensormpi
import (
"reflect"
"cogentcore.org/lab/base/mpi"
"cogentcore.org/lab/tensor"
)
// GatherTensorRows does an MPI AllGather on given src tensor data, gathering into dest,
// using a row-based tensor organization (as in an table.Table).
// dest will have np * src.Rows Rows, filled with each processor's data, in order.
// dest must have same overall shape as src at start, but rows will be enforced.
func GatherTensorRows(dest, src tensor.Values, comm *mpi.Comm) error {
dt := src.DataType()
if dt == reflect.String {
return GatherTensorRowsString(dest.(*tensor.String), src.(*tensor.String), comm)
}
sr, _ := src.Shape().RowCellSize()
dr, _ := dest.Shape().RowCellSize()
np := mpi.WorldSize()
dl := np * sr
if dr != dl {
dest.SetNumRows(dl)
dr = dl
}
var err error
switch dt {
case reflect.Bool:
// todo
case reflect.Uint8:
dt := dest.(*tensor.Byte)
st := src.(*tensor.Byte)
err = comm.AllGatherU8(dt.Values, st.Values)
case reflect.Int32:
dt := dest.(*tensor.Int32)
st := src.(*tensor.Int32)
err = comm.AllGatherI32(dt.Values, st.Values)
case reflect.Int:
dt := dest.(*tensor.Int)
st := src.(*tensor.Int)
err = comm.AllGatherInt(dt.Values, st.Values)
case reflect.Float32:
dt := dest.(*tensor.Float32)
st := src.(*tensor.Float32)
err = comm.AllGatherF32(dt.Values, st.Values)
case reflect.Float64:
dt := dest.(*tensor.Float64)
st := src.(*tensor.Float64)
err = comm.AllGatherF64(dt.Values, st.Values)
}
return err
}
// GatherTensorRowsString does an MPI AllGather on given String src tensor data,
// gathering into dest, using a row-based tensor organization (as in an table.Table).
// dest will have np * src.Rows Rows, filled with each processor's data, in order.
// dest must have same overall shape as src at start, but rows will be enforced.
func GatherTensorRowsString(dest, src *tensor.String, comm *mpi.Comm) error {
sr, _ := src.Shape().RowCellSize()
dr, _ := dest.Shape().RowCellSize()
np := mpi.WorldSize()
dl := np * sr
if dr != dl {
dest.SetNumRows(dl)
dr = dl
}
ssz := len(src.Values)
dsz := len(dest.Values)
sln := make([]int, ssz)
dln := make([]int, dsz)
for i, s := range src.Values {
sln[i] = len(s)
}
err := comm.AllGatherInt(dln, sln)
if err != nil {
return err
}
mxlen := 0
for _, l := range dln {
mxlen = max(mxlen, l)
}
if mxlen == 0 {
return nil // nothing to transfer
}
sdt := make([]byte, ssz*mxlen)
ddt := make([]byte, dsz*mxlen)
idx := 0
for _, s := range src.Values {
l := len(s)
copy(sdt[idx:idx+l], []byte(s))
idx += mxlen
}
err = comm.AllGatherU8(ddt, sdt)
idx = 0
for i := range dest.Values {
l := dln[i]
s := string(ddt[idx : idx+l])
dest.Values[i] = s
idx += mxlen
}
return err
}
// ReduceTensor does an MPI AllReduce on given src tensor data, using given operation,
// gathering into dest. dest must have same overall shape as src -- will be enforced.
// IMPORTANT: src and dest must be different slices!
// each processor must have the same shape and organization for this to make sense.
// does nothing for strings.
func ReduceTensor(dest, src tensor.Values, comm *mpi.Comm, op mpi.Op) error {
dt := src.DataType()
if dt == reflect.String {
return nil
}
slen := src.Len()
if slen != dest.Len() {
tensor.SetShapeFrom(dest, src)
}
var err error
switch dt {
case reflect.Bool:
dt := dest.(*tensor.Bool)
st := src.(*tensor.Bool)
err = comm.AllReduceU8(op, dt.Values, st.Values)
case reflect.Uint8:
dt := dest.(*tensor.Byte)
st := src.(*tensor.Byte)
err = comm.AllReduceU8(op, dt.Values, st.Values)
case reflect.Int32:
dt := dest.(*tensor.Int32)
st := src.(*tensor.Int32)
err = comm.AllReduceI32(op, dt.Values, st.Values)
case reflect.Int:
dt := dest.(*tensor.Int)
st := src.(*tensor.Int)
err = comm.AllReduceInt(op, dt.Values, st.Values)
case reflect.Float32:
dt := dest.(*tensor.Float32)
st := src.(*tensor.Float32)
err = comm.AllReduceF32(op, dt.Values, st.Values)
case reflect.Float64:
dt := dest.(*tensor.Float64)
st := src.(*tensor.Float64)
err = comm.AllReduceF64(op, dt.Values, st.Values)
}
return err
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tmath
import (
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/tensor"
)
// Equal stores in the output the bool value a == b.
func Equal(a, b tensor.Tensor) *tensor.Bool {
return tensor.CallOut2Bool(EqualOut, a, b)
}
// EqualOut stores in the output the bool value a == b.
func EqualOut(a, b tensor.Tensor, out *tensor.Bool) error {
if a.IsString() {
return tensor.BoolStringsFuncOut(func(a, b string) bool { return a == b }, a, b, out)
}
return tensor.BoolFloatsFuncOut(func(a, b float64) bool { return a == b }, a, b, out)
}
// Less stores in the output the bool value a < b.
func Less(a, b tensor.Tensor) *tensor.Bool {
return tensor.CallOut2Bool(LessOut, a, b)
}
// LessOut stores in the output the bool value a < b.
func LessOut(a, b tensor.Tensor, out *tensor.Bool) error {
if a.IsString() {
return tensor.BoolStringsFuncOut(func(a, b string) bool { return a < b }, a, b, out)
}
return tensor.BoolFloatsFuncOut(func(a, b float64) bool { return a < b }, a, b, out)
}
// Greater stores in the output the bool value a > b.
func Greater(a, b tensor.Tensor) *tensor.Bool {
return tensor.CallOut2Bool(GreaterOut, a, b)
}
// GreaterOut stores in the output the bool value a > b.
func GreaterOut(a, b tensor.Tensor, out *tensor.Bool) error {
if a.IsString() {
return tensor.BoolStringsFuncOut(func(a, b string) bool { return a > b }, a, b, out)
}
return tensor.BoolFloatsFuncOut(func(a, b float64) bool { return a > b }, a, b, out)
}
// NotEqual stores in the output the bool value a != b.
func NotEqual(a, b tensor.Tensor) *tensor.Bool {
return tensor.CallOut2Bool(NotEqualOut, a, b)
}
// NotEqualOut stores in the output the bool value a != b.
func NotEqualOut(a, b tensor.Tensor, out *tensor.Bool) error {
if a.IsString() {
return tensor.BoolStringsFuncOut(func(a, b string) bool { return a != b }, a, b, out)
}
return tensor.BoolFloatsFuncOut(func(a, b float64) bool { return a != b }, a, b, out)
}
// LessEqual stores in the output the bool value a <= b.
func LessEqual(a, b tensor.Tensor) *tensor.Bool {
return tensor.CallOut2Bool(LessEqualOut, a, b)
}
// LessEqualOut stores in the output the bool value a <= b.
func LessEqualOut(a, b tensor.Tensor, out *tensor.Bool) error {
if a.IsString() {
return tensor.BoolStringsFuncOut(func(a, b string) bool { return a <= b }, a, b, out)
}
return tensor.BoolFloatsFuncOut(func(a, b float64) bool { return a <= b }, a, b, out)
}
// GreaterEqual stores in the output the bool value a >= b.
func GreaterEqual(a, b tensor.Tensor) *tensor.Bool {
return tensor.CallOut2Bool(GreaterEqualOut, a, b)
}
// GreaterEqualOut stores in the output the bool value a >= b.
func GreaterEqualOut(a, b tensor.Tensor, out *tensor.Bool) error {
if a.IsString() {
return tensor.BoolStringsFuncOut(func(a, b string) bool { return a >= b }, a, b, out)
}
return tensor.BoolFloatsFuncOut(func(a, b float64) bool { return a >= b }, a, b, out)
}
// Or stores in the output the bool value a || b.
func Or(a, b tensor.Tensor) *tensor.Bool {
return tensor.CallOut2Bool(OrOut, a, b)
}
// OrOut stores in the output the bool value a || b.
func OrOut(a, b tensor.Tensor, out *tensor.Bool) error {
return tensor.BoolIntsFuncOut(func(a, b int) bool { return a > 0 || b > 0 }, a, b, out)
}
// And stores in the output the bool value a || b.
func And(a, b tensor.Tensor) *tensor.Bool {
return tensor.CallOut2Bool(AndOut, a, b)
}
// AndOut stores in the output the bool value a || b.
func AndOut(a, b tensor.Tensor, out *tensor.Bool) error {
return tensor.BoolIntsFuncOut(func(a, b int) bool { return a > 0 && b > 0 }, a, b, out)
}
// Not stores in the output the bool value !a.
func Not(a tensor.Tensor) *tensor.Bool {
out := tensor.NewBool()
errors.Log(NotOut(a, out))
return out
}
// NotOut stores in the output the bool value !a.
func NotOut(a tensor.Tensor, out *tensor.Bool) error {
out.SetShapeSizes(a.Shape().Sizes...)
alen := a.Len()
tensor.VectorizeThreaded(1, func(tsr ...tensor.Tensor) int { return alen },
func(idx int, tsr ...tensor.Tensor) {
out.SetBool1D(tsr[0].Int1D(idx) == 0, idx)
}, a, out)
return nil
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tmath
import (
"math"
"cogentcore.org/lab/tensor"
)
func Abs(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(AbsOut, in)
}
func AbsOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Abs(a) }, in, out)
}
func Acos(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(AcosOut, in)
}
func AcosOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Acos(a) }, in, out)
}
func Acosh(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(AcoshOut, in)
}
func AcoshOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Acosh(a) }, in, out)
}
func Asin(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(AsinOut, in)
}
func AsinOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Asin(a) }, in, out)
}
func Asinh(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(AsinhOut, in)
}
func AsinhOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Asinh(a) }, in, out)
}
func Atan(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(AtanOut, in)
}
func AtanOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Atan(a) }, in, out)
}
func Atanh(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(AtanhOut, in)
}
func AtanhOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Atanh(a) }, in, out)
}
func Cbrt(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(CbrtOut, in)
}
func CbrtOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Cbrt(a) }, in, out)
}
func Ceil(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(CeilOut, in)
}
func CeilOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Ceil(a) }, in, out)
}
func Cos(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(CosOut, in)
}
func CosOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Cos(a) }, in, out)
}
func Cosh(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(CoshOut, in)
}
func CoshOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Cosh(a) }, in, out)
}
func Erf(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(ErfOut, in)
}
func ErfOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Erf(a) }, in, out)
}
func Erfc(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(ErfcOut, in)
}
func ErfcOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Erfc(a) }, in, out)
}
func Erfcinv(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(ErfcinvOut, in)
}
func ErfcinvOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Erfcinv(a) }, in, out)
}
func Erfinv(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(ErfinvOut, in)
}
func ErfinvOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Erfinv(a) }, in, out)
}
func Exp(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(ExpOut, in)
}
func ExpOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Exp(a) }, in, out)
}
func Exp2(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(Exp2Out, in)
}
func Exp2Out(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Exp2(a) }, in, out)
}
func Expm1(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(Expm1Out, in)
}
func Expm1Out(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Expm1(a) }, in, out)
}
func Floor(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(FloorOut, in)
}
func FloorOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Floor(a) }, in, out)
}
func Gamma(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(GammaOut, in)
}
func GammaOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Gamma(a) }, in, out)
}
func J0(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(J0Out, in)
}
func J0Out(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.J0(a) }, in, out)
}
func J1(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(J1Out, in)
}
func J1Out(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.J1(a) }, in, out)
}
func Log(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(LogOut, in)
}
func LogOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Log(a) }, in, out)
}
func Log10(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(Log10Out, in)
}
func Log10Out(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Log10(a) }, in, out)
}
func Log1p(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(Log1pOut, in)
}
func Log1pOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Log1p(a) }, in, out)
}
func Log2(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(Log2Out, in)
}
func Log2Out(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Log2(a) }, in, out)
}
func Logb(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(LogbOut, in)
}
func LogbOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Logb(a) }, in, out)
}
func Round(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(RoundOut, in)
}
func RoundOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Round(a) }, in, out)
}
func RoundToEven(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(RoundToEvenOut, in)
}
func RoundToEvenOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.RoundToEven(a) }, in, out)
}
func Sin(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(SinOut, in)
}
func SinOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Sin(a) }, in, out)
}
func Sinh(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(SinhOut, in)
}
func SinhOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Sinh(a) }, in, out)
}
func Sqrt(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(SqrtOut, in)
}
func SqrtOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Sqrt(a) }, in, out)
}
func Tan(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(TanOut, in)
}
func TanOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Tan(a) }, in, out)
}
func Tanh(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(TanhOut, in)
}
func TanhOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Tanh(a) }, in, out)
}
func Trunc(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(TruncOut, in)
}
func TruncOut(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Trunc(a) }, in, out)
}
func Y0(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(Y0Out, in)
}
func Y0Out(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Y0(a) }, in, out)
}
func Y1(in tensor.Tensor) tensor.Values {
return tensor.CallOut1Float64(Y1Out, in)
}
func Y1Out(in tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(a float64) float64 { return math.Y1(a) }, in, out)
}
//////// Binary
func Atan2(y, x tensor.Tensor) tensor.Values {
return tensor.CallOut2(Atan2Out, y, x)
}
func Atan2Out(y, x tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Atan2(a, b) }, y, x, out)
}
func Copysign(x, y tensor.Tensor) tensor.Values {
return tensor.CallOut2(CopysignOut, x, y)
}
func CopysignOut(x, y tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Copysign(a, b) }, x, y, out)
}
func Dim(x, y tensor.Tensor) tensor.Values {
return tensor.CallOut2(DimOut, x, y)
}
func DimOut(x, y tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Dim(a, b) }, x, y, out)
}
func Hypot(x, y tensor.Tensor) tensor.Values {
return tensor.CallOut2(HypotOut, x, y)
}
func HypotOut(x, y tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Hypot(a, b) }, x, y, out)
}
func Max(x, y tensor.Tensor) tensor.Values {
return tensor.CallOut2(MaxOut, x, y)
}
func MaxOut(x, y tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Max(a, b) }, x, y, out)
}
func Min(x, y tensor.Tensor) tensor.Values {
return tensor.CallOut2(MinOut, x, y)
}
func MinOut(x, y tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Min(a, b) }, x, y, out)
}
func Nextafter(x, y tensor.Tensor) tensor.Values {
return tensor.CallOut2(NextafterOut, x, y)
}
func NextafterOut(x, y tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Nextafter(a, b) }, x, y, out)
}
func Pow(x, y tensor.Tensor) tensor.Values {
return tensor.CallOut2(PowOut, x, y)
}
func PowOut(x, y tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Pow(a, b) }, x, y, out)
}
func Remainder(x, y tensor.Tensor) tensor.Values {
return tensor.CallOut2(RemainderOut, x, y)
}
func RemainderOut(x, y tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Remainder(a, b) }, x, y, out)
}
/*
func Nextafter32(x, y float32) (r float32)
func Inf(sign int) float64
func IsInf(f float64, sign int) bool
func IsNaN(f float64) (is bool)
func NaN() float64
func Signbit(x float64) bool
func Float32bits(f float32) uint32
func Float32frombits(b uint32) float32
func Float64bits(f float64) uint64
func Float64frombits(b uint64) float64
func FMA(x, y, z float64) float64
func Jn(n int, in tensor.Tensor, out tensor.Values)
func Yn(n int, in tensor.Tensor, out tensor.Values)
func Ldexp(frac float64, exp int) float64
func Ilogb(x float64) int
func Pow10(n int) float64
func Frexp(f float64) (frac float64, exp int)
func Modf(f float64) (int float64, frac float64)
func Lgamma(x float64) (lgamma float64, sign int)
func Sincos(x float64) (sin, cos float64)
*/
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tmath
import (
"math"
"cogentcore.org/lab/tensor"
)
// Assign assigns values from b into a.
func Assign(a, b tensor.Tensor) error {
return tensor.FloatAssignFunc(func(a, b float64) float64 { return b }, a, b)
}
// AddAssign does += add assign values from b into a.
func AddAssign(a, b tensor.Tensor) error {
if a.IsString() {
return tensor.StringAssignFunc(func(a, b string) string { return a + b }, a, b)
}
return tensor.FloatAssignFunc(func(a, b float64) float64 { return a + b }, a, b)
}
// SubAssign does -= sub assign values from b into a.
func SubAssign(a, b tensor.Tensor) error {
return tensor.FloatAssignFunc(func(a, b float64) float64 { return a - b }, a, b)
}
// MulAssign does *= mul assign values from b into a.
func MulAssign(a, b tensor.Tensor) error {
return tensor.FloatAssignFunc(func(a, b float64) float64 { return a * b }, a, b)
}
// DivAssign does /= divide assign values from b into a.
func DivAssign(a, b tensor.Tensor) error {
return tensor.FloatAssignFunc(func(a, b float64) float64 { return a / b }, a, b)
}
// ModAssign does %= modulus assign values from b into a.
func ModAssign(a, b tensor.Tensor) error {
return tensor.FloatAssignFunc(func(a, b float64) float64 { return math.Mod(a, b) }, a, b)
}
// Inc increments values in given tensor by 1.
func Inc(a tensor.Tensor) error {
alen := a.Len()
tensor.VectorizeThreaded(1, func(tsr ...tensor.Tensor) int { return alen },
func(idx int, tsr ...tensor.Tensor) {
tsr[0].SetFloat1D(tsr[0].Float1D(idx)+1.0, idx)
}, a)
return nil
}
// Dec decrements values in given tensor by 1.
func Dec(a tensor.Tensor) error {
alen := a.Len()
tensor.VectorizeThreaded(1, func(tsr ...tensor.Tensor) int { return alen },
func(idx int, tsr ...tensor.Tensor) {
tsr[0].SetFloat1D(tsr[0].Float1D(idx)-1.0, idx)
}, a)
return nil
}
// Add adds two tensors into output.
func Add(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(AddOut, a, b)
}
// AddOut adds two tensors into output.
func AddOut(a, b tensor.Tensor, out tensor.Values) error {
if a.IsString() {
return tensor.StringBinaryFuncOut(func(a, b string) string { return a + b }, a, b, out)
}
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return a + b }, a, b, out)
}
// Sub subtracts tensors into output.
func Sub(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(SubOut, a, b)
}
// SubOut subtracts two tensors into output.
func SubOut(a, b tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return a - b }, a, b, out)
}
// Mul multiplies tensors into output.
func Mul(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(MulOut, a, b)
}
// MulOut multiplies two tensors into output.
func MulOut(a, b tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return a * b }, a, b, out)
}
// Div divides tensors into output. always does floating point division,
// even with integer operands.
func Div(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2Float64(DivOut, a, b)
}
// DivOut divides two tensors into output.
func DivOut(a, b tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return a / b }, a, b, out)
}
// Mod performs modulus a%b on tensors into output.
func Mod(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2(ModOut, a, b)
}
// ModOut performs modulus a%b on tensors into output.
func ModOut(a, b tensor.Tensor, out tensor.Values) error {
return tensor.FloatBinaryFuncOut(1, func(a, b float64) float64 { return math.Mod(a, b) }, a, b, out)
}
// Negate stores in the output the bool value -a.
func Negate(a tensor.Tensor) tensor.Values {
return tensor.CallOut1(NegateOut, a)
}
// NegateOut stores in the output the bool value -a.
func NegateOut(a tensor.Tensor, out tensor.Values) error {
return tensor.FloatFuncOut(1, func(in float64) float64 { return -in }, a, out)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"fmt"
"reflect"
"cogentcore.org/core/base/metadata"
)
// Values is an extended [Tensor] interface for raw value tensors.
// This supports direct setting of the shape of the underlying values,
// sub-space access to inner-dimensional subspaces of values, etc.
type Values interface {
RowMajor
// SetShapeSizes sets the dimension sizes of the tensor, and resizes
// backing storage appropriately, retaining all existing data that fits.
SetShapeSizes(sizes ...int)
// SetNumRows sets the number of rows (outermost dimension).
// It is safe to set this to 0. For incrementally growing tensors (e.g., a log)
// it is best to first set the anticipated full size, which allocates the
// full amount of memory, and then set to 0 and grow incrementally.
SetNumRows(rows int)
// Sizeof returns the number of bytes contained in the Values of this tensor.
// For String types, this is just the string pointers, not the string content.
Sizeof() int64
// Bytes returns the underlying byte representation of the tensor values.
// This is the actual underlying data, so make a copy if it can be
// unintentionally modified or retained more than for immediate use.
// For the [String] type, this is a len int + bytes encoding of each string.
Bytes() []byte
// SetFromBytes sets the values from given bytes. See [Values.Bytes] for
// the [String] encoding.
SetFromBytes(b []byte)
// SetZeros is a convenience function initialize all values to the
// zero value of the type (empty strings for string type).
// New tensors always start out with zeros.
SetZeros()
// Clone clones this tensor, creating a duplicate copy of itself with its
// own separate memory representation of all the values.
Clone() Values
// CopyFrom copies all values from other tensor into this tensor, with an
// optimized implementation if the other tensor is of the same type, and
// otherwise it goes through the appropriate standard type (Float, Int, String).
CopyFrom(from Values)
// CopyCellsFrom copies given range of values from other tensor into this tensor,
// using flat 1D indexes: to = starting index in this Tensor to start copying into,
// start = starting index on from Tensor to start copying from, and n = number of
// values to copy. Uses an optimized implementation if the other tensor is
// of the same type, and otherwise it goes through appropriate standard type.
CopyCellsFrom(from Values, to, start, n int)
// AppendFrom appends all values from other tensor into this tensor, with an
// optimized implementation if the other tensor is of the same type, and
// otherwise it goes through the appropriate standard type (Float, Int, String).
AppendFrom(from Values) Values
}
// New returns a new n-dimensional tensor of given value type
// with the given sizes per dimension (shape).
func New[T DataTypes](sizes ...int) Values {
var v T
switch any(v).(type) {
case string:
return NewString(sizes...)
case bool:
return NewBool(sizes...)
case float64:
return NewNumber[float64](sizes...)
case float32:
return NewNumber[float32](sizes...)
case int:
return NewNumber[int](sizes...)
case int64:
return NewNumber[int64](sizes...)
case uint64:
return NewNumber[uint64](sizes...)
case int32:
return NewNumber[int32](sizes...)
case uint32:
return NewNumber[uint32](sizes...)
case byte:
return NewNumber[byte](sizes...)
default:
panic("tensor.New: unexpected error: type not supported")
}
}
// NewOfType returns a new n-dimensional tensor of given reflect.Kind type
// with the given sizes per dimension (shape).
// Types supported are listed in [DataTypes].
func NewOfType(typ reflect.Kind, sizes ...int) Values {
switch typ {
case reflect.String:
return NewString(sizes...)
case reflect.Bool:
return NewBool(sizes...)
case reflect.Float64:
return NewNumber[float64](sizes...)
case reflect.Float32:
return NewNumber[float32](sizes...)
case reflect.Int:
return NewNumber[int](sizes...)
case reflect.Int64:
return NewNumber[int64](sizes...)
case reflect.Uint64:
return NewNumber[uint64](sizes...)
case reflect.Int32:
return NewNumber[int32](sizes...)
case reflect.Uint32:
return NewNumber[uint32](sizes...)
case reflect.Uint8:
return NewNumber[byte](sizes...)
default:
panic(fmt.Sprintf("tensor.NewOfType: type not supported: %v", typ))
}
}
// metadata helpers
// SetShapeNames sets the tensor shape dimension names into given metadata.
func SetShapeNames(md *metadata.Data, names ...string) {
md.Set("ShapeNames", names)
}
// ShapeNames gets the tensor shape dimension names from given metadata.
func ShapeNames(md *metadata.Data) []string {
names, _ := metadata.Get[[]string](*md, "ShapeNames")
return names
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensor
import (
"math"
"runtime"
"sync"
)
var (
// ThreadingThreshod is the threshold in number of flops (floating point ops),
// computed as tensor N * flops per element, to engage actual parallel processing.
// Heuristically, numbers below this threshold do not result in
// an overall speedup, due to overhead costs. See tmath/ops_test.go for benchmark.
ThreadingThreshold = 300
// NumThreads is the number of threads to use for parallel threading.
// The default of 0 causes the [runtime.GOMAXPROCS] to be used.
NumThreads = 0
)
// Vectorize applies given function 'fun' to tensor elements indexed
// by given index, with the 'nfun' providing the number of indexes
// to vectorize over, and initializing any output vectors.
// Thus the nfun is often specific to a particular class of functions.
// Both functions are called with the same set
// of Tensors passed as the final argument(s).
// The role of each tensor is function-dependent: there could be multiple
// inputs and outputs, and the output could be effectively scalar,
// as in a sum operation. The interpretation of the index is
// function dependent as well, but often is used to iterate over
// the outermost row dimension of the tensor.
// This version runs purely sequentially on on this go routine.
// See VectorizeThreaded and VectorizeGPU for other versions.
func Vectorize(nfun func(tsr ...Tensor) int, fun func(idx int, tsr ...Tensor), tsr ...Tensor) {
n := nfun(tsr...)
if n <= 0 {
return
}
for idx := range n {
fun(idx, tsr...)
}
}
// VectorizeThreaded is a version of [Vectorize] that will automatically
// distribute the computation in parallel across multiple "threads" (goroutines)
// if the number of elements to be computed times the given flops
// (floating point operations) for the function exceeds the [ThreadingThreshold].
// Heuristically, numbers below this threshold do not result
// in an overall speedup, due to overhead costs.
// Each elemental math operation in the function adds a flop.
// See estimates in [tmath] for basic math functions.
func VectorizeThreaded(flops int, nfun func(tsr ...Tensor) int, fun func(idx int, tsr ...Tensor), tsr ...Tensor) {
n := nfun(tsr...)
if n <= 0 {
return
}
if flops < 0 {
flops = 1
}
if n*flops < ThreadingThreshold {
Vectorize(nfun, fun, tsr...)
return
}
VectorizeOnThreads(0, nfun, fun, tsr...)
}
// DefaultNumThreads returns the default number of threads to use:
// NumThreads if non-zero, otherwise [runtime.GOMAXPROCS].
func DefaultNumThreads() int {
if NumThreads > 0 {
return NumThreads
}
return runtime.GOMAXPROCS(0)
}
// VectorizeOnThreads runs given [Vectorize] function on given number
// of threads. Use [VectorizeThreaded] to only use parallel threads when
// it is likely to be beneficial, in terms of the ThreadingThreshold.
// If threads is 0, then the [DefaultNumThreads] will be used:
// GOMAXPROCS subject to NumThreads constraint if non-zero.
func VectorizeOnThreads(threads int, nfun func(tsr ...Tensor) int, fun func(idx int, tsr ...Tensor), tsr ...Tensor) {
if threads == 0 {
threads = DefaultNumThreads()
}
n := nfun(tsr...)
if n <= 0 {
return
}
nper := int(math.Ceil(float64(n) / float64(threads)))
wait := sync.WaitGroup{}
for start := 0; start < n; start += nper {
end := start + nper
if end > n {
end = n
}
wait.Add(1) // todo: move out of loop
go func() {
for idx := start; idx < end; idx++ {
fun(idx, tsr...)
}
wait.Done()
}()
}
wait.Wait()
}
// NFirstRows is an N function for Vectorize that returns the number of
// outer-dimension rows (or Indexes) of the first tensor.
func NFirstRows(tsr ...Tensor) int {
if len(tsr) == 0 {
return 0
}
return tsr[0].DimSize(0)
}
// NFirstLen is an N function for Vectorize that returns the number of
// elements in the tensor, taking into account the Indexes view.
func NFirstLen(tsr ...Tensor) int {
if len(tsr) == 0 {
return 0
}
return tsr[0].Len()
}
// NMinLen is an N function for Vectorize that returns the min number of
// elements across given number of tensors in the list. Use a closure
// to call this with the nt.
func NMinLen(nt int, tsr ...Tensor) int {
nt = min(len(tsr), nt)
if nt == 0 {
return 0
}
n := tsr[0].Len()
for i := 1; i < nt; i++ {
n = min(n, tsr[0].Len())
}
return n
}
// Copyright (c) 2019, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensorcore
import (
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/core"
"cogentcore.org/core/math32/minmax"
)
// Layout are layout options for displaying tensors.
type Layout struct { //types:add --setters
// OddRow means that even-numbered dimensions are displayed as Y*X rectangles.
// This determines along which dimension to display any remaining
// odd dimension: OddRow = true = organize vertically along row
// dimension, false = organize horizontally across column dimension.
OddRow bool
// TopZero means that the Y=0 coordinate is displayed from the top-down;
// otherwise the Y=0 coordinate is displayed from the bottom up,
// which is typical for emergent network patterns.
TopZero bool
// Image will display the data as a bitmap image. If a 2D tensor, then it will
// be a greyscale image. If a 3D tensor with size of either the first
// or last dim = either 3 or 4, then it is a RGB(A) color image.
Image bool
}
// GridStyle are options for displaying tensors
type GridStyle struct { //types:add --setters
Layout
// Range to plot
Range minmax.Range64 `display:"inline"`
// MinMax has the actual range of data, if not using fixed Range.
MinMax minmax.F64 `display:"inline"`
// ColorMap is the name of the color map to use in translating values to colors.
ColorMap core.ColorMapName
// GridFill sets proportion of grid square filled by the color block:
// 1 = all, .5 = half, etc.
GridFill float32 `min:"0.1" max:"1" step:"0.1" default:"0.9,1"`
// DimExtra is the amount of extra space to add at dimension boundaries,
// as a proportion of total grid size.
DimExtra float32 `min:"0" max:"1" step:"0.02" default:"0.1,0.3"`
// Size sets the minimum and maximum size for grid squares.
Size minmax.F32 `display:"inline"`
// TotalSize sets the total preferred display size along largest dimension.
// Grid squares will be sized to fit within this size,
// subject to the Size.Min / Max constraints, which have precedence.
TotalSize float32
// FontSize is the font size in standard Dp units for labels.
FontSize float32
// ColumnRotation is the rotation angle in degrees for column labels
ColumnRotation float32 `default:"90"`
}
// Defaults sets defaults for values that are at nonsensical initial values
func (gs *GridStyle) Defaults() {
gs.Range.SetMin(-1).SetMax(1)
gs.ColorMap = "ColdHot"
gs.GridFill = 0.9
gs.DimExtra = 0.3
gs.Size.Set(2, 32)
gs.TotalSize = 100
gs.FontSize = 16
gs.ColumnRotation = 90
}
// NewGridStyle returns a new GridStyle with defaults.
func NewGridStyle() *GridStyle {
gs := &GridStyle{}
gs.Defaults()
return gs
}
func (gs *GridStyle) ApplyStylersFrom(obj any) {
st := GetGridStylersFrom(obj)
if st == nil {
return
}
st.Run(gs)
}
// GridStylers is a list of styling functions that set GridStyle properties.
// These are called in the order added.
type GridStylers []func(s *GridStyle)
// Add Adds a styling function to the list.
func (st *GridStylers) Add(f func(s *GridStyle)) {
*st = append(*st, f)
}
// Run runs the list of styling functions on given [GridStyle] object.
func (st *GridStylers) Run(s *GridStyle) {
for _, f := range *st {
f(s)
}
}
// SetGridStylersTo sets the [GridStylers] into given object's [metadata].
func SetGridStylersTo(obj any, st GridStylers) {
metadata.Set(obj, "GridStylers", st)
}
// GetGridStylersFrom returns [GridStylers] from given object's [metadata].
// Returns nil if none or no metadata.
func GetGridStylersFrom(obj any) GridStylers {
st, _ := metadata.Get[GridStylers](obj, "GridStylers")
return st
}
// AddGridStylerTo adds the given [GridStyler] function into given object's [metadata].
func AddGridStylerTo(obj any, f func(s *GridStyle)) {
st := GetGridStylersFrom(obj)
st.Add(f)
SetGridStylersTo(obj, st)
}
// Copyright (c) 2023, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package tensorcore provides GUI Cogent Core widgets for tensor types.
package tensorcore
//go:generate core generate
import (
"bytes"
"encoding/csv"
"fmt"
"image"
"log"
"strconv"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/fileinfo"
"cogentcore.org/core/base/fileinfo/mimedata"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/icons"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/states"
"cogentcore.org/core/styles/units"
"cogentcore.org/core/tree"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
)
// Table provides a GUI widget for representing [table.Table] values.
type Table struct {
core.ListBase
// Table is the table that we're a view of.
Table *table.Table `set:"-"`
// GridStyle has global grid display styles. GridStylers on the Table
// are applied to this on top of defaults.
GridStyle GridStyle `set:"-"`
// ColumnGridStyle has per column grid display styles.
ColumnGridStyle map[int]*GridStyle `set:"-"`
// current sort index.
SortIndex int
// whether current sort order is descending.
SortDescending bool
// number of columns in table (as of last update).
nCols int `edit:"-"`
// headerWidths has number of characters in each header, per visfields.
headerWidths []int `copier:"-" display:"-" json:"-" xml:"-"`
// colMaxWidths records maximum width in chars of string type fields.
colMaxWidths []int `set:"-" copier:"-" json:"-" xml:"-"`
// blank values for out-of-range rows.
blankString string
blankFloat float64
// blankCells has per column blank tensor cells.
blankCells map[int]*tensor.Float64 `set:"-"`
}
// check for interface impl
var _ core.Lister = (*Table)(nil)
func (tb *Table) Init() {
tb.ListBase.Init()
tb.SortIndex = -1
tb.GridStyle.Defaults()
tb.ColumnGridStyle = map[int]*GridStyle{}
tb.blankCells = map[int]*tensor.Float64{}
tb.Makers.Normal[0] = func(p *tree.Plan) { // TODO: reduce redundancy with ListBase Maker
svi := tb.This.(core.Lister)
svi.UpdateSliceSize()
scrollTo := -1
if tb.InitSelectedIndex >= 0 {
tb.SelectedIndex = tb.InitSelectedIndex
tb.InitSelectedIndex = -1
scrollTo = tb.SelectedIndex
}
if scrollTo >= 0 {
tb.ScrollToIndex(scrollTo)
}
tb.UpdateStartIndex()
tb.UpdateMaxWidths()
tb.Updater(func() {
tb.UpdateStartIndex()
})
tb.MakeHeader(p)
tb.MakeGrid(p, func(p *tree.Plan) {
for i := 0; i < tb.VisibleRows; i++ {
svi.MakeRow(p, i)
}
})
}
}
func (tb *Table) SliceIndex(i int) (si, vi int, invis bool) {
si = tb.StartIndex + i
vi = -1
if si < tb.Table.NumRows() {
vi = tb.Table.RowIndex(si)
}
invis = vi < 0
return
}
// StyleValue performs additional value widget styling
func (tb *Table) StyleValue(w core.Widget, s *styles.Style, row, col int) {
hw := float32(tb.headerWidths[col])
if col == tb.SortIndex {
hw += 6
}
if len(tb.colMaxWidths) > col {
hw = max(float32(tb.colMaxWidths[col]), hw)
}
hv := units.Ch(1.1 * hw) // 1.1 works
s.Min.X.Value = max(s.Min.X.Value, hv.Convert(s.Min.X.Unit, &s.UnitContext).Value)
s.SetTextWrap(false)
}
// SetTable sets the source table that we are viewing, using a sequential view,
// and then configures the display
func (tb *Table) SetTable(dt *table.Table) *Table {
if dt == nil {
tb.Table = nil
} else {
tb.Table = table.NewView(dt)
tb.GridStyle.ApplyStylersFrom(tb.Table)
}
tb.This.(core.Lister).UpdateSliceSize()
tb.SetSliceBase()
tb.Update()
return tb
}
// SetSlice sets the source table to a [table.NewSliceTable]
// from the given slice.
func (tb *Table) SetSlice(sl any) *Table {
return tb.SetTable(errors.Log1(table.NewSliceTable(sl)))
}
// AsyncUpdateTable updates the display for asynchronous updating from
// other goroutines. Also updates indexview (calling Sequential).
func (tb *Table) AsyncUpdateTable() {
tb.AsyncLock()
tb.Table.Sequential()
tb.ScrollToIndexNoUpdate(tb.SliceSize - 1)
tb.Update()
tb.AsyncUnlock()
}
func (tb *Table) UpdateSliceSize() int {
tb.Table.ValidIndexes() // table could have changed
if tb.Table.NumRows() == 0 {
tb.Table.Sequential()
}
tb.SliceSize = tb.Table.NumRows()
tb.nCols = tb.Table.NumColumns()
return tb.SliceSize
}
func (tb *Table) UpdateMaxWidths() {
if len(tb.headerWidths) != tb.nCols {
tb.headerWidths = make([]int, tb.nCols)
tb.colMaxWidths = make([]int, tb.nCols)
}
if tb.SliceSize == 0 {
return
}
for fli := 0; fli < tb.nCols; fli++ {
tb.colMaxWidths[fli] = 0
col := tb.Table.Columns.Values[fli]
stsr, isstr := col.(*tensor.String)
if !isstr {
continue
}
mxw := 0
nr := tb.Table.NumRows()
for r := range nr {
sval := stsr.Values[tb.Table.RowIndex(r)]
mxw = max(mxw, len(sval))
}
tb.colMaxWidths[fli] = mxw
}
}
func (tb *Table) MakeHeader(p *tree.Plan) {
tree.AddAt(p, "header", func(w *core.Frame) {
core.ToolbarStyles(w)
w.FinalStyler(func(s *styles.Style) {
s.Display = styles.Flex // note: ToolbarStyles sets to None if no children, which can happen transiently in tables -- it doesn't recover from that.
s.Padding.Zero()
s.Grow.Set(0, 0)
s.Gap.Set(units.Em(0.5)) // matches grid default
})
w.Maker(func(p *tree.Plan) {
if tb.ShowIndexes {
tree.AddAt(p, "_head-index", func(w *core.Text) { // TODO: is not working
w.SetType(core.TextBodyMedium)
w.Styler(func(s *styles.Style) {
s.Align.Self = styles.Center
})
w.SetText("Index")
})
}
for fli := 0; fli < tb.nCols; fli++ {
field := tb.Table.Columns.Keys[fli]
tree.AddAt(p, "head-"+field, func(w *core.Button) {
w.SetType(core.ButtonAction)
w.Styler(func(s *styles.Style) {
s.Justify.Content = styles.Start
})
w.OnClick(func(e events.Event) {
tb.SortColumn(fli)
})
if tb.Table.Columns.Values[fli].NumDims() > 1 {
w.AddContextMenu(func(m *core.Scene) {
core.NewButton(m).SetText("Edit grid style").SetIcon(icons.Edit).
OnClick(func(e events.Event) {
tb.EditGridStyle(fli)
})
})
}
w.Updater(func() {
field := tb.Table.Columns.Keys[fli]
w.SetText(field).SetTooltip(field + " (tap to sort by)")
tb.headerWidths[fli] = len(field)
if fli == tb.SortIndex {
if tb.SortDescending {
w.SetIndicator(icons.KeyboardArrowDown)
} else {
w.SetIndicator(icons.KeyboardArrowUp)
}
} else {
w.SetIndicator(icons.Blank)
}
})
})
}
})
})
}
// SliceHeader returns the Frame header for slice grid
func (tb *Table) SliceHeader() *core.Frame {
return tb.Child(0).(*core.Frame)
}
// RowWidgetNs returns number of widgets per row and offset for index label
func (tb *Table) RowWidgetNs() (nWidgPerRow, idxOff int) {
nWidgPerRow = 1 + tb.nCols
idxOff = 1
if !tb.ShowIndexes {
nWidgPerRow -= 1
idxOff = 0
}
return
}
func (tb *Table) MakeRow(p *tree.Plan, i int) {
svi := tb.This.(core.Lister)
si, _, invis := svi.SliceIndex(i)
itxt := strconv.Itoa(i)
if tb.ShowIndexes {
tb.MakeGridIndex(p, i, si, itxt, invis)
}
for fli := 0; fli < tb.nCols; fli++ {
col := tb.Table.Columns.Values[fli]
valnm := fmt.Sprintf("value-%v.%v", fli, itxt)
_, isstr := col.(*tensor.String)
if col.NumDims() == 1 {
str := ""
fval := float64(0)
tree.AddNew(p, valnm, func() core.Value {
if isstr {
return core.NewValue(&str, "")
} else {
return core.NewValue(&fval, "")
}
}, func(w core.Value) {
wb := w.AsWidget()
tb.MakeValue(w, i)
w.AsTree().SetProperty(core.ListColProperty, fli)
if !tb.IsReadOnly() {
wb.OnChange(func(e events.Event) {
_, vi, invis := svi.SliceIndex(i)
if !invis {
if isstr {
col.SetString1D(str, vi)
} else {
col.SetFloat1D(fval, vi)
}
}
tb.This.(core.Lister).UpdateMaxWidths()
tb.SendChange()
})
}
wb.Updater(func() {
col := tb.Table.Columns.Values[fli]
_, vi, invis := svi.SliceIndex(i)
if !invis {
if isstr {
str = col.String1D(vi)
core.Bind(&str, w)
wb.SetTooltip(str)
} else {
fval = col.Float1D(vi)
wb.SetTooltip(fmt.Sprintf("%g", fval))
core.Bind(&fval, w)
}
} else {
wb.SetTooltip("")
if isstr {
core.Bind(tb.blankString, w)
} else {
core.Bind(tb.blankFloat, w)
}
}
wb.SetReadOnly(tb.IsReadOnly())
wb.SetState(invis, states.Invisible)
if svi.HasStyler() {
w.Style()
}
if invis {
wb.SetSelected(false)
}
})
})
} else {
tree.AddAt(p, valnm, func(w *TensorGrid) {
w.SetReadOnly(tb.IsReadOnly())
wb := w.AsWidget()
w.SetProperty(core.ListRowProperty, i)
w.SetProperty(core.ListColProperty, fli)
w.Styler(func(s *styles.Style) {
s.Grow.Set(0, 0)
})
wb.Updater(func() {
si, vi, invis := svi.SliceIndex(i)
var cell tensor.Tensor
if invis {
cell = tb.blankCell(fli, col)
} else {
cell = col.RowTensor(vi)
}
wb.ValueTitle = tb.ValueTitle + "[" + strconv.Itoa(si) + "]"
w.SetState(invis, states.Invisible)
w.SetTensor(cell)
w.GridStyle = *tb.GetColumnGridStyle(fli)
})
})
}
}
}
// blankCell returns tensor blanks for given tensor col
func (tb *Table) blankCell(cidx int, col tensor.Tensor) *tensor.Float64 {
if ctb, has := tb.blankCells[cidx]; has {
return ctb
}
ctb := tensor.New[float64](col.ShapeSizes()...).(*tensor.Float64)
tb.blankCells[cidx] = ctb
return ctb
}
// GetColumnGridStyle gets grid style for given column.
func (tb *Table) GetColumnGridStyle(col int) *GridStyle {
if ctd, has := tb.ColumnGridStyle[col]; has {
return ctd
}
ctd := &GridStyle{}
*ctd = tb.GridStyle
if tb.Table != nil {
cl := tb.Table.Columns.Values[col]
ctd.ApplyStylersFrom(cl)
}
return ctd
}
// NewAt inserts a new blank element at given index in the slice -- -1
// means the end
func (tb *Table) NewAt(idx int) {
tb.NewAtSelect(idx)
tb.Table.InsertRows(idx, 1)
tb.SelectIndexEvent(idx, events.SelectOne)
tb.Update()
tb.IndexGrabFocus(idx)
}
// DeleteAt deletes element at given index from slice
func (tb *Table) DeleteAt(idx int) {
if idx < 0 || idx >= tb.SliceSize {
return
}
tb.DeleteAtSelect(idx)
tb.Table.DeleteRows(idx, 1)
tb.Update()
}
// SortColumn sorts the slice for given column index.
// Toggles ascending vs. descending if already sorting on this dimension.
func (tb *Table) SortColumn(fldIndex int) {
sgh := tb.SliceHeader()
_, idxOff := tb.RowWidgetNs()
for fli := 0; fli < tb.nCols; fli++ {
hdr := sgh.Child(idxOff + fli).(*core.Button)
hdr.SetType(core.ButtonAction)
if fli == fldIndex {
if tb.SortIndex == fli {
tb.SortDescending = !tb.SortDescending
} else {
tb.SortDescending = false
}
}
}
tb.SortIndex = fldIndex
if fldIndex == -1 {
tb.Table.SortIndexes()
} else {
tb.Table.IndexesNeeded()
col := tb.Table.ColumnByIndex(tb.SortIndex)
col.Sort(!tb.SortDescending)
tb.Table.IndexesFromTensor(col)
}
tb.Update() // requires full update due to sort button icon
}
// EditGridStyle shows an editor dialog for grid style for given column index.
func (tb *Table) EditGridStyle(col int) {
ctd := tb.GetColumnGridStyle(col)
d := core.NewBody("Tensor grid style")
core.NewForm(d).SetStruct(ctd).
OnChange(func(e events.Event) {
tb.ColumnGridStyle[col] = ctd
tb.Update()
})
core.NewButton(d).SetText("Edit global style").SetIcon(icons.Edit).
OnClick(func(e events.Event) {
tb.EditGlobalGridStyle()
})
d.RunWindowDialog(tb)
}
// EditGlobalGridStyle shows an editor dialog for global grid styles.
func (tb *Table) EditGlobalGridStyle() {
d := core.NewBody("Tensor grid style")
core.NewForm(d).SetStruct(&tb.GridStyle).
OnChange(func(e events.Event) {
tb.Update()
})
d.RunWindowDialog(tb)
}
func (tb *Table) HasStyler() bool { return false }
func (tb *Table) StyleRow(w core.Widget, idx, fidx int) {}
// SortFieldName returns the name of the field being sorted, along with :up or
// :down depending on descending
func (tb *Table) SortFieldName() string {
if tb.SortIndex >= 0 && tb.SortIndex < tb.nCols {
nm := tb.Table.Columns.Keys[tb.SortIndex]
if tb.SortDescending {
nm += ":down"
} else {
nm += ":up"
}
return nm
}
return ""
}
// SetSortField sets sorting to happen on given field and direction -- see
// SortFieldName for details
func (tb *Table) SetSortFieldName(nm string) {
if nm == "" {
return
}
spnm := strings.Split(nm, ":")
got := false
for fli := 0; fli < tb.nCols; fli++ {
fld := tb.Table.Columns.Keys[fli]
if fld == spnm[0] {
got = true
// fmt.Println("sorting on:", fld.Name, fli, "from:", nm)
tb.SortIndex = fli
}
}
if len(spnm) == 2 {
if spnm[1] == "down" {
tb.SortDescending = true
} else {
tb.SortDescending = false
}
}
_ = got
// if got {
// tv.SortSlice()
// }
}
// RowFirstVisWidget returns the first visible widget for given row (could be
// index or not) -- false if out of range
func (tb *Table) RowFirstVisWidget(row int) (*core.WidgetBase, bool) {
if !tb.IsRowInBounds(row) {
return nil, false
}
nWidgPerRow, idxOff := tb.RowWidgetNs()
lg := tb.ListGrid
w := lg.Children[row*nWidgPerRow].(core.Widget).AsWidget()
if w.Geom.TotalBBox != (image.Rectangle{}) {
return w, true
}
ridx := nWidgPerRow * row
for fli := 0; fli < tb.nCols; fli++ {
w := lg.Child(ridx + idxOff + fli).(core.Widget).AsWidget()
if w.Geom.TotalBBox != (image.Rectangle{}) {
return w, true
}
}
return nil, false
}
// RowGrabFocus grabs the focus for the first focusable widget in given row --
// returns that element or nil if not successful -- note: grid must have
// already rendered for focus to be grabbed!
func (tb *Table) RowGrabFocus(row int) *core.WidgetBase {
if !tb.IsRowInBounds(row) || tb.InFocusGrab { // range check
return nil
}
nWidgPerRow, idxOff := tb.RowWidgetNs()
ridx := nWidgPerRow * row
lg := tb.ListGrid
// first check if we already have focus
for fli := 0; fli < tb.nCols; fli++ {
w := lg.Child(ridx + idxOff + fli).(core.Widget).AsWidget()
if w.StateIs(states.Focused) || w.ContainsFocus() {
return w
}
}
tb.InFocusGrab = true
defer func() { tb.InFocusGrab = false }()
for fli := 0; fli < tb.nCols; fli++ {
w := lg.Child(ridx + idxOff + fli).(core.Widget).AsWidget()
if w.CanFocus() {
w.SetFocus()
return w
}
}
return nil
}
//////// Header layout
func (tb *Table) SizeFinal() {
tb.ListBase.SizeFinal()
lg := tb.ListGrid
sh := tb.SliceHeader()
sh.ForWidgetChildren(func(i int, cw core.Widget, cwb *core.WidgetBase) bool {
sgb := core.AsWidget(lg.Child(i))
gsz := &sgb.Geom.Size
if gsz.Actual.Total.X == 0 {
return tree.Continue
}
ksz := &cwb.Geom.Size
ksz.Actual.Total.X = gsz.Actual.Total.X
ksz.Actual.Content.X = gsz.Actual.Content.X
ksz.Alloc.Total.X = gsz.Alloc.Total.X
ksz.Alloc.Content.X = gsz.Alloc.Content.X
return tree.Continue
})
gsz := &lg.Geom.Size
ksz := &sh.Geom.Size
if gsz.Actual.Total.X > 0 {
ksz.Actual.Total.X = gsz.Actual.Total.X
ksz.Actual.Content.X = gsz.Actual.Content.X
ksz.Alloc.Total.X = gsz.Alloc.Total.X
ksz.Alloc.Content.X = gsz.Alloc.Content.X
}
}
// SelectedColumnStrings returns the string values of given column name.
func (tb *Table) SelectedColumnStrings(colName string) []string {
dt := tb.Table
jis := tb.SelectedIndexesList(false)
if len(jis) == 0 || dt == nil {
return nil
}
var s []string
col := dt.Column(colName)
for _, i := range jis {
v := col.StringRow(i, 0)
s = append(s, v)
}
return s
}
//////// Copy / Cut / Paste
func (tb *Table) MakeToolbar(p *tree.Plan) {
if tb.Table == nil {
return
}
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(tb.Table.AddRows).SetIcon(icons.Add)
w.SetAfterFunc(func() { tb.Update() })
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(tb.Table.SortColumns).SetText("Sort").SetIcon(icons.Sort)
w.SetAfterFunc(func() { tb.Update() })
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(tb.Table.FilterString).SetText("Filter").SetIcon(icons.FilterAlt)
w.SetAfterFunc(func() { tb.Update() })
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(tb.Table.Sequential).SetText("Unfilter").SetIcon(icons.FilterAltOff)
w.SetAfterFunc(func() { tb.Update() })
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(tb.Table.OpenCSV).SetIcon(icons.Open)
w.SetAfterFunc(func() { tb.Update() })
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(tb.Table.SaveCSV).SetIcon(icons.Save)
w.SetAfterFunc(func() { tb.Update() })
})
}
func (tb *Table) MimeDataType() string {
return fileinfo.DataCsv
}
// CopySelectToMime copies selected rows to mime data
func (tb *Table) CopySelectToMime() mimedata.Mimes {
nitms := len(tb.SelectedIndexes)
if nitms == 0 {
return nil
}
ix := table.NewView(tb.Table)
idx := tb.SelectedIndexesList(false) // ascending
iidx := make([]int, len(idx))
for i, di := range idx {
iidx[i] = tb.Table.RowIndex(di)
}
ix.Indexes = iidx
var b bytes.Buffer
ix.WriteCSV(&b, tensor.Tab, table.Headers)
md := mimedata.NewTextBytes(b.Bytes())
md[0].Type = fileinfo.DataCsv
return md
}
// FromMimeData returns records from csv of mime data
func (tb *Table) FromMimeData(md mimedata.Mimes) [][]string {
var recs [][]string
for _, d := range md {
if d.Type == fileinfo.DataCsv {
b := bytes.NewBuffer(d.Data)
cr := csv.NewReader(b)
cr.Comma = tensor.Tab.Rune()
rec, err := cr.ReadAll()
if err != nil || len(rec) == 0 {
log.Printf("Error reading CSV from clipboard: %s\n", err)
return nil
}
recs = append(recs, rec...)
}
}
return recs
}
// PasteAssign assigns mime data (only the first one!) to this idx
func (tb *Table) PasteAssign(md mimedata.Mimes, idx int) {
recs := tb.FromMimeData(md)
if len(recs) == 0 {
return
}
tb.Table.ReadCSVRow(recs[1], tb.Table.RowIndex(idx))
tb.UpdateChange()
}
// PasteAtIndex inserts object(s) from mime data at (before) given slice index
// adds to end of table
func (tb *Table) PasteAtIndex(md mimedata.Mimes, idx int) {
recs := tb.FromMimeData(md)
nr := len(recs) - 1
if nr <= 0 {
return
}
tb.Table.InsertRows(idx, nr)
for ri := 0; ri < nr; ri++ {
rec := recs[1+ri]
rw := tb.Table.RowIndex(idx + ri)
tb.Table.ReadCSVRow(rec, rw)
}
tb.SendChange()
tb.SelectIndexEvent(idx, events.SelectOne)
tb.Update()
}
// Copyright (c) 2023, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package tensorcore provides GUI Cogent Core widgets for tensor types.
package tensorcore
import (
"fmt"
"image"
"strconv"
"cogentcore.org/core/base/fileinfo"
"cogentcore.org/core/base/fileinfo/mimedata"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/icons"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/states"
"cogentcore.org/core/styles/units"
"cogentcore.org/core/tree"
"cogentcore.org/lab/tensor"
)
// TensorEditor provides a GUI widget for representing [tensor.Tensor] values.
type TensorEditor struct {
core.ListBase
// the tensor that we're a view of
Tensor tensor.Tensor `set:"-"`
// overall layout options for tensor display
Layout Layout `set:"-"`
// number of columns in table (as of last update)
NCols int `edit:"-"`
// headerWidths has number of characters in each header, per visfields
headerWidths []int `copier:"-" display:"-" json:"-" xml:"-"`
// colMaxWidths records maximum width in chars of string type fields
colMaxWidths []int `set:"-" copier:"-" json:"-" xml:"-"`
// blank values for out-of-range rows
BlankString string
BlankFloat float64
}
// check for interface impl
var _ core.Lister = (*TensorEditor)(nil)
func (tb *TensorEditor) Init() {
tb.ListBase.Init()
tb.Layout.OddRow = true
tb.Makers.Normal[0] = func(p *tree.Plan) { // TODO: reduce redundancy with ListBase Maker
svi := tb.This.(core.Lister)
svi.UpdateSliceSize()
scrollTo := -1
if tb.InitSelectedIndex >= 0 {
tb.SelectedIndex = tb.InitSelectedIndex
tb.InitSelectedIndex = -1
scrollTo = tb.SelectedIndex
}
if scrollTo >= 0 {
tb.ScrollToIndex(scrollTo)
}
tb.UpdateStartIndex()
tb.UpdateMaxWidths()
tb.Updater(func() {
if tb.Tensor.NumDims() == 1 {
tb.Layout.TopZero = true
}
tb.UpdateStartIndex()
})
tb.MakeHeader(p)
tb.MakeGrid(p, func(p *tree.Plan) {
for i := 0; i < tb.VisibleRows; i++ {
svi.MakeRow(p, i)
}
})
}
}
func (tb *TensorEditor) SliceIndex(i int) (si, vi int, invis bool) {
si = tb.StartIndex + i
vi = si
invis = si >= tb.SliceSize
if !tb.Layout.TopZero {
vi = (tb.SliceSize - 1) - si
}
return
}
// StyleValue performs additional value widget styling
func (tb *TensorEditor) StyleValue(w core.Widget, s *styles.Style, row, col int) {
hw := float32(tb.headerWidths[col])
if len(tb.colMaxWidths) > col {
hw = max(float32(tb.colMaxWidths[col]), hw)
}
hv := units.Ch(hw)
s.Min.X.Value = max(s.Min.X.Value, hv.Convert(s.Min.X.Unit, &s.UnitContext).Value)
s.SetTextWrap(false)
}
// SetTensor sets the source tensor that we are viewing,
// and then configures the display.
func (tb *TensorEditor) SetTensor(et tensor.Tensor) *TensorEditor {
if et == nil {
return nil
}
tb.Tensor = et
tb.This.(core.Lister).UpdateSliceSize()
tb.SetSliceBase()
tb.Update()
return tb
}
func (tb *TensorEditor) UpdateSliceSize() int {
tb.SliceSize, tb.NCols, _, _ = tensor.Projection2DShape(tb.Tensor.Shape(), tb.Layout.OddRow)
return tb.SliceSize
}
func (tb *TensorEditor) UpdateMaxWidths() {
if len(tb.headerWidths) != tb.NCols {
tb.headerWidths = make([]int, tb.NCols)
tb.colMaxWidths = make([]int, tb.NCols)
}
if tb.SliceSize == 0 {
return
}
_, isstr := tb.Tensor.(*tensor.String)
for fli := 0; fli < tb.NCols; fli++ {
tb.colMaxWidths[fli] = 0
if !isstr {
continue
}
mxw := 0
// for _, ixi := range tb.Tensor.Indexes {
// if ixi >= 0 {
// sval := stsr.Values[ixi]
// mxw = max(mxw, len(sval))
// }
// }
tb.colMaxWidths[fli] = mxw
}
}
func (tb *TensorEditor) MakeHeader(p *tree.Plan) {
tree.AddAt(p, "header", func(w *core.Frame) {
core.ToolbarStyles(w)
w.FinalStyler(func(s *styles.Style) {
s.Padding.Zero()
s.Grow.Set(0, 0)
s.Gap.Set(units.Em(0.5)) // matches grid default
})
w.Maker(func(p *tree.Plan) {
if tb.ShowIndexes {
tree.AddAt(p, "_head-index", func(w *core.Text) { // TODO: is not working
w.SetType(core.TextBodyMedium)
w.Styler(func(s *styles.Style) {
s.Align.Self = styles.Center
})
w.SetText("Index")
})
}
for fli := 0; fli < tb.NCols; fli++ {
hdr := tb.ColumnHeader(fli)
tree.AddAt(p, "head-"+hdr, func(w *core.Button) {
w.SetType(core.ButtonAction)
w.Styler(func(s *styles.Style) {
s.Justify.Content = styles.Start
})
w.Updater(func() {
hdr := tb.ColumnHeader(fli)
w.SetText(hdr).SetTooltip(hdr)
tb.headerWidths[fli] = len(hdr)
})
})
}
})
})
}
func (tb *TensorEditor) ColumnHeader(col int) string {
_, cc := tensor.Projection2DCoords(tb.Tensor.Shape(), tb.Layout.OddRow, 0, col)
sitxt := ""
for i, ccc := range cc {
sitxt += fmt.Sprintf("%03d", ccc)
if i < len(cc)-1 {
sitxt += ","
}
}
return sitxt
}
// SliceHeader returns the Frame header for slice grid
func (tb *TensorEditor) SliceHeader() *core.Frame {
return tb.Child(0).(*core.Frame)
}
// RowWidgetNs returns number of widgets per row and offset for index label
func (tb *TensorEditor) RowWidgetNs() (nWidgPerRow, idxOff int) {
nWidgPerRow = 1 + tb.NCols
idxOff = 1
if !tb.ShowIndexes {
nWidgPerRow -= 1
idxOff = 0
}
return
}
func (tb *TensorEditor) MakeRow(p *tree.Plan, i int) {
svi := tb.This.(core.Lister)
si, _, invis := svi.SliceIndex(i)
itxt := strconv.Itoa(i)
if tb.ShowIndexes {
tb.MakeGridIndex(p, i, si, itxt, invis)
}
_, isstr := tb.Tensor.(*tensor.String)
for fli := 0; fli < tb.NCols; fli++ {
valnm := fmt.Sprintf("value-%v.%v", fli, itxt)
fval := float64(0)
str := ""
tree.AddNew(p, valnm, func() core.Value {
if isstr {
return core.NewValue(&str, "")
} else {
return core.NewValue(&fval, "")
}
}, func(w core.Value) {
wb := w.AsWidget()
tb.MakeValue(w, i)
w.AsTree().SetProperty(core.ListColProperty, fli)
if !tb.IsReadOnly() {
wb.OnChange(func(e events.Event) {
_, vi, invis := svi.SliceIndex(i)
if !invis {
if isstr {
tensor.Projection2DSetString(tb.Tensor, tb.Layout.OddRow, vi, fli, str)
} else {
tensor.Projection2DSet(tb.Tensor, tb.Layout.OddRow, vi, fli, fval)
}
}
tb.This.(core.Lister).UpdateMaxWidths()
tb.SendChange()
})
}
wb.Updater(func() {
_, vi, invis := svi.SliceIndex(i)
if !invis {
if isstr {
str = tensor.Projection2DString(tb.Tensor, tb.Layout.OddRow, vi, fli)
core.Bind(&str, w)
} else {
fval = tensor.Projection2DValue(tb.Tensor, tb.Layout.OddRow, vi, fli)
core.Bind(&fval, w)
}
} else {
if isstr {
core.Bind(tb.BlankString, w)
} else {
core.Bind(tb.BlankFloat, w)
}
}
wb.SetReadOnly(tb.IsReadOnly())
wb.SetState(invis, states.Invisible)
if svi.HasStyler() {
w.Style()
}
if invis {
wb.SetSelected(false)
}
})
})
}
}
func (tb *TensorEditor) HasStyler() bool { return false }
func (tb *TensorEditor) StyleRow(w core.Widget, idx, fidx int) {}
// RowFirstVisWidget returns the first visible widget for given row (could be
// index or not) -- false if out of range
func (tb *TensorEditor) RowFirstVisWidget(row int) (*core.WidgetBase, bool) {
if !tb.IsRowInBounds(row) {
return nil, false
}
nWidgPerRow, idxOff := tb.RowWidgetNs()
lg := tb.ListGrid
w := lg.Children[row*nWidgPerRow].(core.Widget).AsWidget()
if w.Geom.TotalBBox != (image.Rectangle{}) {
return w, true
}
ridx := nWidgPerRow * row
for fli := 0; fli < tb.NCols; fli++ {
w := lg.Child(ridx + idxOff + fli).(core.Widget).AsWidget()
if w.Geom.TotalBBox != (image.Rectangle{}) {
return w, true
}
}
return nil, false
}
// RowGrabFocus grabs the focus for the first focusable widget in given row --
// returns that element or nil if not successful -- note: grid must have
// already rendered for focus to be grabbed!
func (tb *TensorEditor) RowGrabFocus(row int) *core.WidgetBase {
if !tb.IsRowInBounds(row) || tb.InFocusGrab { // range check
return nil
}
nWidgPerRow, idxOff := tb.RowWidgetNs()
ridx := nWidgPerRow * row
lg := tb.ListGrid
// first check if we already have focus
for fli := 0; fli < tb.NCols; fli++ {
w := lg.Child(ridx + idxOff + fli).(core.Widget).AsWidget()
if w.StateIs(states.Focused) || w.ContainsFocus() {
return w
}
}
tb.InFocusGrab = true
defer func() { tb.InFocusGrab = false }()
for fli := 0; fli < tb.NCols; fli++ {
w := lg.Child(ridx + idxOff + fli).(core.Widget).AsWidget()
if w.CanFocus() {
w.SetFocus()
return w
}
}
return nil
}
/////// Header layout
func (tb *TensorEditor) SizeFinal() {
tb.ListBase.SizeFinal()
lg := tb.ListGrid
sh := tb.SliceHeader()
sh.ForWidgetChildren(func(i int, cw core.Widget, cwb *core.WidgetBase) bool {
sgb := core.AsWidget(lg.Child(i))
gsz := &sgb.Geom.Size
if gsz.Actual.Total.X == 0 {
return tree.Continue
}
ksz := &cwb.Geom.Size
ksz.Actual.Total.X = gsz.Actual.Total.X
ksz.Actual.Content.X = gsz.Actual.Content.X
ksz.Alloc.Total.X = gsz.Alloc.Total.X
ksz.Alloc.Content.X = gsz.Alloc.Content.X
return tree.Continue
})
gsz := &lg.Geom.Size
ksz := &sh.Geom.Size
if gsz.Actual.Total.X > 0 {
ksz.Actual.Total.X = gsz.Actual.Total.X
ksz.Actual.Content.X = gsz.Actual.Content.X
ksz.Alloc.Total.X = gsz.Alloc.Total.X
ksz.Alloc.Content.X = gsz.Alloc.Content.X
}
}
//////// Copy / Cut / Paste
// SaveTSV writes a tensor to a tab-separated-values (TSV) file.
// Outer-most dims are rows in the file, and inner-most is column --
// Reading just grabs all values and doesn't care about shape.
func (tb *TensorEditor) SaveCSV(filename core.Filename) error { //types:add
return tensor.SaveCSV(tb.Tensor, fsx.Filename(filename), tensor.Tab)
}
// OpenTSV reads a tensor from a tab-separated-values (TSV) file.
// using the Go standard encoding/csv reader conforming
// to the official CSV standard.
// Reads all values and assigns as many as fit.
func (tb *TensorEditor) OpenCSV(filename core.Filename) error { //types:add
return tensor.OpenCSV(tb.Tensor, fsx.Filename(filename), tensor.Tab)
}
func (tb *TensorEditor) MakeToolbar(p *tree.Plan) {
if tb.Tensor == nil {
return
}
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(tb.OpenCSV).SetIcon(icons.Open)
w.SetAfterFunc(func() { tb.Update() })
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(tb.SaveCSV).SetIcon(icons.Save)
w.SetAfterFunc(func() { tb.Update() })
})
}
func (tb *TensorEditor) MimeDataType() string {
return fileinfo.DataCsv
}
// CopySelectToMime copies selected rows to mime data
func (tb *TensorEditor) CopySelectToMime() mimedata.Mimes {
nitms := len(tb.SelectedIndexes)
if nitms == 0 {
return nil
}
// idx := tb.SelectedIndexesList(false) // ascending
// var b bytes.Buffer
// ix.WriteCSV(&b, tensor.Tab, table.Headers)
// md := mimedata.NewTextBytes(b.Bytes())
// md[0].Type = fileinfo.DataCsv
// return md
return nil
}
// FromMimeData returns records from csv of mime data
func (tb *TensorEditor) FromMimeData(md mimedata.Mimes) [][]string {
var recs [][]string
for _, d := range md {
if d.Type == fileinfo.DataCsv {
// b := bytes.NewBuffer(d.Data)
// cr := csv.NewReader(b)
// cr.Comma = tensor.Tab.Rune()
// rec, err := cr.ReadAll()
// if err != nil || len(rec) == 0 {
// log.Printf("Error reading CSV from clipboard: %s\n", err)
// return nil
// }
// recs = append(recs, rec...)
}
}
return recs
}
// PasteAssign assigns mime data (only the first one!) to this idx
func (tb *TensorEditor) PasteAssign(md mimedata.Mimes, idx int) {
// recs := tb.FromMimeData(md)
// if len(recs) == 0 {
// return
// }
// tb.Tensor.ReadCSVRow(recs[1], tb.Tensor.Indexes[idx])
// tb.UpdateChange()
}
// PasteAtIndex inserts object(s) from mime data at (before) given slice index
// adds to end of table
func (tb *TensorEditor) PasteAtIndex(md mimedata.Mimes, idx int) {
// recs := tb.FromMimeData(md)
// nr := len(recs) - 1
// if nr <= 0 {
// return
// }
// tb.Tensor.InsertRows(idx, nr)
// for ri := 0; ri < nr; ri++ {
// rec := recs[1+ri]
// rw := tb.Tensor.Indexes[idx+ri]
// tb.Tensor.ReadCSVRow(rec, rw)
// }
// tb.SendChange()
// tb.SelectIndexEvent(idx, events.SelectOne)
// tb.Update()
}
// Copyright (c) 2019, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensorcore
import (
"image/color"
"log"
"cogentcore.org/core/base/slicesx"
"cogentcore.org/core/colors"
"cogentcore.org/core/colors/colormap"
"cogentcore.org/core/core"
"cogentcore.org/core/events"
"cogentcore.org/core/icons"
"cogentcore.org/core/math32"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/core/styles"
"cogentcore.org/core/styles/abilities"
"cogentcore.org/core/styles/units"
"cogentcore.org/core/text/rich"
"cogentcore.org/core/tree"
"cogentcore.org/lab/tensor"
)
// LabelSpace is space after label in dot pixels.
const LabelSpace = 8
// TensorGrid is a widget that displays tensor values as a grid
// of colored squares. Higher-dimensional data is projected into 2D
// using [tensor.Projection2DShape] and related functions.
type TensorGrid struct {
core.WidgetBase
// Tensor is the tensor that we view.
Tensor tensor.Tensor `set:"-"`
// GridStyle has grid display style properties.
GridStyle GridStyle
// ColorMap is the colormap displayed (based on)
ColorMap *colormap.Map
// RowLabels are optional labels for each row of the 2D shape.
// Empty strings cause grouping with rendered lines.
RowLabels []string
// ColumnLabels are optional labels for each column of the 2D shape.
// Empty strings cause grouping with rendered lines.
ColumnLabels []string
rowMaxSz math32.Vector2 // maximum label size
rowMinBlank int // minimum number of blank rows
rowNGps int // number of groups in row (non-blank after blank)
colMaxSz math32.Vector2 // maximum label size
colMinBlank int // minimum number of blank cols
colNGps int // number of groups in col (non-blank after blank)
}
func (tg *TensorGrid) WidgetValue() any { return &tg.Tensor }
func (tg *TensorGrid) SetWidgetValue(value any) error {
tg.SetTensor(value.(tensor.Tensor))
return nil
}
func (tg *TensorGrid) Init() {
tg.WidgetBase.Init()
tg.GridStyle.Defaults()
tg.Styler(func(s *styles.Style) {
s.SetAbilities(true, abilities.DoubleClickable)
s.Background = colors.Scheme.Surface
s.Font.Size.Dp(tg.GridStyle.FontSize)
s.Font.Size.ToDots(&s.UnitContext)
ms := tg.MinSize()
s.Min.Set(units.Dot(ms.X), units.Dot(ms.Y))
s.Grow.Set(1, 1)
})
tg.OnDoubleClick(func(e events.Event) {
tg.TensorEditor()
})
tg.AddContextMenu(func(m *core.Scene) {
core.NewFuncButton(m).SetFunc(tg.TensorEditor).SetIcon(icons.Edit)
core.NewFuncButton(m).SetFunc(tg.EditStyle).SetIcon(icons.Edit)
})
}
func (tg *TensorGrid) MakeToolbar(p *tree.Plan) {
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(tg.TensorEditor).SetIcon(icons.Edit)
})
tree.Add(p, func(w *core.FuncButton) {
w.SetFunc(tg.EditStyle).SetIcon(icons.Edit)
})
}
// SetTensor sets the tensor. Must call Update after this.
func (tg *TensorGrid) SetTensor(tsr tensor.Tensor) *TensorGrid {
if _, ok := tsr.(*tensor.String); ok {
log.Printf("TensorGrid: String tensors cannot be displayed using TensorGrid\n")
return tg
}
tg.Tensor = tsr
if tg.Tensor != nil {
tg.GridStyle.ApplyStylersFrom(tg.Tensor)
}
return tg
}
// TensorEditor pulls up a TensorEditor of our tensor
func (tg *TensorGrid) TensorEditor() { //types:add
d := core.NewBody("Tensor editor")
tb := core.NewToolbar(d)
te := NewTensorEditor(d).SetTensor(tg.Tensor)
te.OnChange(func(e events.Event) {
tg.NeedsRender()
})
tb.Maker(te.MakeToolbar)
d.RunWindowDialog(tg)
}
func (tg *TensorGrid) EditStyle() { //types:add
d := core.NewBody("Tensor grid style")
core.NewForm(d).SetStruct(&tg.GridStyle).
OnChange(func(e events.Event) {
tg.Restyle()
})
d.RunWindowDialog(tg)
}
// MinSize returns minimum size based on tensor and display settings
func (tg *TensorGrid) MinSize() math32.Vector2 {
if tg.Tensor == nil || tg.Tensor.Len() == 0 {
return math32.Vector2{}
}
if tg.GridStyle.Image {
return math32.Vec2(float32(tg.Tensor.DimSize(1)), float32(tg.Tensor.DimSize(0)))
}
rows, cols, rowEx, colEx := tensor.Projection2DShape(tg.Tensor.Shape(), tg.GridStyle.OddRow)
frw := float32(rows) + float32(rowEx)*tg.GridStyle.DimExtra // extra spacing
fcl := float32(cols) + float32(colEx)*tg.GridStyle.DimExtra // extra spacing
mx := float32(max(frw, fcl))
gsz := tg.GridStyle.TotalSize / mx
gsz = tg.GridStyle.Size.ClampValue(gsz)
gsz = max(gsz, 2)
sz := math32.Vec2(gsz*float32(fcl), gsz*float32(frw))
if len(tg.RowLabels) > 0 {
tg.RowLabels = slicesx.SetLength(tg.RowLabels, rows)
}
if len(tg.ColumnLabels) > 0 {
tg.ColumnLabels = slicesx.SetLength(tg.ColumnLabels, rows)
}
tg.rowMinBlank, tg.rowNGps, tg.rowMaxSz = tg.SizeLabel(tg.RowLabels, false)
tg.colMinBlank, tg.colNGps, tg.colMaxSz = tg.SizeLabel(tg.ColumnLabels, true)
// tg.colMaxSz.Y += tg.rowMaxSz.Y // needs one more for some reason
if tg.rowMaxSz.X > 0 {
sz.X += tg.rowMaxSz.X + LabelSpace
}
if tg.colMaxSz.Y > 0 {
sz.Y += tg.colMaxSz.Y + LabelSpace
}
return sz
}
func (tg *TensorGrid) SizeLabel(lbs []string, col bool) (minBlank, ngps int, sz math32.Vector2) {
minBlank = len(lbs)
if minBlank == 0 {
return
}
mx := 0
mxi := 0
curblk := 0
ngps = 0
for i, lb := range lbs {
l := len(lb)
if l == 0 {
curblk++
continue
}
if curblk > 0 {
ngps++
}
if i > 0 {
minBlank = min(minBlank, curblk)
}
curblk = 0
if l > mx {
mx = l
mxi = i
}
}
minBlank = min(minBlank, curblk)
ts := tg.Scene.TextShaper()
if ts != nil {
sty, tsty := tg.Styles.NewRichText()
tx := rich.NewText(sty, []rune(lbs[mxi]))
lns := ts.WrapLines(tx, sty, tsty, math32.Vec2(10000, 1000))
sz = lns.Bounds.Size().Ceil()
if col {
sz.X, sz.Y = sz.Y, sz.X
}
}
return
}
// EnsureColorMap makes sure there is a valid color map that matches specified name
func (tg *TensorGrid) EnsureColorMap() {
if tg.ColorMap != nil && tg.ColorMap.Name != string(tg.GridStyle.ColorMap) {
tg.ColorMap = nil
}
if tg.ColorMap == nil {
ok := false
tg.ColorMap, ok = colormap.AvailableMaps[string(tg.GridStyle.ColorMap)]
if !ok {
tg.GridStyle.ColorMap = ""
tg.GridStyle.Defaults()
}
tg.ColorMap = colormap.AvailableMaps[string(tg.GridStyle.ColorMap)]
}
}
func (tg *TensorGrid) Color(val float64) (norm float64, clr color.Color) {
if tg.ColorMap.Indexed {
clr = tg.ColorMap.MapIndex(int(val))
} else {
norm = tg.GridStyle.Range.ClipNormValue(val)
clr = tg.ColorMap.Map(float32(norm))
}
return
}
func (tg *TensorGrid) UpdateRange() {
if !tg.GridStyle.Range.FixMin || !tg.GridStyle.Range.FixMax {
min, max, _, _ := tensor.Range(tg.Tensor.AsValues())
if !tg.GridStyle.Range.FixMin {
nmin := minmax.NiceRoundNumber(min, true) // true = below #
tg.GridStyle.Range.Min = nmin
}
if !tg.GridStyle.Range.FixMax {
nmax := minmax.NiceRoundNumber(max, false) // false = above #
tg.GridStyle.Range.Max = nmax
}
}
}
func (tg *TensorGrid) Render() {
if tg.Tensor == nil || tg.Tensor.Len() == 0 {
return
}
tg.EnsureColorMap()
tg.UpdateRange()
if tg.GridStyle.Image {
tg.renderImage()
return
}
dimEx := tg.GridStyle.DimExtra
tsr := tg.Tensor
pc := &tg.Scene.Painter
ts := tg.Scene.TextShaper()
sty, tsty := tg.Styles.NewRichText()
pos := tg.Geom.Pos.Content
sz := tg.Geom.Size.Actual.Content
// sz.SetSubScalar(tg.Disp.BotRtSpace.Dots)
effsz := sz
if tg.rowMaxSz.X > 0 {
effsz.X -= tg.rowMaxSz.X + LabelSpace
}
if tg.colMaxSz.Y > 0 {
effsz.Y -= tg.colMaxSz.Y + LabelSpace
}
pc.FillBox(pos, sz, tg.Styles.Background)
rows, cols, rowEx, colEx := tensor.Projection2DShape(tsr.Shape(), tg.GridStyle.OddRow)
rowsInner := rows
colsInner := cols
if rowEx > 0 {
rowsInner = rows / rowEx
}
if colEx > 0 {
colsInner = cols / colEx
}
// group lines
rowEx += tg.rowNGps
colEx += tg.colNGps
frw := float32(rows) + float32(rowEx)*dimEx // extra spacing
fcl := float32(cols) + float32(colEx)*dimEx // extra spacing
tsz := math32.Vec2(fcl, frw)
gsz := effsz.Div(tsz)
if len(tg.RowLabels) > 0 { // Render Rows
epos := pos
epos.Y += tg.colMaxSz.Y + LabelSpace
nr := len(tg.RowLabels)
mx := min(nr, rows)
ygp := 0
prvyblk := false
for y := 0; y < mx; y++ {
lb := tg.RowLabels[y]
if len(lb) == 0 {
prvyblk = true
continue
}
if prvyblk {
ygp++
prvyblk = false
}
yex := float32(ygp) * dimEx
tx := rich.NewText(sty, []rune(lb))
lns := ts.WrapLines(tx, sty, tsty, math32.Vec2(10000, 1000))
cr := math32.Vec2(0, float32(y)+yex)
pr := epos.Add(cr.Mul(gsz))
pc.DrawText(lns, pr)
}
pos.X += tg.rowMaxSz.X + LabelSpace
}
if len(tg.ColumnLabels) > 0 { // Render Cols
epos := pos
if tg.GridStyle.ColumnRotation > 0 {
epos.X += tg.colMaxSz.X
}
nc := len(tg.ColumnLabels)
mx := min(nc, cols)
xgp := 0
prvxblk := false
for x := 0; x < mx; x++ {
lb := tg.ColumnLabels[x]
if len(lb) == 0 {
prvxblk = true
continue
}
if prvxblk {
xgp++
prvxblk = false
}
xex := float32(xgp) * dimEx
tx := rich.NewText(sty, []rune(lb))
lns := ts.WrapLines(tx, sty, tsty, math32.Vec2(10000, 1000))
cr := math32.Vec2(float32(x)+xex, 0)
pr := epos.Add(cr.Mul(gsz))
rot := tg.GridStyle.ColumnRotation
if rot < 0 {
pr.Y += tg.colMaxSz.Y
}
rotx := math32.Rotate2DAround(math32.DegToRad(rot), pr)
m := pc.Paint.Transform
pc.Paint.Transform = m.Mul(rotx)
pc.DrawText(lns, pr)
pc.Paint.Transform = m
}
pos.Y += tg.colMaxSz.Y + LabelSpace
}
ssz := gsz.MulScalar(tg.GridStyle.GridFill) // smaller size with margin
prvyblk := false
ygp := 0
for y := 0; y < rows; y++ {
yex := float32(int(y/rowsInner)) * dimEx
if len(tg.RowLabels) > 0 {
ylb := tg.RowLabels[y]
if len(ylb) > 0 && prvyblk {
ygp++
prvyblk = false
} else if len(ylb) == 0 {
prvyblk = true
}
yex += float32(ygp) * dimEx
}
prvxblk := false
xgp := 0
for x := 0; x < cols; x++ {
xex := float32(int(x/colsInner)) * dimEx
ey := y
if !tg.GridStyle.TopZero {
ey = (rows - 1) - y
}
if len(tg.ColumnLabels) > 0 {
xlb := tg.ColumnLabels[x]
if len(xlb) > 0 && prvxblk {
xgp++
prvxblk = false
} else if len(xlb) == 0 {
prvxblk = true
}
xex += float32(xgp) * dimEx
}
val := tensor.Projection2DValue(tsr, tg.GridStyle.OddRow, ey, x)
cr := math32.Vec2(float32(x)+xex, float32(y)+yex)
pr := pos.Add(cr.Mul(gsz))
_, clr := tg.Color(val)
pc.FillBox(pr, ssz, colors.Uniform(clr))
}
}
}
func (tg *TensorGrid) renderImage() {
if tg.Tensor == nil || tg.Tensor.Len() == 0 {
return
}
pc := &tg.Scene.Painter
pos := tg.Geom.Pos.Content
sz := tg.Geom.Size.Actual.Content
pc.FillBox(pos, sz, tg.Styles.Background)
tsr := tg.Tensor
ysz := tsr.DimSize(0)
xsz := tsr.DimSize(1)
nclr := 1
outclr := false // outer dimension is color
if tsr.NumDims() == 3 {
if tsr.DimSize(0) == 3 || tsr.DimSize(0) == 4 {
outclr = true
ysz = tsr.DimSize(1)
xsz = tsr.DimSize(2)
nclr = tsr.DimSize(0)
} else {
nclr = tsr.DimSize(2)
}
}
tsz := math32.Vec2(float32(xsz), float32(ysz))
gsz := sz.Div(tsz)
for y := 0; y < ysz; y++ {
for x := 0; x < xsz; x++ {
ey := y
if !tg.GridStyle.TopZero {
ey = (ysz - 1) - y
}
switch {
case outclr:
var r, g, b, a float64
a = 1
r = tg.GridStyle.Range.ClipNormValue(tsr.Float(0, y, x))
g = tg.GridStyle.Range.ClipNormValue(tsr.Float(1, y, x))
b = tg.GridStyle.Range.ClipNormValue(tsr.Float(2, y, x))
if nclr > 3 {
a = tg.GridStyle.Range.ClipNormValue(tsr.Float(3, y, x))
}
cr := math32.Vec2(float32(x), float32(ey))
pr := pos.Add(cr.Mul(gsz))
pc.Stroke.Color = colors.Uniform(colors.FromFloat64(r, g, b, a))
pc.FillBox(pr, gsz, pc.Stroke.Color)
case nclr > 1:
var r, g, b, a float64
a = 1
r = tg.GridStyle.Range.ClipNormValue(tsr.Float(y, x, 0))
g = tg.GridStyle.Range.ClipNormValue(tsr.Float(y, x, 1))
b = tg.GridStyle.Range.ClipNormValue(tsr.Float(y, x, 2))
if nclr > 3 {
a = tg.GridStyle.Range.ClipNormValue(tsr.Float(y, x, 3))
}
cr := math32.Vec2(float32(x), float32(ey))
pr := pos.Add(cr.Mul(gsz))
pc.Stroke.Color = colors.Uniform(colors.FromFloat64(r, g, b, a))
pc.FillBox(pr, gsz, pc.Stroke.Color)
default:
val := tg.GridStyle.Range.ClipNormValue(tsr.Float(y, x))
cr := math32.Vec2(float32(x), float32(ey))
pr := pos.Add(cr.Mul(gsz))
pc.Stroke.Color = colors.Uniform(colors.FromFloat64(val, val, val, 1))
pc.FillBox(pr, gsz, pc.Stroke.Color)
}
}
}
}
// RepeatsToBlank returns string slice with any sequentially repeated strings
// set to blank (empty string), which drives grouping in the TensorGrid labels.
func RepeatsToBlank(str []string) []string {
sz := len(str)
br := make([]string, sz)
last := ""
for r, s := range str {
if s == last {
continue
}
br[r] = s
last = s
}
return br
}
// Code generated by "core generate"; DO NOT EDIT.
package tensorcore
import (
"cogentcore.org/core/colors/colormap"
"cogentcore.org/core/core"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/core/tree"
"cogentcore.org/core/types"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
)
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/tensorcore.Layout", IDName: "layout", Doc: "Layout are layout options for displaying tensors.", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"--setters"}}}, Fields: []types.Field{{Name: "OddRow", Doc: "OddRow means that even-numbered dimensions are displayed as Y*X rectangles.\nThis determines along which dimension to display any remaining\nodd dimension: OddRow = true = organize vertically along row\ndimension, false = organize horizontally across column dimension."}, {Name: "TopZero", Doc: "TopZero means that the Y=0 coordinate is displayed from the top-down;\notherwise the Y=0 coordinate is displayed from the bottom up,\nwhich is typical for emergent network patterns."}, {Name: "Image", Doc: "Image will display the data as a bitmap image. If a 2D tensor, then it will\nbe a greyscale image. If a 3D tensor with size of either the first\nor last dim = either 3 or 4, then it is a RGB(A) color image."}}})
// SetOddRow sets the [Layout.OddRow]:
// OddRow means that even-numbered dimensions are displayed as Y*X rectangles.
// This determines along which dimension to display any remaining
// odd dimension: OddRow = true = organize vertically along row
// dimension, false = organize horizontally across column dimension.
func (t *Layout) SetOddRow(v bool) *Layout { t.OddRow = v; return t }
// SetTopZero sets the [Layout.TopZero]:
// TopZero means that the Y=0 coordinate is displayed from the top-down;
// otherwise the Y=0 coordinate is displayed from the bottom up,
// which is typical for emergent network patterns.
func (t *Layout) SetTopZero(v bool) *Layout { t.TopZero = v; return t }
// SetImage sets the [Layout.Image]:
// Image will display the data as a bitmap image. If a 2D tensor, then it will
// be a greyscale image. If a 3D tensor with size of either the first
// or last dim = either 3 or 4, then it is a RGB(A) color image.
func (t *Layout) SetImage(v bool) *Layout { t.Image = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/tensorcore.GridStyle", IDName: "grid-style", Doc: "GridStyle are options for displaying tensors", Directives: []types.Directive{{Tool: "types", Directive: "add", Args: []string{"--setters"}}}, Embeds: []types.Field{{Name: "Layout"}}, Fields: []types.Field{{Name: "Range", Doc: "Range to plot"}, {Name: "MinMax", Doc: "MinMax has the actual range of data, if not using fixed Range."}, {Name: "ColorMap", Doc: "ColorMap is the name of the color map to use in translating values to colors."}, {Name: "GridFill", Doc: "GridFill sets proportion of grid square filled by the color block:\n1 = all, .5 = half, etc."}, {Name: "DimExtra", Doc: "DimExtra is the amount of extra space to add at dimension boundaries,\nas a proportion of total grid size."}, {Name: "Size", Doc: "Size sets the minimum and maximum size for grid squares."}, {Name: "TotalSize", Doc: "TotalSize sets the total preferred display size along largest dimension.\nGrid squares will be sized to fit within this size,\nsubject to the Size.Min / Max constraints, which have precedence."}, {Name: "FontSize", Doc: "FontSize is the font size in standard Dp units for labels."}, {Name: "ColumnRotation", Doc: "ColumnRotation is the rotation angle in degrees for column labels"}}})
// SetRange sets the [GridStyle.Range]:
// Range to plot
func (t *GridStyle) SetRange(v minmax.Range64) *GridStyle { t.Range = v; return t }
// SetMinMax sets the [GridStyle.MinMax]:
// MinMax has the actual range of data, if not using fixed Range.
func (t *GridStyle) SetMinMax(v minmax.F64) *GridStyle { t.MinMax = v; return t }
// SetColorMap sets the [GridStyle.ColorMap]:
// ColorMap is the name of the color map to use in translating values to colors.
func (t *GridStyle) SetColorMap(v core.ColorMapName) *GridStyle { t.ColorMap = v; return t }
// SetGridFill sets the [GridStyle.GridFill]:
// GridFill sets proportion of grid square filled by the color block:
// 1 = all, .5 = half, etc.
func (t *GridStyle) SetGridFill(v float32) *GridStyle { t.GridFill = v; return t }
// SetDimExtra sets the [GridStyle.DimExtra]:
// DimExtra is the amount of extra space to add at dimension boundaries,
// as a proportion of total grid size.
func (t *GridStyle) SetDimExtra(v float32) *GridStyle { t.DimExtra = v; return t }
// SetSize sets the [GridStyle.Size]:
// Size sets the minimum and maximum size for grid squares.
func (t *GridStyle) SetSize(v minmax.F32) *GridStyle { t.Size = v; return t }
// SetTotalSize sets the [GridStyle.TotalSize]:
// TotalSize sets the total preferred display size along largest dimension.
// Grid squares will be sized to fit within this size,
// subject to the Size.Min / Max constraints, which have precedence.
func (t *GridStyle) SetTotalSize(v float32) *GridStyle { t.TotalSize = v; return t }
// SetFontSize sets the [GridStyle.FontSize]:
// FontSize is the font size in standard Dp units for labels.
func (t *GridStyle) SetFontSize(v float32) *GridStyle { t.FontSize = v; return t }
// SetColumnRotation sets the [GridStyle.ColumnRotation]:
// ColumnRotation is the rotation angle in degrees for column labels
func (t *GridStyle) SetColumnRotation(v float32) *GridStyle { t.ColumnRotation = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/tensorcore.Table", IDName: "table", Doc: "Table provides a GUI widget for representing [table.Table] values.", Embeds: []types.Field{{Name: "ListBase"}}, Fields: []types.Field{{Name: "Table", Doc: "Table is the table that we're a view of."}, {Name: "GridStyle", Doc: "GridStyle has global grid display styles. GridStylers on the Table\nare applied to this on top of defaults."}, {Name: "ColumnGridStyle", Doc: "ColumnGridStyle has per column grid display styles."}, {Name: "SortIndex", Doc: "current sort index."}, {Name: "SortDescending", Doc: "whether current sort order is descending."}, {Name: "nCols", Doc: "number of columns in table (as of last update)."}, {Name: "headerWidths", Doc: "headerWidths has number of characters in each header, per visfields."}, {Name: "colMaxWidths", Doc: "colMaxWidths records maximum width in chars of string type fields."}, {Name: "blankString", Doc: "blank values for out-of-range rows."}, {Name: "blankFloat"}, {Name: "blankCells", Doc: "blankCells has per column blank tensor cells."}}})
// NewTable returns a new [Table] with the given optional parent:
// Table provides a GUI widget for representing [table.Table] values.
func NewTable(parent ...tree.Node) *Table { return tree.New[Table](parent...) }
// SetSortIndex sets the [Table.SortIndex]:
// current sort index.
func (t *Table) SetSortIndex(v int) *Table { t.SortIndex = v; return t }
// SetSortDescending sets the [Table.SortDescending]:
// whether current sort order is descending.
func (t *Table) SetSortDescending(v bool) *Table { t.SortDescending = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/tensorcore.TensorEditor", IDName: "tensor-editor", Doc: "TensorEditor provides a GUI widget for representing [tensor.Tensor] values.", Methods: []types.Method{{Name: "SaveCSV", Doc: "SaveTSV writes a tensor to a tab-separated-values (TSV) file.\nOuter-most dims are rows in the file, and inner-most is column --\nReading just grabs all values and doesn't care about shape.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename"}, Returns: []string{"error"}}, {Name: "OpenCSV", Doc: "OpenTSV reads a tensor from a tab-separated-values (TSV) file.\nusing the Go standard encoding/csv reader conforming\nto the official CSV standard.\nReads all values and assigns as many as fit.", Directives: []types.Directive{{Tool: "types", Directive: "add"}}, Args: []string{"filename"}, Returns: []string{"error"}}}, Embeds: []types.Field{{Name: "ListBase"}}, Fields: []types.Field{{Name: "Tensor", Doc: "the tensor that we're a view of"}, {Name: "Layout", Doc: "overall layout options for tensor display"}, {Name: "NCols", Doc: "number of columns in table (as of last update)"}, {Name: "headerWidths", Doc: "headerWidths has number of characters in each header, per visfields"}, {Name: "colMaxWidths", Doc: "colMaxWidths records maximum width in chars of string type fields"}, {Name: "BlankString", Doc: "blank values for out-of-range rows"}, {Name: "BlankFloat"}}})
// NewTensorEditor returns a new [TensorEditor] with the given optional parent:
// TensorEditor provides a GUI widget for representing [tensor.Tensor] values.
func NewTensorEditor(parent ...tree.Node) *TensorEditor { return tree.New[TensorEditor](parent...) }
// SetNCols sets the [TensorEditor.NCols]:
// number of columns in table (as of last update)
func (t *TensorEditor) SetNCols(v int) *TensorEditor { t.NCols = v; return t }
// SetBlankString sets the [TensorEditor.BlankString]:
// blank values for out-of-range rows
func (t *TensorEditor) SetBlankString(v string) *TensorEditor { t.BlankString = v; return t }
// SetBlankFloat sets the [TensorEditor.BlankFloat]
func (t *TensorEditor) SetBlankFloat(v float64) *TensorEditor { t.BlankFloat = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/tensorcore.TensorGrid", IDName: "tensor-grid", Doc: "TensorGrid is a widget that displays tensor values as a grid\nof colored squares. Higher-dimensional data is projected into 2D\nusing [tensor.Projection2DShape] and related functions.", Methods: []types.Method{{Name: "TensorEditor", Doc: "TensorEditor pulls up a TensorEditor of our tensor", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}, {Name: "EditStyle", Directives: []types.Directive{{Tool: "types", Directive: "add"}}}}, Embeds: []types.Field{{Name: "WidgetBase"}}, Fields: []types.Field{{Name: "Tensor", Doc: "Tensor is the tensor that we view."}, {Name: "GridStyle", Doc: "GridStyle has grid display style properties."}, {Name: "ColorMap", Doc: "ColorMap is the colormap displayed (based on)"}, {Name: "RowLabels", Doc: "RowLabels are optional labels for each row of the 2D shape.\nEmpty strings cause grouping with rendered lines."}, {Name: "ColumnLabels", Doc: "ColumnLabels are optional labels for each column of the 2D shape.\nEmpty strings cause grouping with rendered lines."}, {Name: "rowMaxSz"}, {Name: "rowMinBlank"}, {Name: "rowNGps"}, {Name: "colMaxSz"}, {Name: "colMinBlank"}, {Name: "colNGps"}}})
// NewTensorGrid returns a new [TensorGrid] with the given optional parent:
// TensorGrid is a widget that displays tensor values as a grid
// of colored squares. Higher-dimensional data is projected into 2D
// using [tensor.Projection2DShape] and related functions.
func NewTensorGrid(parent ...tree.Node) *TensorGrid { return tree.New[TensorGrid](parent...) }
// SetGridStyle sets the [TensorGrid.GridStyle]:
// GridStyle has grid display style properties.
func (t *TensorGrid) SetGridStyle(v GridStyle) *TensorGrid { t.GridStyle = v; return t }
// SetColorMap sets the [TensorGrid.ColorMap]:
// ColorMap is the colormap displayed (based on)
func (t *TensorGrid) SetColorMap(v *colormap.Map) *TensorGrid { t.ColorMap = v; return t }
// SetRowLabels sets the [TensorGrid.RowLabels]:
// RowLabels are optional labels for each row of the 2D shape.
// Empty strings cause grouping with rendered lines.
func (t *TensorGrid) SetRowLabels(v ...string) *TensorGrid { t.RowLabels = v; return t }
// SetColumnLabels sets the [TensorGrid.ColumnLabels]:
// ColumnLabels are optional labels for each column of the 2D shape.
// Empty strings cause grouping with rendered lines.
func (t *TensorGrid) SetColumnLabels(v ...string) *TensorGrid { t.ColumnLabels = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/tensorcore.TensorButton", IDName: "tensor-button", Doc: "TensorButton represents a Tensor with a button for making a [TensorGrid]\nviewer for an [tensor.Tensor].", Embeds: []types.Field{{Name: "Button"}}, Fields: []types.Field{{Name: "Tensor"}}})
// NewTensorButton returns a new [TensorButton] with the given optional parent:
// TensorButton represents a Tensor with a button for making a [TensorGrid]
// viewer for an [tensor.Tensor].
func NewTensorButton(parent ...tree.Node) *TensorButton { return tree.New[TensorButton](parent...) }
// SetTensor sets the [TensorButton.Tensor]
func (t *TensorButton) SetTensor(v tensor.Tensor) *TensorButton { t.Tensor = v; return t }
var _ = types.AddType(&types.Type{Name: "cogentcore.org/lab/tensorcore.TableButton", IDName: "table-button", Doc: "TableButton presents a button that pulls up the [Table] viewer for a [table.Table].", Embeds: []types.Field{{Name: "Button"}}, Fields: []types.Field{{Name: "Table"}}})
// NewTableButton returns a new [TableButton] with the given optional parent:
// TableButton presents a button that pulls up the [Table] viewer for a [table.Table].
func NewTableButton(parent ...tree.Node) *TableButton { return tree.New[TableButton](parent...) }
// SetTable sets the [TableButton.Table]
func (t *TableButton) SetTable(v *table.Table) *TableButton { t.Table = v; return t }
// Copyright (c) 2019, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensorcore
import (
"cogentcore.org/core/base/metadata"
"cogentcore.org/core/core"
"cogentcore.org/core/icons"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
)
func init() {
core.AddValueType[table.Table, TableButton]()
core.AddValueType[tensor.Float32, TensorButton]()
core.AddValueType[tensor.Float64, TensorButton]()
core.AddValueType[tensor.Int, TensorButton]()
core.AddValueType[tensor.Int32, TensorButton]()
core.AddValueType[tensor.Byte, TensorButton]()
core.AddValueType[tensor.String, TensorButton]()
core.AddValueType[tensor.Bool, TensorButton]()
// core.AddValueType[simat.SimMat, SimMatButton]()
}
// TensorButton represents a Tensor with a button for making a [TensorGrid]
// viewer for an [tensor.Tensor].
type TensorButton struct {
core.Button
Tensor tensor.Tensor
}
func (tb *TensorButton) WidgetValue() any { return &tb.Tensor }
func (tb *TensorButton) Init() {
tb.Button.Init()
tb.SetType(core.ButtonTonal).SetIcon(icons.Edit)
tb.Updater(func() {
text := "None"
if tb.Tensor != nil {
text = "Tensor"
}
tb.SetText(text)
})
core.InitValueButton(tb, true, func(d *core.Body) {
NewTensorGrid(d).SetTensor(tb.Tensor)
})
}
// TableButton presents a button that pulls up the [Table] viewer for a [table.Table].
type TableButton struct {
core.Button
Table *table.Table
}
func (tb *TableButton) WidgetValue() any { return &tb.Table }
func (tb *TableButton) Init() {
tb.Button.Init()
tb.SetType(core.ButtonTonal).SetIcon(icons.Edit)
tb.Updater(func() {
text := "None"
if tb.Table != nil {
if nm, err := metadata.Get[string](tb.Table.Meta, "name"); err == nil {
text = nm
} else {
text = "Table"
}
}
tb.SetText(text)
})
core.InitValueButton(tb, true, func(d *core.Body) {
NewTable(d).SetTable(tb.Table)
})
}
/*
// SimMatValue presents a button that pulls up the [SimMatGrid] viewer for a [table.Table].
type SimMatButton struct {
core.Button
SimMat *simat.SimMat
}
func (tb *SimMatButton) WidgetValue() any { return &tb.SimMat }
func (tb *SimMatButton) Init() {
tb.Button.Init()
tb.SetType(core.ButtonTonal).SetIcon(icons.Edit)
tb.Updater(func() {
text := "None"
if tb.SimMat != nil && tb.SimMat.Mat != nil {
if nm, has := tb.SimMat.Mat.MetaData("name"); has {
text = nm
} else {
text = "SimMat"
}
}
tb.SetText(text)
})
core.InitValueButton(tb, true, func(d *core.Body) {
NewSimMatGrid(d).SetSimMat(tb.SimMat)
})
}
*/
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensorfs
import (
"fmt"
"io"
"io/fs"
"path"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/tensor"
)
var (
// CurDir is the current working directory.
CurDir *Node
// CurRoot is the current root tensorfs system.
// A default root tensorfs is created at startup.
CurRoot *Node
// ListOutput is where to send the output of List commands,
// if non-nil (otherwise os.Stdout).
ListOutput io.Writer
)
func init() {
CurRoot, _ = NewDir("data")
CurDir = CurRoot
}
// Record saves given tensor to current directory with given name.
func Record(tsr tensor.Tensor, name string) {
if CurDir == nil {
CurDir = CurRoot
}
SetTensor(CurDir, tsr, name)
}
// Chdir changes the current working tensorfs directory to the named directory.
func Chdir(dir string) error {
if CurDir == nil {
CurDir = CurRoot
}
if dir == "" {
CurDir = CurRoot
return nil
}
ndir, err := CurDir.DirAtPath(dir)
if err != nil {
return err
}
CurDir = ndir
return nil
}
// Mkdir creates a new directory with the specified name in the current directory.
// It returns an existing directory of the same name without error.
func Mkdir(dir string) *Node {
if CurDir == nil {
CurDir = CurRoot
}
if dir == "" {
err := &fs.PathError{Op: "Mkdir", Path: dir, Err: errors.New("path must not be empty")}
errors.Log(err)
return nil
}
return CurDir.Dir(dir)
}
// List lists files using arguments (options and path) from the current directory.
func List(opts ...string) error {
if CurDir == nil {
CurDir = CurRoot
}
long := false
recursive := false
if len(opts) > 0 && len(opts[0]) > 0 && opts[0][0] == '-' {
op := opts[0]
if strings.Contains(op, "l") {
long = true
}
if strings.Contains(op, "r") {
recursive = true
}
opts = opts[1:]
}
dir := CurDir
if len(opts) > 0 {
nd, err := CurDir.DirAtPath(opts[0])
if err == nil {
dir = nd
}
}
ls := dir.List(long, recursive)
if ListOutput != nil {
fmt.Fprintln(ListOutput, ls)
} else {
fmt.Println(ls)
}
return nil
}
// Get returns the tensor value at given path relative to the
// current working directory.
// This is the direct pointer to the node, so changes
// to it will change the node. Clone the tensor to make
// a new copy disconnected from the original.
func Get(name string) tensor.Tensor {
if CurDir == nil {
CurDir = CurRoot
}
if name == "" {
err := &fs.PathError{Op: "Get", Path: name, Err: errors.New("name must not be empty")}
errors.Log(err)
return nil
}
nd, err := CurDir.NodeAtPath(name)
if errors.Log(err) != nil {
return nil
}
if nd.IsDir() {
err := &fs.PathError{Op: "Get", Path: name, Err: errors.New("node is a directory, not a data node")}
errors.Log(err)
return nil
}
return nd.Tensor
}
// Set sets tensor to given name or path relative to the
// current working directory.
// If the node already exists, its previous tensor is updated to the
// given one; if it doesn't, then a new node is created.
func Set(name string, tsr tensor.Tensor) error {
if CurDir == nil {
CurDir = CurRoot
}
if name == "" {
err := &fs.PathError{Op: "Set", Path: name, Err: errors.New("name must not be empty")}
return errors.Log(err)
}
itm, err := CurDir.NodeAtPath(name)
if err == nil {
if itm.IsDir() {
err := &fs.PathError{Op: "Set", Path: name, Err: errors.New("existing node is a directory, not a data node")}
return errors.Log(err)
}
itm.Tensor = tsr
return nil
}
cd := CurDir
dir, name := path.Split(name)
if dir != "" {
d, err := CurDir.DirAtPath(dir)
if err != nil {
return errors.Log(err)
}
cd = d
}
SetTensor(cd, tsr, name)
return nil
}
// SetCopy sets tensor to given name or path relative to the
// current working directory.
// Unlike [Set], this version saves a [tensor.Clone] of the tensor,
// so future changes to the tensor do not affect this value.
func SetCopy(name string, tsr tensor.Tensor) error {
return Set(name, tensor.Clone(tsr))
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensorfs
import (
"errors"
"io/fs"
"time"
"cogentcore.org/lab/tensor"
)
const (
// Preserve is used for Overwrite flag, indicating to not overwrite and preserve existing.
Preserve = false
// Overwrite is used for Overwrite flag, indicating to overwrite existing.
Overwrite = true
)
// CopyFromValue copies value from given source node, cloning it.
func (d *Node) CopyFromValue(frd *Node) {
d.modTime = time.Now()
d.Tensor = tensor.Clone(frd.Tensor)
}
// Clone returns a copy of this node, recursively cloning directory nodes
// if it is a directory.
func (nd *Node) Clone() *Node {
if !nd.IsDir() {
cp, _ := newNode(nil, nd.name)
cp.Tensor = tensor.Clone(nd.Tensor)
return cp
}
nodes, _ := nd.Nodes()
cp, _ := NewDir(nd.name)
for _, it := range nodes {
cp.Add(it.Clone())
}
return cp
}
// Copy copies node(s) from given paths to given path or directory.
// if there are multiple from nodes, then to must be a directory.
// must be called on a directory node.
func (dir *Node) Copy(overwrite bool, to string, from ...string) error {
if err := dir.mustDir("Copy", to); err != nil {
return err
}
switch {
case to == "":
return &fs.PathError{Op: "Copy", Path: to, Err: errors.New("to location is empty")}
case len(from) == 0:
return &fs.PathError{Op: "Copy", Path: to, Err: errors.New("no from sources specified")}
}
// todo: check for to conflict first here..
tod, _ := dir.NodeAtPath(to)
var errs []error
if len(from) > 1 && tod != nil && !tod.IsDir() {
return &fs.PathError{Op: "Copy", Path: to, Err: errors.New("multiple source nodes requires destination to be a directory")}
}
targd := dir
targf := to
if tod != nil && tod.IsDir() {
targd = tod
targf = ""
}
for _, fr := range from {
opstr := fr + " -> " + to
frd, err := dir.NodeAtPath(fr)
if err != nil {
errs = append(errs, err)
continue
}
if targf == "" {
if trg, ok := targd.nodes.AtTry(frd.name); ok { // target exists
switch {
case trg.IsDir() && frd.IsDir():
// todo: copy all nodes from frd into trg
case trg.IsDir(): // frd is not
errs = append(errs, &fs.PathError{Op: "Copy", Path: opstr, Err: errors.New("cannot copy from Value onto directory of same name")})
case frd.IsDir(): // trg is not
errs = append(errs, &fs.PathError{Op: "Copy", Path: opstr, Err: errors.New("cannot copy from Directory onto Value of same name")})
default: // both nodes
if overwrite { // todo: interactive!?
trg.CopyFromValue(frd)
}
}
continue
}
}
nw := frd.Clone()
if targf != "" {
nw.name = targf
}
targd.Add(nw)
}
return errors.Join(errs...)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensorfs
import (
"fmt"
"io/fs"
"path"
"slices"
"sort"
"strings"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/keylist"
"cogentcore.org/lab/tensor"
)
// Nodes is a map of directory entry names to Nodes.
// It retains the order that nodes were added in, which is
// the natural order nodes are processed in.
type Nodes = keylist.List[string, *Node]
// NewDir returns a new tensorfs directory with the given name.
// If parent != nil and a directory, this dir is added to it.
// If the parent already has an node of that name, it is returned,
// with an [fs.ErrExist] error.
// If the name is empty, then it is set to "root", the root directory.
// Note that "/" is not allowed for the root directory in Go [fs].
// If no parent (i.e., a new root) and CurRoot is nil, then it is set
// to this.
func NewDir(name string, parent ...*Node) (*Node, error) {
if name == "" {
name = "root"
}
var par *Node
if len(parent) == 1 {
par = parent[0]
}
dir, err := newNode(par, name)
if dir != nil && dir.nodes == nil {
dir.nodes = &Nodes{}
}
return dir, err
}
// Dir creates a new directory under given dir with the specified name
// if it doesn't already exist, otherwise returns the existing one.
// Path / slash separators can be used to make a path of multiple directories.
// It logs an error and returns nil if this dir node is not a directory.
func (dir *Node) Dir(name string) *Node {
if err := dir.mustDir("Dir", name); errors.Log(err) != nil {
return nil
}
if len(name) == 0 {
return dir
}
path := strings.Split(name, "/")
if cd := dir.nodes.At(path[0]); cd != nil {
if len(path) > 1 {
return cd.Dir(strings.Join(path[1:], "/"))
}
return cd
}
nd, _ := NewDir(path[0], dir)
if len(path) > 1 {
return nd.Dir(strings.Join(path[1:], "/"))
}
return nd
}
// Node returns a Node in given directory by name.
// This is for fast access and direct usage of known
// nodes, and it will panic if this node is not a directory.
// Returns nil if no node of given name exists.
func (dir *Node) Node(name string) *Node {
return dir.nodes.At(name)
}
// Value returns the [tensor.Tensor] value for given node
// within this directory. This will panic if node is not
// found, and will return nil if it is not a Value
// (i.e., it is a directory).
func (dir *Node) Value(name string) tensor.Tensor {
return dir.nodes.At(name).Tensor
}
// Nodes returns a slice of Nodes in given directory by names variadic list.
// If list is empty, then all nodes in the directory are returned.
// returned error reports any nodes not found, or if not a directory.
func (dir *Node) Nodes(names ...string) ([]*Node, error) {
if err := dir.mustDir("Nodes", ""); err != nil {
return nil, err
}
var nds []*Node
if len(names) == 0 {
for _, it := range dir.nodes.Values {
nds = append(nds, it)
}
return nds, nil
}
var errs []error
for _, nm := range names {
dt := dir.nodes.At(nm)
if dt != nil {
nds = append(nds, dt)
} else {
err := fmt.Errorf("tensorfs Dir %q node not found: %q", dir.Path(), nm)
errs = append(errs, err)
}
}
return nds, errors.Join(errs...)
}
// Values returns a slice of tensor values in the given directory,
// by names variadic list. If list is empty, then all value nodes
// in the directory are returned.
// returned error reports any nodes not found, or if not a directory.
func (dir *Node) Values(names ...string) ([]tensor.Tensor, error) {
if err := dir.mustDir("Values", ""); err != nil {
return nil, err
}
var nds []tensor.Tensor
if len(names) == 0 {
for _, it := range dir.nodes.Values {
if it.Tensor != nil {
nds = append(nds, it.Tensor)
}
}
return nds, nil
}
var errs []error
for _, nm := range names {
it := dir.nodes.At(nm)
if it != nil && it.Tensor != nil {
nds = append(nds, it.Tensor)
} else {
err := fmt.Errorf("tensorfs Dir %q node not found: %q", dir.Path(), nm)
errs = append(errs, err)
}
}
return nds, errors.Join(errs...)
}
// ValuesFunc returns all tensor Values under given directory,
// filtered by given include function, in directory order (e.g., order added),
// recursively descending into directories to return a flat list of
// the entire subtree. The function can filter out directories to prune
// the tree, e.g., using `IsDir` method.
// If func is nil, all Value nodes are returned.
func (dir *Node) ValuesFunc(include func(nd *Node) bool) []tensor.Tensor {
if err := dir.mustDir("ValuesFunc", ""); err != nil {
return nil
}
var nds []tensor.Tensor
for _, it := range dir.nodes.Values {
if include != nil && !include(it) {
continue
}
if it.IsDir() {
subs := it.ValuesFunc(include)
nds = append(nds, subs...)
} else {
nds = append(nds, it.Tensor)
}
}
return nds
}
// NodesFunc returns leaf Nodes under given directory, filtered by
// given include function, recursively descending into directories
// to return a flat list of the entire subtree, in directory order
// (e.g., order added).
// The function can filter out directories to prune the tree.
// If func is nil, all leaf Nodes are returned.
func (dir *Node) NodesFunc(include func(nd *Node) bool) []*Node {
if err := dir.mustDir("NodesFunc", ""); err != nil {
return nil
}
var nds []*Node
for _, it := range dir.nodes.Values {
if include != nil && !include(it) {
continue
}
if it.IsDir() {
subs := it.NodesFunc(include)
nds = append(nds, subs...)
} else {
nds = append(nds, it)
}
}
return nds
}
// ValuesAlphaFunc returns all Value nodes (tensors) in given directory,
// recursively descending into directories to return a flat list of
// the entire subtree, filtered by given function, with nodes at each
// directory level traversed in alphabetical order.
// The function can filter out directories to prune the tree.
// If func is nil, all Values are returned.
func (dir *Node) ValuesAlphaFunc(include func(nd *Node) bool) []tensor.Tensor {
if err := dir.mustDir("ValuesAlphaFunc", ""); err != nil {
return nil
}
names := dir.dirNamesAlpha()
var nds []tensor.Tensor
for _, nm := range names {
it := dir.nodes.At(nm)
if include != nil && !include(it) {
continue
}
if it.IsDir() {
subs := it.ValuesAlphaFunc(include)
nds = append(nds, subs...)
} else {
nds = append(nds, it.Tensor)
}
}
return nds
}
// NodesAlphaFunc returns leaf nodes under given directory, filtered
// by given include function, with nodes at each directory level
// traversed in alphabetical order, recursively descending into directories
// to return a flat list of the entire subtree, in directory order
// (e.g., order added).
// The function can filter out directories to prune the tree.
// If func is nil, all leaf Nodes are returned.
func (dir *Node) NodesAlphaFunc(include func(nd *Node) bool) []*Node {
if err := dir.mustDir("NodesAlphaFunc", ""); err != nil {
return nil
}
names := dir.dirNamesAlpha()
var nds []*Node
for _, nm := range names {
it := dir.nodes.At(nm)
if include != nil && !include(it) {
continue
}
if it.IsDir() {
subs := it.NodesAlphaFunc(include)
nds = append(nds, subs...)
} else {
nds = append(nds, it)
}
}
return nds
}
// todo: these must handle going up the tree using ..
// DirAtPath returns directory at given relative path
// from this starting dir.
func (dir *Node) DirAtPath(dirPath string) (*Node, error) {
var err error
dirPath = path.Clean(dirPath)
sdf, err := dir.Sub(dirPath) // this ensures that dir is a dir
if err != nil {
return nil, err
}
return sdf.(*Node), nil
}
// NodeAtPath returns node at given relative path from this starting dir.
func (dir *Node) NodeAtPath(name string) (*Node, error) {
if err := dir.mustDir("NodeAtPath", name); err != nil {
return nil, err
}
if !fs.ValidPath(name) {
return nil, &fs.PathError{Op: "NodeAtPath", Path: name, Err: errors.New("invalid path")}
}
dirPath, file := path.Split(name)
sd, err := dir.DirAtPath(dirPath)
if err != nil {
return nil, err
}
nd, ok := sd.nodes.AtTry(file)
if !ok {
if dirPath == "" && (file == dir.name || file == ".") {
return dir, nil
}
return nil, &fs.PathError{Op: "NodeAtPath", Path: name, Err: errors.New("file not found")}
}
return nd, nil
}
// Path returns the full path to this data node
func (dir *Node) Path() string {
pt := dir.name
cur := dir.Parent
loops := make(map[*Node]struct{})
for {
if cur == nil {
return pt
}
if _, ok := loops[cur]; ok {
return pt
}
pt = path.Join(cur.name, pt)
loops[cur] = struct{}{}
cur = cur.Parent
}
}
// dirNamesAlpha returns the names of nodes in the directory
// sorted alphabetically. Node must be dir by this point.
func (dir *Node) dirNamesAlpha() []string {
names := slices.Clone(dir.nodes.Keys)
sort.Strings(names)
return names
}
// dirNamesByTime returns the names of nodes in the directory
// sorted by modTime. Node must be dir by this point.
func (dir *Node) dirNamesByTime() []string {
names := slices.Clone(dir.nodes.Keys)
slices.SortFunc(names, func(a, b string) int {
return dir.nodes.At(a).ModTime().Compare(dir.nodes.At(b).ModTime())
})
return names
}
// mustDir returns an error for given operation and path
// if this data node is not a directory.
func (dir *Node) mustDir(op, path string) error {
if !dir.IsDir() {
return &fs.PathError{Op: op, Path: path, Err: errors.New("tensorfs node is not a directory")}
}
return nil
}
// Add adds an node to this directory data node.
// The only errors are if this node is not a directory,
// or the name already exists, in which case an [fs.ErrExist] is returned.
// Names must be unique within a directory.
func (dir *Node) Add(it *Node) error {
if err := dir.mustDir("Add", it.name); err != nil {
return err
}
err := dir.nodes.Add(it.name, it)
if err != nil {
return fs.ErrExist
}
return nil
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensorfs
import (
"bytes"
"io"
"io/fs"
)
// File represents a data item for reading, as an [fs.File].
// All io functionality is handled by [bytes.Reader].
type File struct {
bytes.Reader
Node *Node
dirEntries []fs.DirEntry
dirsRead int
}
func (f *File) Stat() (fs.FileInfo, error) {
return f.Node, nil
}
func (f *File) Close() error {
f.Reader.Reset(f.Node.Bytes())
return nil
}
// DirFile represents a directory data item for reading, as fs.ReadDirFile.
type DirFile struct {
File
dirEntries []fs.DirEntry
dirsRead int
}
func (f *DirFile) ReadDir(n int) ([]fs.DirEntry, error) {
if err := f.Node.mustDir("DirFile:ReadDir", ""); err != nil {
return nil, err
}
if f.dirEntries == nil {
f.dirEntries, _ = f.Node.ReadDir(".")
f.dirsRead = 0
}
ne := len(f.dirEntries)
if n <= 0 {
if f.dirsRead >= ne {
return nil, nil
}
re := f.dirEntries[f.dirsRead:]
f.dirsRead = ne
return re, nil
}
if f.dirsRead >= ne {
return nil, io.EOF
}
mx := min(n+f.dirsRead, ne)
re := f.dirEntries[f.dirsRead:mx]
f.dirsRead = mx
return re, nil
}
func (f *DirFile) Close() error {
f.dirsRead = 0
return nil
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensorfs
import (
"bytes"
"errors"
"io/fs"
"slices"
"time"
"cogentcore.org/core/base/fileinfo"
"cogentcore.org/core/base/fsx"
)
// fs.go contains all the io/fs interface implementations, and other fs functionality.
// Open opens the given node at given path within this tensorfs filesystem.
func (nd *Node) Open(name string) (fs.File, error) {
itm, err := nd.NodeAtPath(name)
if err != nil {
return nil, err
}
if itm.IsDir() {
return &DirFile{File: File{Reader: *bytes.NewReader(itm.Bytes()), Node: itm}}, nil
}
return &File{Reader: *bytes.NewReader(itm.Bytes()), Node: itm}, nil
}
// Stat returns a FileInfo describing the file.
// If there is an error, it should be of type *PathError.
func (nd *Node) Stat(name string) (fs.FileInfo, error) {
return nd.NodeAtPath(name)
}
// Sub returns a data FS corresponding to the subtree rooted at dir.
func (nd *Node) Sub(dir string) (fs.FS, error) {
if err := nd.mustDir("Sub", dir); err != nil {
return nil, err
}
// todo: this does not allow .. expressions, so we can't use it:
// if !fs.ValidPath(dir) {
// return nil, &fs.PathError{Op: "Sub", Path: dir, Err: errors.New("invalid path")}
// }
if dir == "." || dir == "" || dir == nd.name { // todo: this last condition seems bad.
// need tests
return nd, nil
}
cd := dir
cur := nd
root, rest := fsx.SplitRootPathFS(dir)
if root == "." || root == nd.name {
cd = rest
}
for {
if cd == "." || cd == "" {
return cur, nil
}
root, rest := fsx.SplitRootPathFS(cd)
if root == "." && rest == "" {
return cur, nil
}
cd = rest
if root == ".." {
if cur.Parent != nil {
cur = cur.Parent
continue
} else {
return nil, &fs.PathError{Op: "Sub", Path: dir, Err: errors.New("already at root")}
}
}
sd, ok := cur.nodes.AtTry(root)
if !ok {
return nil, &fs.PathError{Op: "Sub", Path: dir, Err: errors.New("directory not found")}
}
if !sd.IsDir() {
return nil, &fs.PathError{Op: "Sub", Path: dir, Err: errors.New("is not a directory")}
}
cur = sd
}
}
// ReadDir returns the contents of the given directory within this filesystem.
// Use "." (or "") to refer to the current directory.
func (nd *Node) ReadDir(dir string) ([]fs.DirEntry, error) {
sd, err := nd.DirAtPath(dir)
if err != nil {
return nil, err
}
names := sd.dirNamesAlpha()
ents := make([]fs.DirEntry, len(names))
for i, nm := range names {
ents[i] = sd.nodes.At(nm)
}
return ents, nil
}
// ReadFile reads the named file and returns its contents.
// A successful call returns a nil error, not io.EOF.
// (Because ReadFile reads the whole file, the expected EOF
// from the final Read is not treated as an error to be reported.)
//
// The caller is permitted to modify the returned byte slice.
// This method should return a copy of the underlying data.
func (nd *Node) ReadFile(name string) ([]byte, error) {
itm, err := nd.NodeAtPath(name)
if err != nil {
return nil, err
}
if itm.IsDir() {
return nil, &fs.PathError{Op: "ReadFile", Path: name, Err: errors.New("Node is a directory")}
}
return slices.Clone(itm.Bytes()), nil
}
//////// FileInfo interface:
func (nd *Node) Name() string { return nd.name }
// Size returns the size of known data Values, or it uses
// the Sizer interface, otherwise returns 0.
func (nd *Node) Size() int64 {
if nd.Tensor == nil {
return 0
}
return nd.Tensor.AsValues().Sizeof()
}
func (nd *Node) IsDir() bool {
return nd.nodes != nil
}
func (nd *Node) ModTime() time.Time {
return nd.modTime
}
func (nd *Node) Mode() fs.FileMode {
if nd.IsDir() {
return 0755 | fs.ModeDir
}
return 0444
}
// Sys returns the Dir or Value
func (nd *Node) Sys() any {
if nd.Tensor != nil {
return nd.Tensor
}
return nd.nodes
}
//////// DirEntry interface
func (nd *Node) Type() fs.FileMode {
return nd.Mode().Type()
}
func (nd *Node) Info() (fs.FileInfo, error) {
return nd, nil
}
//////// Misc
func (nd *Node) KnownFileInfo() fileinfo.Known {
if nd.Tensor == nil {
return fileinfo.Unknown
}
tsr := nd.Tensor
if tsr.Len() > 1 {
return fileinfo.Tensor
}
// scalars by type
if tsr.IsString() {
return fileinfo.String
}
return fileinfo.Number
}
// Bytes returns the byte-wise representation of the data Value.
// This is the actual underlying data, so make a copy if it can be
// unintentionally modified or retained more than for immediate use.
func (nd *Node) Bytes() []byte {
if nd.Tensor == nil || nd.Tensor.NumDims() == 0 || nd.Tensor.Len() == 0 {
return nil
}
return nd.Tensor.AsValues().Bytes()
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensorfs
import (
"strings"
"cogentcore.org/core/base/indent"
)
const (
// Short is used as a named arg for the [Node.List] method
// for a short, name-only listing, vs. [Long].
Short = false
// Long is used as a named arg for the [Node.List] method
// for a long, name and size listing, vs. [Short].
Long = true
// DirOnly is used as a named arg for the [Node.List] method
// for only listing the current directory, vs. [Recursive].
DirOnly = false
// Recursive is used as a named arg for the [Node.List] method
// for listing all directories recursively, vs. [DirOnly].
Recursive = true
)
func (nd *Node) String() string {
if !nd.IsDir() {
lb := nd.Tensor.Label()
if !strings.HasPrefix(lb, nd.name) {
lb = nd.name + " " + lb
}
return lb
}
return nd.List(Short, DirOnly)
}
// ListAll returns a Long, Recursive listing of nodes in the given directory.
func (dir *Node) ListAll() string {
return dir.listLong(true, 0)
}
// List returns a listing of nodes in the given directory.
// - long = include detailed information about each node, vs just the name.
// - recursive = descend into subdirectories.
func (dir *Node) List(long, recursive bool) string {
if long {
return dir.listLong(recursive, 0)
}
return dir.listShort(recursive, 0)
}
// listShort returns a name-only listing of given directory.
func (dir *Node) listShort(recursive bool, ident int) string {
var b strings.Builder
nodes, _ := dir.Nodes()
for _, it := range nodes {
b.WriteString(indent.Tabs(ident))
if it.IsDir() {
if recursive {
b.WriteString("\n" + it.listShort(recursive, ident+1))
} else {
b.WriteString(it.name + "/ ")
}
} else {
b.WriteString(it.name + " ")
}
}
return b.String()
}
// listLong returns a detailed listing of given directory.
func (dir *Node) listLong(recursive bool, ident int) string {
var b strings.Builder
nodes, _ := dir.Nodes()
for _, it := range nodes {
b.WriteString(indent.Tabs(ident))
if it.IsDir() {
b.WriteString(it.name + "/\n")
if recursive {
b.WriteString(it.listLong(recursive, ident+1))
}
} else {
b.WriteString(it.String() + "\n")
}
}
return b.String()
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensorfs
import (
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/tensor"
)
// This file provides standardized metadata options for frequent
// use cases, using codified key names to eliminate typos.
// SetMetaItems sets given metadata for Value items in given directory
// with given names. Returns error for any items not found.
func (d *Node) SetMetaItems(key string, value any, names ...string) error {
tsrs, err := d.Values(names...)
for _, tsr := range tsrs {
tsr.Metadata().Set(key, value)
}
return err
}
// CalcAll calls function set by [Node.SetCalcFunc] for all items
// in this directory and all of its subdirectories.
// Calls Calc on items from ValuesFunc(nil)
func (d *Node) CalcAll() error {
var errs []error
items := d.ValuesFunc(nil)
for _, it := range items {
err := tensor.Calc(it)
if err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensorfs
import (
"io/fs"
"path"
"reflect"
"time"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/fsx"
"cogentcore.org/core/base/metadata"
"cogentcore.org/lab/table"
"cogentcore.org/lab/tensor"
)
// Node is the element type for the filesystem, which can represent either
// a [tensor] Value as a "file" equivalent, or a "directory" containing other Nodes.
// The [tensor.Tensor] can represent everything from a single scalar value up to
// n-dimensional collections of patterns, in a range of data types.
// Directories have an ordered map of nodes.
type Node struct {
// Parent is the parent data directory.
Parent *Node
// name is the name of this node. it is not a path.
name string
// modTime tracks time added to directory, used for ordering.
modTime time.Time
// Tensor is the tensor value for a file or leaf Node in the FS,
// represented using the universal [tensor] data type of
// [tensor.Tensor], which can represent anything from a scalar
// to n-dimensional data, in a range of data types.
Tensor tensor.Tensor
// nodes is for directory nodes, with all the nodes in the directory.
nodes *Nodes
// DirTable is a summary [table.Table] with columns comprised of Value
// nodes in the directory, which can be used for plotting or other operations.
DirTable *table.Table
}
// newNode returns a new Node in given directory Node, which can be nil.
// If dir is not a directory, returns nil and an error.
// If an node already exists in dir with that name, that node is returned
// with an [fs.ErrExist] error, and the caller can decide how to proceed.
// The modTime is set to now. The name must be unique within parent.
func newNode(dir *Node, name string) (*Node, error) {
if dir == nil {
return &Node{name: name, modTime: time.Now()}, nil
}
if err := dir.mustDir("newNode", name); err != nil {
return nil, err
}
if ex, ok := dir.nodes.AtTry(name); ok {
return ex, fs.ErrExist
}
d := &Node{Parent: dir, name: name, modTime: time.Now()}
dir.nodes.Add(name, d)
return d, nil
}
// Value creates / returns a Node with given name as a [tensor.Tensor]
// of given data type and shape sizes, in given directory Node.
// If it already exists, it is returned as-is (no checking against the
// type or sizes provided, for efficiency -- if there is doubt, check!),
// otherwise a new tensor is created. It is fine to not pass any sizes and
// use `SetShapeSizes` method later to set the size.
func Value[T tensor.DataTypes](dir *Node, name string, sizes ...int) tensor.Values {
it := dir.Node(name)
if it != nil {
return it.Tensor.(tensor.Values)
}
tsr := tensor.New[T](sizes...)
metadata.SetName(tsr, name)
nd, err := newNode(dir, name)
if errors.Log(err) != nil {
return nil
}
nd.Tensor = tsr
return tsr
}
// NewValues makes new tensor Node value(s) (as a [tensor.Tensor])
// of given data type and shape sizes, in given directory.
// Any existing nodes with the same names are recycled without checking
// or updating the data type or sizes.
// See the [Value] documentation for more info.
func NewValues[T tensor.DataTypes](dir *Node, shape []int, names ...string) {
for _, nm := range names {
Value[T](dir, nm, shape...)
}
}
// Scalar returns a scalar Node value (as a [tensor.Tensor])
// of given data type, in given directory and name.
// If it already exists, it is returned without checking against args,
// else a new one is made. See the [Value] documentation for more info.
func Scalar[T tensor.DataTypes](dir *Node, name string) tensor.Values {
return Value[T](dir, name, 1)
}
// ValueType creates / returns a Node with given name as a [tensor.Tensor]
// of given data type specified as a reflect.Kind, with shape sizes,
// in given directory Node.
// Supported types are string, bool (for [Bool]), float32, float64, int, int32, and byte.
// If it already exists, it is returned as-is (no checking against the
// type or sizes provided, for efficiency -- if there is doubt, check!),
// otherwise a new tensor is created. It is fine to not pass any sizes and
// use `SetShapeSizes` method later to set the size.
func ValueType(dir *Node, name string, typ reflect.Kind, sizes ...int) tensor.Values {
it := dir.Node(name)
if it != nil {
return it.Tensor.(tensor.Values)
}
tsr := tensor.NewOfType(typ, sizes...)
metadata.SetName(tsr, name)
nd, err := newNode(dir, name)
if errors.Log(err) != nil {
return nil
}
nd.Tensor = tsr
return tsr
}
// SetTensor creates / recycles a node and sets to given existing tensor with given name.
func SetTensor(dir *Node, tsr tensor.Tensor, name string) *Node {
nd := dir.Node(name)
if nd != nil {
nd.Tensor = tsr
} else {
nd, _ = newNode(dir, name)
}
nd.Tensor = tsr
return nd
}
// Set creates / returns a Node with given name setting value to given Tensor,
// in given directory [Node]. Calls [SetTensor].
func (dir *Node) Set(name string, tsr tensor.Tensor) *Node {
return SetTensor(dir, tsr, name)
}
// DirTable returns a [table.Table] with all of the tensor values under
// the given directory, with columns as the Tensor values elements in the directory
// and any subdirectories, using given filter function.
// This is a convenient mechanism for creating a plot of all the data
// in a given directory.
// If such was previously constructed, it is returned from "DirTable"
// where it is stored for later use.
// Row count is updated to current max row.
// Set DirTable = nil to regenerate.
func DirTable(dir *Node, fun func(node *Node) bool) *table.Table {
nds := dir.NodesFunc(fun)
if dir.DirTable != nil {
if dir.DirTable.NumColumns() == len(nds) {
dir.DirTable.SetNumRowsToMax()
return dir.DirTable
}
}
dt := table.New(fsx.DirAndFile(string(dir.Path())))
for _, it := range nds {
if it.Tensor == nil || it.Tensor.NumDims() == 0 {
continue
}
tsr := it.Tensor
rows := tsr.DimSize(0)
if dt.Columns.Rows < rows {
dt.Columns.Rows = rows
dt.SetNumRows(dt.Columns.Rows)
}
nm := it.name
if it.Parent != dir {
nm = fsx.DirAndFile(string(it.Path()))
}
dt.AddColumn(nm, tsr.AsValues())
}
dir.DirTable = dt
return dt
}
// DirFromTable sets tensor values under given directory node to the
// columns of the given [table.Table]. Also sets the DirTable to this table.
func DirFromTable(dir *Node, dt *table.Table) {
for i, cl := range dt.Columns.Values {
nm := dt.Columns.Keys[i]
dr, fn := path.Split(nm)
pdir := dir
if dr != "" {
dr = path.Dir(nm)
pdir = dir.Dir(dr)
}
nd, err := newNode(pdir, fn)
if err == nil || err == fs.ErrExist {
metadata.SetName(cl, fn)
nd.Tensor = cl
}
}
dir.DirTable = dt
}
// Float64 creates / returns a Node with given name as a [tensor.Float64]
// for given shape sizes, in given directory [Node].
// See [Values] function for more info.
func (dir *Node) Float64(name string, sizes ...int) *tensor.Float64 {
return Value[float64](dir, name, sizes...).(*tensor.Float64)
}
// Float32 creates / returns a Node with given name as a [tensor.Float32]
// for given shape sizes, in given directory [Node].
// See [Values] function for more info.
func (dir *Node) Float32(name string, sizes ...int) *tensor.Float32 {
return Value[float32](dir, name, sizes...).(*tensor.Float32)
}
// Int creates / returns a Node with given name as a [tensor.Int]
// for given shape sizes, in given directory [Node].
// See [Values] function for more info.
func (dir *Node) Int(name string, sizes ...int) *tensor.Int {
return Value[int](dir, name, sizes...).(*tensor.Int)
}
// StringValue creates / returns a Node with given name as a [tensor.String]
// for given shape sizes, in given directory [Node].
// See [Values] function for more info.
func (dir *Node) StringValue(name string, sizes ...int) *tensor.String {
return Value[string](dir, name, sizes...).(*tensor.String)
}
// Copyright (c) 2025, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tensorfs
import (
"archive/tar"
"compress/gzip"
"io"
"io/fs"
"path"
"path/filepath"
"time"
"cogentcore.org/core/base/errors"
"cogentcore.org/lab/tensor"
)
// AllFiles returns all file names within given directory, including subdirectory,
// excluding those matching given glob expressions. Files are relative to dir,
// and do not include the full path.
func AllFiles(dir string, exclude ...string) ([]string, error) {
var files []string
err := filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if !d.Type().IsRegular() {
return nil
}
for _, ex := range exclude {
if errors.Log1(filepath.Match(ex, path)) {
return nil
}
}
files = append(files, path)
return nil
})
return files, err
}
// note: Tar code helped significantly by Steve Domino examples:
// https://medium.com/@skdomino/taring-untaring-files-in-go-6b07cf56bc07
// Tar writes a tar file to given writer, from given source directory,
// using given include function to select nodes to include (all if nil).
// If gz is true, then tar is gzipped.
// The tensor data is written using the [tensor.ToBinary] format, so the
// files are effectively opaque binary files.
func Tar(w io.Writer, dir *Node, gz bool, include func(nd *Node) bool) error {
ow := w
if gz {
gzw := gzip.NewWriter(w)
defer gzw.Close()
ow = gzw
}
tw := tar.NewWriter(ow)
defer tw.Close()
return tarWrite(tw, dir, "", include)
}
func tarWrite(w *tar.Writer, dir *Node, parPath string, include func(nd *Node) bool) error {
var errs []error
for _, it := range dir.nodes.Values {
if include != nil && !include(it) {
continue
}
if it.IsDir() {
tarWrite(w, it, path.Join(parPath, it.name), include)
continue
}
vtsr := it.Tensor.AsValues()
b := tensor.ToBinary(vtsr)
fname := path.Join(parPath, it.name)
now := time.Now()
hdr := &tar.Header{
Name: fname,
Mode: 0666,
Size: int64(len(b)),
Format: tar.FormatPAX,
ModTime: now,
AccessTime: now,
ChangeTime: now,
}
if err := w.WriteHeader(hdr); err != nil {
errs = append(errs, err)
break
}
_, err := w.Write(b)
if err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}
// Untar extracts a tar file from given reader, into given directory node.
// If gz is true, then tar is gzipped.
func Untar(r io.Reader, dir *Node, gz bool) error {
or := r
if gz {
gzr, err := gzip.NewReader(r)
if err != nil {
return err
}
or = gzr
defer gzr.Close()
}
tr := tar.NewReader(or)
var errs []error
addErr := func(err error) error { // if != nil, return
if err == nil {
return nil
}
errs = append(errs, err)
if len(errs) > 10 {
return errors.Join(errs...)
}
return nil
}
for {
hdr, err := tr.Next()
switch {
case err == io.EOF:
return errors.Join(errs...)
case err != nil:
if allErr := addErr(err); allErr != nil {
return allErr
}
continue
case hdr == nil:
continue
}
fname := hdr.Name
switch hdr.Typeflag {
case tar.TypeDir:
dir.Dir(fname)
case tar.TypeReg:
b := make([]byte, hdr.Size)
_, err := tr.Read(b)
if err != nil && err != io.EOF {
if allErr := addErr(err); allErr != nil {
return allErr
}
continue
}
dr, fn := path.Split(fname)
pdir := dir
if dr != "" {
dr = path.Dir(fname)
pdir = dir.Dir(dr)
}
nd, err := newNode(pdir, fn)
if err != nil && err != fs.ErrExist {
if allErr := addErr(err); allErr != nil {
return allErr
}
continue
}
nd.Tensor = tensor.FromBinary(b)
}
}
return errors.Join(errs...)
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package vector provides standard vector math functions that
// always operate on 1D views of tensor inputs regardless of the original
// vector shape.
package vector
import (
"math"
"cogentcore.org/lab/tensor"
"cogentcore.org/lab/tensor/tmath"
)
// Mul multiplies two vectors element-wise, using a 1D vector
// view of the two vectors, returning the output 1D vector.
func Mul(a, b tensor.Tensor) tensor.Values {
return tensor.CallOut2Float64(MulOut, a, b)
}
// MulOut multiplies two vectors element-wise, using a 1D vector
// view of the two vectors, filling in values in the output 1D vector.
func MulOut(a, b tensor.Tensor, out tensor.Values) error {
return tmath.MulOut(tensor.As1D(a), tensor.As1D(b), out)
}
// Sum returns the sum of all values in the tensor, as a scalar.
func Sum(a tensor.Tensor) tensor.Values {
n := a.Len()
sum := 0.0
tensor.Vectorize(func(tsr ...tensor.Tensor) int { return n },
func(idx int, tsr ...tensor.Tensor) {
sum += tsr[0].Float1D(idx)
}, a)
return tensor.NewFloat64Scalar(sum)
}
// Dot performs the vector dot product: the [Sum] of the [Mul] product
// of the two tensors, returning a scalar value. Also known as the inner product.
func Dot(a, b tensor.Tensor) tensor.Values {
return Sum(Mul(a, b))
}
// L2Norm returns the length of the vector as the L2 Norm:
// square root of the sum of squared values of the vector, as a scalar.
// This is the Sqrt of the [Dot] product of the vector with itself.
func L2Norm(a tensor.Tensor) tensor.Values {
dot := Dot(a, a).Float1D(0)
return tensor.NewFloat64Scalar(math.Sqrt(dot))
}
// L1Norm returns the length of the vector as the L1 Norm:
// sum of the absolute values of the tensor, as a scalar.
func L1Norm(a tensor.Tensor) tensor.Values {
n := a.Len()
sum := 0.0
tensor.Vectorize(func(tsr ...tensor.Tensor) int { return n },
func(idx int, tsr ...tensor.Tensor) {
sum += math.Abs(tsr[0].Float1D(idx))
}, a)
return tensor.NewFloat64Scalar(sum)
}
// Code generated by 'yaegi extract cogentcore.org/lab/lab'. DO NOT EDIT.
package labsymbols
import (
"cogentcore.org/core/core"
"cogentcore.org/lab/lab"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/lab/lab"] = map[string]reflect.Value{
// function, constant and variable definitions
"AsDataTree": reflect.ValueOf(lab.AsDataTree),
"DirAndFileNoSlash": reflect.ValueOf(lab.DirAndFileNoSlash),
"FirstComment": reflect.ValueOf(lab.FirstComment),
"IsTableFile": reflect.ValueOf(lab.IsTableFile),
"Lab": reflect.ValueOf(&lab.Lab).Elem(),
"LabBrowser": reflect.ValueOf(&lab.LabBrowser).Elem(),
"NewBasic": reflect.ValueOf(lab.NewBasic),
"NewBasicWindow": reflect.ValueOf(lab.NewBasicWindow),
"NewDataTree": reflect.ValueOf(lab.NewDataTree),
"NewDiffBrowserDirs": reflect.ValueOf(lab.NewDiffBrowserDirs),
"NewFileNode": reflect.ValueOf(lab.NewFileNode),
"NewPlot": reflect.ValueOf(lab.NewPlot),
"NewPlotFrom": reflect.ValueOf(lab.NewPlotFrom),
"NewPlotWidget": reflect.ValueOf(lab.NewPlotWidget),
"NewTabs": reflect.ValueOf(lab.NewTabs),
"PromptOKCancel": reflect.ValueOf(lab.PromptOKCancel),
"PromptString": reflect.ValueOf(lab.PromptString),
"PromptStruct": reflect.ValueOf(lab.PromptStruct),
"RunScript": reflect.ValueOf(&lab.RunScript).Elem(),
"RunScriptCode": reflect.ValueOf(&lab.RunScriptCode).Elem(),
"TensorFS": reflect.ValueOf(lab.TensorFS),
"TrimOrderPrefix": reflect.ValueOf(lab.TrimOrderPrefix),
// type definitions
"Basic": reflect.ValueOf((*lab.Basic)(nil)),
"Browser": reflect.ValueOf((*lab.Browser)(nil)),
"DataTree": reflect.ValueOf((*lab.DataTree)(nil)),
"FileNode": reflect.ValueOf((*lab.FileNode)(nil)),
"Tabber": reflect.ValueOf((*lab.Tabber)(nil)),
"Tabs": reflect.ValueOf((*lab.Tabs)(nil)),
"Treer": reflect.ValueOf((*lab.Treer)(nil)),
// interface wrapper definitions
"_Tabber": reflect.ValueOf((*_cogentcore_org_lab_lab_Tabber)(nil)),
"_Treer": reflect.ValueOf((*_cogentcore_org_lab_lab_Treer)(nil)),
}
}
// _cogentcore_org_lab_lab_Tabber is an interface wrapper for Tabber type
type _cogentcore_org_lab_lab_Tabber struct {
IValue interface{}
WAsCoreTabs func() *core.Tabs
WAsLab func() *lab.Tabs
}
func (W _cogentcore_org_lab_lab_Tabber) AsCoreTabs() *core.Tabs { return W.WAsCoreTabs() }
func (W _cogentcore_org_lab_lab_Tabber) AsLab() *lab.Tabs { return W.WAsLab() }
// _cogentcore_org_lab_lab_Treer is an interface wrapper for Treer type
type _cogentcore_org_lab_lab_Treer struct {
IValue interface{}
WAsDataTree func() *lab.DataTree
}
func (W _cogentcore_org_lab_lab_Treer) AsDataTree() *lab.DataTree { return W.WAsDataTree() }
// Code generated by 'yaegi extract cogentcore.org/lab/physics/builder'. DO NOT EDIT.
package labsymbols
import (
"cogentcore.org/lab/physics/builder"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/physics/builder/builder"] = map[string]reflect.Value{
// function, constant and variable definitions
"MakePoseToolbar": reflect.ValueOf(builder.MakePoseToolbar),
"NewBuilder": reflect.ValueOf(builder.NewBuilder),
// type definitions
"Body": reflect.ValueOf((*builder.Body)(nil)),
"Builder": reflect.ValueOf((*builder.Builder)(nil)),
"Controls": reflect.ValueOf((*builder.Controls)(nil)),
"DoF": reflect.ValueOf((*builder.DoF)(nil)),
"Joint": reflect.ValueOf((*builder.Joint)(nil)),
"Object": reflect.ValueOf((*builder.Object)(nil)),
"Physics": reflect.ValueOf((*builder.Physics)(nil)),
"Pose": reflect.ValueOf((*builder.Pose)(nil)),
"World": reflect.ValueOf((*builder.World)(nil)),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/physics/phyxyz'. DO NOT EDIT.
package labsymbols
import (
"cogentcore.org/lab/physics/phyxyz"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/physics/phyxyz/phyxyz"] = map[string]reflect.Value{
// function, constant and variable definitions
"DepthImage": reflect.ValueOf(phyxyz.DepthImage),
"DepthNorm": reflect.ValueOf(phyxyz.DepthNorm),
"NewEditor": reflect.ValueOf(phyxyz.NewEditor),
"NewScene": reflect.ValueOf(phyxyz.NewScene),
"NoDisplayScene": reflect.ValueOf(phyxyz.NoDisplayScene),
// type definitions
"Camera": reflect.ValueOf((*phyxyz.Camera)(nil)),
"Editor": reflect.ValueOf((*phyxyz.Editor)(nil)),
"Scene": reflect.ValueOf((*phyxyz.Scene)(nil)),
"Skin": reflect.ValueOf((*phyxyz.Skin)(nil)),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/physics'. DO NOT EDIT.
package labsymbols
import (
"cogentcore.org/lab/physics"
"go/constant"
"go/token"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/physics/physics"] = map[string]reflect.Value{
// function, constant and variable definitions
"AddBroadContacts": reflect.ValueOf(physics.AddBroadContacts),
"AngularAccelAt": reflect.ValueOf(physics.AngularAccelAt),
"AngularCorrection": reflect.ValueOf(physics.AngularCorrection),
"AngularVelocityAt": reflect.ValueOf(physics.AngularVelocityAt),
"Ball": reflect.ValueOf(physics.Ball),
"Bodies": reflect.ValueOf(&physics.Bodies).Elem(),
"BodiesVar": reflect.ValueOf(physics.BodiesVar),
"BodyBounce": reflect.ValueOf(physics.BodyBounce),
"BodyCollidePairs": reflect.ValueOf(&physics.BodyCollidePairs).Elem(),
"BodyCollidePairsVar": reflect.ValueOf(physics.BodyCollidePairsVar),
"BodyCom": reflect.ValueOf(physics.BodyCom),
"BodyComX": reflect.ValueOf(physics.BodyComX),
"BodyComY": reflect.ValueOf(physics.BodyComY),
"BodyComZ": reflect.ValueOf(physics.BodyComZ),
"BodyDynamic": reflect.ValueOf(physics.BodyDynamic),
"BodyDynamicPos": reflect.ValueOf(physics.BodyDynamicPos),
"BodyDynamicQuat": reflect.ValueOf(physics.BodyDynamicQuat),
"BodyFriction": reflect.ValueOf(physics.BodyFriction),
"BodyFrictionRolling": reflect.ValueOf(physics.BodyFrictionRolling),
"BodyFrictionTortion": reflect.ValueOf(physics.BodyFrictionTortion),
"BodyGroup": reflect.ValueOf(physics.BodyGroup),
"BodyHSize": reflect.ValueOf(physics.BodyHSize),
"BodyHSizeX": reflect.ValueOf(physics.BodyHSizeX),
"BodyHSizeY": reflect.ValueOf(physics.BodyHSizeY),
"BodyHSizeZ": reflect.ValueOf(physics.BodyHSizeZ),
"BodyInertia": reflect.ValueOf(physics.BodyInertia),
"BodyInertiaXX": reflect.ValueOf(physics.BodyInertiaXX),
"BodyInertiaXY": reflect.ValueOf(physics.BodyInertiaXY),
"BodyInertiaXZ": reflect.ValueOf(physics.BodyInertiaXZ),
"BodyInertiaYX": reflect.ValueOf(physics.BodyInertiaYX),
"BodyInertiaYY": reflect.ValueOf(physics.BodyInertiaYY),
"BodyInertiaYZ": reflect.ValueOf(physics.BodyInertiaYZ),
"BodyInertiaZX": reflect.ValueOf(physics.BodyInertiaZX),
"BodyInertiaZY": reflect.ValueOf(physics.BodyInertiaZY),
"BodyInertiaZZ": reflect.ValueOf(physics.BodyInertiaZZ),
"BodyInvInertia": reflect.ValueOf(physics.BodyInvInertia),
"BodyInvInertiaXX": reflect.ValueOf(physics.BodyInvInertiaXX),
"BodyInvInertiaXY": reflect.ValueOf(physics.BodyInvInertiaXY),
"BodyInvInertiaXZ": reflect.ValueOf(physics.BodyInvInertiaXZ),
"BodyInvInertiaYX": reflect.ValueOf(physics.BodyInvInertiaYX),
"BodyInvInertiaYY": reflect.ValueOf(physics.BodyInvInertiaYY),
"BodyInvInertiaYZ": reflect.ValueOf(physics.BodyInvInertiaYZ),
"BodyInvInertiaZX": reflect.ValueOf(physics.BodyInvInertiaZX),
"BodyInvInertiaZY": reflect.ValueOf(physics.BodyInvInertiaZY),
"BodyInvInertiaZZ": reflect.ValueOf(physics.BodyInvInertiaZZ),
"BodyInvMass": reflect.ValueOf(physics.BodyInvMass),
"BodyJoints": reflect.ValueOf(&physics.BodyJoints).Elem(),
"BodyJointsVar": reflect.ValueOf(physics.BodyJointsVar),
"BodyMass": reflect.ValueOf(physics.BodyMass),
"BodyPos": reflect.ValueOf(physics.BodyPos),
"BodyPosX": reflect.ValueOf(physics.BodyPosX),
"BodyPosY": reflect.ValueOf(physics.BodyPosY),
"BodyPosZ": reflect.ValueOf(physics.BodyPosZ),
"BodyQuat": reflect.ValueOf(physics.BodyQuat),
"BodyQuatW": reflect.ValueOf(physics.BodyQuatW),
"BodyQuatX": reflect.ValueOf(physics.BodyQuatX),
"BodyQuatY": reflect.ValueOf(physics.BodyQuatY),
"BodyQuatZ": reflect.ValueOf(physics.BodyQuatZ),
"BodyRadius": reflect.ValueOf(physics.BodyRadius),
"BodyShape": reflect.ValueOf(physics.BodyShape),
"BodyThick": reflect.ValueOf(physics.BodyThick),
"BodyVarsN": reflect.ValueOf(physics.BodyVarsN),
"BodyVarsValues": reflect.ValueOf(physics.BodyVarsValues),
"BodyWorld": reflect.ValueOf(physics.BodyWorld),
"BorrowedGPU": reflect.ValueOf(&physics.BorrowedGPU).Elem(),
"Box": reflect.ValueOf(physics.Box),
"BoxEdge": reflect.ValueOf(physics.BoxEdge),
"BoxSDF": reflect.ValueOf(physics.BoxSDF),
"BoxSDFGrad": reflect.ValueOf(physics.BoxSDFGrad),
"BoxVertex": reflect.ValueOf(physics.BoxVertex),
"BroadContactVarsN": reflect.ValueOf(physics.BroadContactVarsN),
"BroadContacts": reflect.ValueOf(&physics.BroadContacts).Elem(),
"BroadContactsN": reflect.ValueOf(&physics.BroadContactsN).Elem(),
"BroadContactsNVar": reflect.ValueOf(physics.BroadContactsNVar),
"BroadContactsVar": reflect.ValueOf(physics.BroadContactsVar),
"Capsule": reflect.ValueOf(physics.Capsule),
"CapsuleSDF": reflect.ValueOf(physics.CapsuleSDF),
"ClosestEdgeBox": reflect.ValueOf(physics.ClosestEdgeBox),
"ClosestEdgeCapsule": reflect.ValueOf(physics.ClosestEdgeCapsule),
"ClosestEdgePlane": reflect.ValueOf(physics.ClosestEdgePlane),
"ClosestPointBox": reflect.ValueOf(physics.ClosestPointBox),
"ClosestPointLineSegment": reflect.ValueOf(physics.ClosestPointLineSegment),
"ClosestPointPlane": reflect.ValueOf(physics.ClosestPointPlane),
"ColBoxBox": reflect.ValueOf(physics.ColBoxBox),
"ColBoxCapsule": reflect.ValueOf(physics.ColBoxCapsule),
"ColBoxPlane": reflect.ValueOf(physics.ColBoxPlane),
"ColCapsuleCapsule": reflect.ValueOf(physics.ColCapsuleCapsule),
"ColCapsulePlane": reflect.ValueOf(physics.ColCapsulePlane),
"ColCylinderPlane": reflect.ValueOf(physics.ColCylinderPlane),
"ColSphereBox": reflect.ValueOf(physics.ColSphereBox),
"ColSphereCapsule": reflect.ValueOf(physics.ColSphereCapsule),
"ColSpherePlane": reflect.ValueOf(physics.ColSpherePlane),
"ColSphereSphere": reflect.ValueOf(physics.ColSphereSphere),
"CollisionBroad": reflect.ValueOf(physics.CollisionBroad),
"CollisionNarrow": reflect.ValueOf(physics.CollisionNarrow),
"ComputeGPU": reflect.ValueOf(&physics.ComputeGPU).Elem(),
"Cone": reflect.ValueOf(physics.Cone),
"ConeSDF": reflect.ValueOf(physics.ConeSDF),
"ContactA": reflect.ValueOf(physics.ContactA),
"ContactAAngDelta": reflect.ValueOf(physics.ContactAAngDelta),
"ContactAAngDeltaX": reflect.ValueOf(physics.ContactAAngDeltaX),
"ContactAAngDeltaY": reflect.ValueOf(physics.ContactAAngDeltaY),
"ContactAAngDeltaZ": reflect.ValueOf(physics.ContactAAngDeltaZ),
"ContactADelta": reflect.ValueOf(physics.ContactADelta),
"ContactADeltaX": reflect.ValueOf(physics.ContactADeltaX),
"ContactADeltaY": reflect.ValueOf(physics.ContactADeltaY),
"ContactADeltaZ": reflect.ValueOf(physics.ContactADeltaZ),
"ContactAOff": reflect.ValueOf(physics.ContactAOff),
"ContactAOffX": reflect.ValueOf(physics.ContactAOffX),
"ContactAOffY": reflect.ValueOf(physics.ContactAOffY),
"ContactAOffZ": reflect.ValueOf(physics.ContactAOffZ),
"ContactAPoint": reflect.ValueOf(physics.ContactAPoint),
"ContactAPointX": reflect.ValueOf(physics.ContactAPointX),
"ContactAPointY": reflect.ValueOf(physics.ContactAPointY),
"ContactAPointZ": reflect.ValueOf(physics.ContactAPointZ),
"ContactAThick": reflect.ValueOf(physics.ContactAThick),
"ContactB": reflect.ValueOf(physics.ContactB),
"ContactBAngDelta": reflect.ValueOf(physics.ContactBAngDelta),
"ContactBAngDeltaX": reflect.ValueOf(physics.ContactBAngDeltaX),
"ContactBAngDeltaY": reflect.ValueOf(physics.ContactBAngDeltaY),
"ContactBAngDeltaZ": reflect.ValueOf(physics.ContactBAngDeltaZ),
"ContactBDelta": reflect.ValueOf(physics.ContactBDelta),
"ContactBDeltaX": reflect.ValueOf(physics.ContactBDeltaX),
"ContactBDeltaY": reflect.ValueOf(physics.ContactBDeltaY),
"ContactBDeltaZ": reflect.ValueOf(physics.ContactBDeltaZ),
"ContactBOff": reflect.ValueOf(physics.ContactBOff),
"ContactBOffX": reflect.ValueOf(physics.ContactBOffX),
"ContactBOffY": reflect.ValueOf(physics.ContactBOffY),
"ContactBOffZ": reflect.ValueOf(physics.ContactBOffZ),
"ContactBPoint": reflect.ValueOf(physics.ContactBPoint),
"ContactBPointX": reflect.ValueOf(physics.ContactBPointX),
"ContactBPointY": reflect.ValueOf(physics.ContactBPointY),
"ContactBPointZ": reflect.ValueOf(physics.ContactBPointZ),
"ContactBThick": reflect.ValueOf(physics.ContactBThick),
"ContactConstraint": reflect.ValueOf(physics.ContactConstraint),
"ContactNorm": reflect.ValueOf(physics.ContactNorm),
"ContactNormX": reflect.ValueOf(physics.ContactNormX),
"ContactNormY": reflect.ValueOf(physics.ContactNormY),
"ContactNormZ": reflect.ValueOf(physics.ContactNormZ),
"ContactPointIdx": reflect.ValueOf(physics.ContactPointIdx),
"ContactPoints": reflect.ValueOf(physics.ContactPoints),
"ContactVarsN": reflect.ValueOf(physics.ContactVarsN),
"ContactVarsValues": reflect.ValueOf(physics.ContactVarsValues),
"ContactWeight": reflect.ValueOf(physics.ContactWeight),
"Contacts": reflect.ValueOf(&physics.Contacts).Elem(),
"ContactsN": reflect.ValueOf(&physics.ContactsN).Elem(),
"ContactsNVar": reflect.ValueOf(physics.ContactsNVar),
"ContactsVar": reflect.ValueOf(physics.ContactsVar),
"CurModel": reflect.ValueOf(&physics.CurModel).Elem(),
"Cylinder": reflect.ValueOf(physics.Cylinder),
"CylinderSDF": reflect.ValueOf(physics.CylinderSDF),
"D6": reflect.ValueOf(physics.D6),
"Distance": reflect.ValueOf(physics.Distance),
"DynAccX": reflect.ValueOf(physics.DynAccX),
"DynAccY": reflect.ValueOf(physics.DynAccY),
"DynAccZ": reflect.ValueOf(physics.DynAccZ),
"DynAngAccX": reflect.ValueOf(physics.DynAngAccX),
"DynAngAccY": reflect.ValueOf(physics.DynAngAccY),
"DynAngAccZ": reflect.ValueOf(physics.DynAngAccZ),
"DynAngDeltaX": reflect.ValueOf(physics.DynAngDeltaX),
"DynAngDeltaY": reflect.ValueOf(physics.DynAngDeltaY),
"DynAngDeltaZ": reflect.ValueOf(physics.DynAngDeltaZ),
"DynAngVelX": reflect.ValueOf(physics.DynAngVelX),
"DynAngVelY": reflect.ValueOf(physics.DynAngVelY),
"DynAngVelZ": reflect.ValueOf(physics.DynAngVelZ),
"DynBody": reflect.ValueOf(physics.DynBody),
"DynContactWeight": reflect.ValueOf(physics.DynContactWeight),
"DynDeltaX": reflect.ValueOf(physics.DynDeltaX),
"DynDeltaY": reflect.ValueOf(physics.DynDeltaY),
"DynDeltaZ": reflect.ValueOf(physics.DynDeltaZ),
"DynForceX": reflect.ValueOf(physics.DynForceX),
"DynForceY": reflect.ValueOf(physics.DynForceY),
"DynForceZ": reflect.ValueOf(physics.DynForceZ),
"DynPosX": reflect.ValueOf(physics.DynPosX),
"DynPosY": reflect.ValueOf(physics.DynPosY),
"DynPosZ": reflect.ValueOf(physics.DynPosZ),
"DynQuatW": reflect.ValueOf(physics.DynQuatW),
"DynQuatX": reflect.ValueOf(physics.DynQuatX),
"DynQuatY": reflect.ValueOf(physics.DynQuatY),
"DynQuatZ": reflect.ValueOf(physics.DynQuatZ),
"DynTorqueX": reflect.ValueOf(physics.DynTorqueX),
"DynTorqueY": reflect.ValueOf(physics.DynTorqueY),
"DynTorqueZ": reflect.ValueOf(physics.DynTorqueZ),
"DynVelX": reflect.ValueOf(physics.DynVelX),
"DynVelY": reflect.ValueOf(physics.DynVelY),
"DynVelZ": reflect.ValueOf(physics.DynVelZ),
"DynamicAcc": reflect.ValueOf(physics.DynamicAcc),
"DynamicAngAcc": reflect.ValueOf(physics.DynamicAngAcc),
"DynamicAngDelta": reflect.ValueOf(physics.DynamicAngDelta),
"DynamicAngVel": reflect.ValueOf(physics.DynamicAngVel),
"DynamicBody": reflect.ValueOf(physics.DynamicBody),
"DynamicDelta": reflect.ValueOf(physics.DynamicDelta),
"DynamicForce": reflect.ValueOf(physics.DynamicForce),
"DynamicPos": reflect.ValueOf(physics.DynamicPos),
"DynamicQuat": reflect.ValueOf(physics.DynamicQuat),
"DynamicTorque": reflect.ValueOf(physics.DynamicTorque),
"DynamicVarsN": reflect.ValueOf(physics.DynamicVarsN),
"DynamicVarsValues": reflect.ValueOf(physics.DynamicVarsValues),
"DynamicVel": reflect.ValueOf(physics.DynamicVel),
"Dynamics": reflect.ValueOf(&physics.Dynamics).Elem(),
"DynamicsCurToNext": reflect.ValueOf(physics.DynamicsCurToNext),
"DynamicsVar": reflect.ValueOf(physics.DynamicsVar),
"Fixed": reflect.ValueOf(physics.Fixed),
"ForcesFromJoints": reflect.ValueOf(physics.ForcesFromJoints),
"Free": reflect.ValueOf(physics.Free),
"GPUInit": reflect.ValueOf(physics.GPUInit),
"GPUInitialized": reflect.ValueOf(&physics.GPUInitialized).Elem(),
"GPURelease": reflect.ValueOf(physics.GPURelease),
"GPUSystem": reflect.ValueOf(&physics.GPUSystem).Elem(),
"GPUVarsN": reflect.ValueOf(physics.GPUVarsN),
"GPUVarsValues": reflect.ValueOf(physics.GPUVarsValues),
"GetBodyDynamic": reflect.ValueOf(physics.GetBodyDynamic),
"GetBodyGroup": reflect.ValueOf(physics.GetBodyGroup),
"GetBodyShape": reflect.ValueOf(physics.GetBodyShape),
"GetBodyWorld": reflect.ValueOf(physics.GetBodyWorld),
"GetBroadContactA": reflect.ValueOf(physics.GetBroadContactA),
"GetBroadContactB": reflect.ValueOf(physics.GetBroadContactB),
"GetBroadContactPointIdx": reflect.ValueOf(physics.GetBroadContactPointIdx),
"GetContactA": reflect.ValueOf(physics.GetContactA),
"GetContactB": reflect.ValueOf(physics.GetContactB),
"GetContactPointIdx": reflect.ValueOf(physics.GetContactPointIdx),
"GetJointAngularDoFN": reflect.ValueOf(physics.GetJointAngularDoFN),
"GetJointEnabled": reflect.ValueOf(physics.GetJointEnabled),
"GetJointLinearDoFN": reflect.ValueOf(physics.GetJointLinearDoFN),
"GetJointNoLinearRotation": reflect.ValueOf(physics.GetJointNoLinearRotation),
"GetJointParentFixed": reflect.ValueOf(physics.GetJointParentFixed),
"GetJointTargetPos": reflect.ValueOf(physics.GetJointTargetPos),
"GetJointTargetVel": reflect.ValueOf(physics.GetJointTargetVel),
"GetJointType": reflect.ValueOf(physics.GetJointType),
"GetParams": reflect.ValueOf(physics.GetParams),
"GroupsCollide": reflect.ValueOf(physics.GroupsCollide),
"InitDynamics": reflect.ValueOf(physics.InitDynamics),
"InitGeomData": reflect.ValueOf(physics.InitGeomData),
"JointAngLambda": reflect.ValueOf(physics.JointAngLambda),
"JointAngLambdaX": reflect.ValueOf(physics.JointAngLambdaX),
"JointAngLambdaY": reflect.ValueOf(physics.JointAngLambdaY),
"JointAngLambdaZ": reflect.ValueOf(physics.JointAngLambdaZ),
"JointAngularDoFN": reflect.ValueOf(physics.JointAngularDoFN),
"JointAxis": reflect.ValueOf(physics.JointAxis),
"JointAxisDoF": reflect.ValueOf(physics.JointAxisDoF),
"JointAxisLimitsUpdate": reflect.ValueOf(physics.JointAxisLimitsUpdate),
"JointAxisTarget": reflect.ValueOf(physics.JointAxisTarget),
"JointAxisX": reflect.ValueOf(physics.JointAxisX),
"JointAxisY": reflect.ValueOf(physics.JointAxisY),
"JointAxisZ": reflect.ValueOf(physics.JointAxisZ),
"JointCForce": reflect.ValueOf(physics.JointCForce),
"JointCForceX": reflect.ValueOf(physics.JointCForceX),
"JointCForceY": reflect.ValueOf(physics.JointCForceY),
"JointCForceZ": reflect.ValueOf(physics.JointCForceZ),
"JointCPos": reflect.ValueOf(physics.JointCPos),
"JointCPosX": reflect.ValueOf(physics.JointCPosX),
"JointCPosY": reflect.ValueOf(physics.JointCPosY),
"JointCPosZ": reflect.ValueOf(physics.JointCPosZ),
"JointCQuat": reflect.ValueOf(physics.JointCQuat),
"JointCQuatW": reflect.ValueOf(physics.JointCQuatW),
"JointCQuatX": reflect.ValueOf(physics.JointCQuatX),
"JointCQuatY": reflect.ValueOf(physics.JointCQuatY),
"JointCQuatZ": reflect.ValueOf(physics.JointCQuatZ),
"JointCTorque": reflect.ValueOf(physics.JointCTorque),
"JointCTorqueX": reflect.ValueOf(physics.JointCTorqueX),
"JointCTorqueY": reflect.ValueOf(physics.JointCTorqueY),
"JointCTorqueZ": reflect.ValueOf(physics.JointCTorqueZ),
"JointChild": reflect.ValueOf(physics.JointChild),
"JointChildIndex": reflect.ValueOf(physics.JointChildIndex),
"JointControl": reflect.ValueOf(physics.JointControl),
"JointControlForce": reflect.ValueOf(physics.JointControlForce),
"JointControlVarsN": reflect.ValueOf(physics.JointControlVarsN),
"JointControlVarsValues": reflect.ValueOf(physics.JointControlVarsValues),
"JointControls": reflect.ValueOf(&physics.JointControls).Elem(),
"JointControlsVar": reflect.ValueOf(physics.JointControlsVar),
"JointDoF": reflect.ValueOf(physics.JointDoF),
"JointDoF1": reflect.ValueOf(physics.JointDoF1),
"JointDoF2": reflect.ValueOf(physics.JointDoF2),
"JointDoF3": reflect.ValueOf(physics.JointDoF3),
"JointDoF4": reflect.ValueOf(physics.JointDoF4),
"JointDoF5": reflect.ValueOf(physics.JointDoF5),
"JointDoF6": reflect.ValueOf(physics.JointDoF6),
"JointDoFIndex": reflect.ValueOf(physics.JointDoFIndex),
"JointDoFVarsN": reflect.ValueOf(physics.JointDoFVarsN),
"JointDoFVarsValues": reflect.ValueOf(physics.JointDoFVarsValues),
"JointDoFs": reflect.ValueOf(&physics.JointDoFs).Elem(),
"JointDoFsVar": reflect.ValueOf(physics.JointDoFsVar),
"JointEnabled": reflect.ValueOf(physics.JointEnabled),
"JointLimitLower": reflect.ValueOf(physics.JointLimitLower),
"JointLimitUnlimited": reflect.ValueOf(constant.MakeFromLiteral("10000000000", token.FLOAT, 0)),
"JointLimitUpper": reflect.ValueOf(physics.JointLimitUpper),
"JointLinLambda": reflect.ValueOf(physics.JointLinLambda),
"JointLinLambdaX": reflect.ValueOf(physics.JointLinLambdaX),
"JointLinLambdaY": reflect.ValueOf(physics.JointLinLambdaY),
"JointLinLambdaZ": reflect.ValueOf(physics.JointLinLambdaZ),
"JointLinearDoFN": reflect.ValueOf(physics.JointLinearDoFN),
"JointNoLinearRotation": reflect.ValueOf(physics.JointNoLinearRotation),
"JointPForce": reflect.ValueOf(physics.JointPForce),
"JointPForceX": reflect.ValueOf(physics.JointPForceX),
"JointPForceY": reflect.ValueOf(physics.JointPForceY),
"JointPForceZ": reflect.ValueOf(physics.JointPForceZ),
"JointPPos": reflect.ValueOf(physics.JointPPos),
"JointPPosX": reflect.ValueOf(physics.JointPPosX),
"JointPPosY": reflect.ValueOf(physics.JointPPosY),
"JointPPosZ": reflect.ValueOf(physics.JointPPosZ),
"JointPQuat": reflect.ValueOf(physics.JointPQuat),
"JointPQuatW": reflect.ValueOf(physics.JointPQuatW),
"JointPQuatX": reflect.ValueOf(physics.JointPQuatX),
"JointPQuatY": reflect.ValueOf(physics.JointPQuatY),
"JointPQuatZ": reflect.ValueOf(physics.JointPQuatZ),
"JointPTorque": reflect.ValueOf(physics.JointPTorque),
"JointPTorqueX": reflect.ValueOf(physics.JointPTorqueX),
"JointPTorqueY": reflect.ValueOf(physics.JointPTorqueY),
"JointPTorqueZ": reflect.ValueOf(physics.JointPTorqueZ),
"JointParent": reflect.ValueOf(physics.JointParent),
"JointParentFixed": reflect.ValueOf(physics.JointParentFixed),
"JointParentIndex": reflect.ValueOf(physics.JointParentIndex),
"JointTargetDamp": reflect.ValueOf(physics.JointTargetDamp),
"JointTargetPos": reflect.ValueOf(physics.JointTargetPos),
"JointTargetPosCur": reflect.ValueOf(physics.JointTargetPosCur),
"JointTargetStiff": reflect.ValueOf(physics.JointTargetStiff),
"JointTargetVel": reflect.ValueOf(physics.JointTargetVel),
"JointType": reflect.ValueOf(physics.JointType),
"JointTypesN": reflect.ValueOf(physics.JointTypesN),
"JointTypesValues": reflect.ValueOf(physics.JointTypesValues),
"JointVarsN": reflect.ValueOf(physics.JointVarsN),
"JointVarsValues": reflect.ValueOf(physics.JointVarsValues),
"Joints": reflect.ValueOf(&physics.Joints).Elem(),
"JointsVar": reflect.ValueOf(physics.JointsVar),
"LimitDelta": reflect.ValueOf(physics.LimitDelta),
"NewGeomData": reflect.ValueOf(physics.NewGeomData),
"NewModel": reflect.ValueOf(physics.NewModel),
"Objects": reflect.ValueOf(&physics.Objects).Elem(),
"ObjectsVar": reflect.ValueOf(physics.ObjectsVar),
"OneIfNonzero": reflect.ValueOf(physics.OneIfNonzero),
"Params": reflect.ValueOf(&physics.Params).Elem(),
"ParamsVar": reflect.ValueOf(physics.ParamsVar),
"Plane": reflect.ValueOf(physics.Plane),
"PlaneEdge": reflect.ValueOf(physics.PlaneEdge),
"PlaneSDF": reflect.ValueOf(physics.PlaneSDF),
"PlaneXZ": reflect.ValueOf(physics.PlaneXZ),
"PositionalCorrection": reflect.ValueOf(physics.PositionalCorrection),
"Prismatic": reflect.ValueOf(physics.Prismatic),
"ReadFromGPU": reflect.ValueOf(physics.ReadFromGPU),
"Revolute": reflect.ValueOf(physics.Revolute),
"RunCollisionBroad": reflect.ValueOf(physics.RunCollisionBroad),
"RunCollisionBroadCPU": reflect.ValueOf(physics.RunCollisionBroadCPU),
"RunCollisionBroadGPU": reflect.ValueOf(physics.RunCollisionBroadGPU),
"RunCollisionNarrow": reflect.ValueOf(physics.RunCollisionNarrow),
"RunCollisionNarrowCPU": reflect.ValueOf(physics.RunCollisionNarrowCPU),
"RunCollisionNarrowGPU": reflect.ValueOf(physics.RunCollisionNarrowGPU),
"RunDone": reflect.ValueOf(physics.RunDone),
"RunDynamicsCurToNext": reflect.ValueOf(physics.RunDynamicsCurToNext),
"RunDynamicsCurToNextCPU": reflect.ValueOf(physics.RunDynamicsCurToNextCPU),
"RunDynamicsCurToNextGPU": reflect.ValueOf(physics.RunDynamicsCurToNextGPU),
"RunForcesFromJoints": reflect.ValueOf(physics.RunForcesFromJoints),
"RunForcesFromJointsCPU": reflect.ValueOf(physics.RunForcesFromJointsCPU),
"RunForcesFromJointsGPU": reflect.ValueOf(physics.RunForcesFromJointsGPU),
"RunGPUSync": reflect.ValueOf(physics.RunGPUSync),
"RunInitDynamics": reflect.ValueOf(physics.RunInitDynamics),
"RunInitDynamicsCPU": reflect.ValueOf(physics.RunInitDynamicsCPU),
"RunInitDynamicsGPU": reflect.ValueOf(physics.RunInitDynamicsGPU),
"RunOneCollisionBroad": reflect.ValueOf(physics.RunOneCollisionBroad),
"RunOneCollisionNarrow": reflect.ValueOf(physics.RunOneCollisionNarrow),
"RunOneDynamicsCurToNext": reflect.ValueOf(physics.RunOneDynamicsCurToNext),
"RunOneForcesFromJoints": reflect.ValueOf(physics.RunOneForcesFromJoints),
"RunOneInitDynamics": reflect.ValueOf(physics.RunOneInitDynamics),
"RunOneStepBodyContactDeltas": reflect.ValueOf(physics.RunOneStepBodyContactDeltas),
"RunOneStepBodyContacts": reflect.ValueOf(physics.RunOneStepBodyContacts),
"RunOneStepInit": reflect.ValueOf(physics.RunOneStepInit),
"RunOneStepIntegrateBodies": reflect.ValueOf(physics.RunOneStepIntegrateBodies),
"RunOneStepJointForces": reflect.ValueOf(physics.RunOneStepJointForces),
"RunOneStepSolveJoints": reflect.ValueOf(physics.RunOneStepSolveJoints),
"RunStepBodyContactDeltas": reflect.ValueOf(physics.RunStepBodyContactDeltas),
"RunStepBodyContactDeltasCPU": reflect.ValueOf(physics.RunStepBodyContactDeltasCPU),
"RunStepBodyContactDeltasGPU": reflect.ValueOf(physics.RunStepBodyContactDeltasGPU),
"RunStepBodyContacts": reflect.ValueOf(physics.RunStepBodyContacts),
"RunStepBodyContactsCPU": reflect.ValueOf(physics.RunStepBodyContactsCPU),
"RunStepBodyContactsGPU": reflect.ValueOf(physics.RunStepBodyContactsGPU),
"RunStepInit": reflect.ValueOf(physics.RunStepInit),
"RunStepInitCPU": reflect.ValueOf(physics.RunStepInitCPU),
"RunStepInitGPU": reflect.ValueOf(physics.RunStepInitGPU),
"RunStepIntegrateBodies": reflect.ValueOf(physics.RunStepIntegrateBodies),
"RunStepIntegrateBodiesCPU": reflect.ValueOf(physics.RunStepIntegrateBodiesCPU),
"RunStepIntegrateBodiesGPU": reflect.ValueOf(physics.RunStepIntegrateBodiesGPU),
"RunStepJointForces": reflect.ValueOf(physics.RunStepJointForces),
"RunStepJointForcesCPU": reflect.ValueOf(physics.RunStepJointForcesCPU),
"RunStepJointForcesGPU": reflect.ValueOf(physics.RunStepJointForcesGPU),
"RunStepSolveJoints": reflect.ValueOf(physics.RunStepSolveJoints),
"RunStepSolveJointsCPU": reflect.ValueOf(physics.RunStepSolveJointsCPU),
"RunStepSolveJointsGPU": reflect.ValueOf(physics.RunStepSolveJointsGPU),
"SetBodyBounce": reflect.ValueOf(physics.SetBodyBounce),
"SetBodyCom": reflect.ValueOf(physics.SetBodyCom),
"SetBodyDynamic": reflect.ValueOf(physics.SetBodyDynamic),
"SetBodyFriction": reflect.ValueOf(physics.SetBodyFriction),
"SetBodyFrictionRolling": reflect.ValueOf(physics.SetBodyFrictionRolling),
"SetBodyFrictionTortion": reflect.ValueOf(physics.SetBodyFrictionTortion),
"SetBodyGroup": reflect.ValueOf(physics.SetBodyGroup),
"SetBodyHSize": reflect.ValueOf(physics.SetBodyHSize),
"SetBodyInertia": reflect.ValueOf(physics.SetBodyInertia),
"SetBodyInvInertia": reflect.ValueOf(physics.SetBodyInvInertia),
"SetBodyPos": reflect.ValueOf(physics.SetBodyPos),
"SetBodyQuat": reflect.ValueOf(physics.SetBodyQuat),
"SetBodyShape": reflect.ValueOf(physics.SetBodyShape),
"SetBodyThick": reflect.ValueOf(physics.SetBodyThick),
"SetBodyWorld": reflect.ValueOf(physics.SetBodyWorld),
"SetBroadContactA": reflect.ValueOf(physics.SetBroadContactA),
"SetBroadContactB": reflect.ValueOf(physics.SetBroadContactB),
"SetBroadContactPointIdx": reflect.ValueOf(physics.SetBroadContactPointIdx),
"SetContactA": reflect.ValueOf(physics.SetContactA),
"SetContactAAngDelta": reflect.ValueOf(physics.SetContactAAngDelta),
"SetContactADelta": reflect.ValueOf(physics.SetContactADelta),
"SetContactAOff": reflect.ValueOf(physics.SetContactAOff),
"SetContactAPoint": reflect.ValueOf(physics.SetContactAPoint),
"SetContactB": reflect.ValueOf(physics.SetContactB),
"SetContactBAngDelta": reflect.ValueOf(physics.SetContactBAngDelta),
"SetContactBDelta": reflect.ValueOf(physics.SetContactBDelta),
"SetContactBOff": reflect.ValueOf(physics.SetContactBOff),
"SetContactBPoint": reflect.ValueOf(physics.SetContactBPoint),
"SetContactNorm": reflect.ValueOf(physics.SetContactNorm),
"SetContactPointIdx": reflect.ValueOf(physics.SetContactPointIdx),
"SetDynamicAcc": reflect.ValueOf(physics.SetDynamicAcc),
"SetDynamicAngAcc": reflect.ValueOf(physics.SetDynamicAngAcc),
"SetDynamicAngDelta": reflect.ValueOf(physics.SetDynamicAngDelta),
"SetDynamicAngVel": reflect.ValueOf(physics.SetDynamicAngVel),
"SetDynamicBody": reflect.ValueOf(physics.SetDynamicBody),
"SetDynamicDelta": reflect.ValueOf(physics.SetDynamicDelta),
"SetDynamicForce": reflect.ValueOf(physics.SetDynamicForce),
"SetDynamicPos": reflect.ValueOf(physics.SetDynamicPos),
"SetDynamicQuat": reflect.ValueOf(physics.SetDynamicQuat),
"SetDynamicTorque": reflect.ValueOf(physics.SetDynamicTorque),
"SetDynamicVel": reflect.ValueOf(physics.SetDynamicVel),
"SetJointAngLambda": reflect.ValueOf(physics.SetJointAngLambda),
"SetJointAngularDoFN": reflect.ValueOf(physics.SetJointAngularDoFN),
"SetJointAxis": reflect.ValueOf(physics.SetJointAxis),
"SetJointAxisDoF": reflect.ValueOf(physics.SetJointAxisDoF),
"SetJointCForce": reflect.ValueOf(physics.SetJointCForce),
"SetJointCPos": reflect.ValueOf(physics.SetJointCPos),
"SetJointCQuat": reflect.ValueOf(physics.SetJointCQuat),
"SetJointCTorque": reflect.ValueOf(physics.SetJointCTorque),
"SetJointChild": reflect.ValueOf(physics.SetJointChild),
"SetJointControl": reflect.ValueOf(physics.SetJointControl),
"SetJointControlForce": reflect.ValueOf(physics.SetJointControlForce),
"SetJointDoF": reflect.ValueOf(physics.SetJointDoF),
"SetJointDoFIndex": reflect.ValueOf(physics.SetJointDoFIndex),
"SetJointEnabled": reflect.ValueOf(physics.SetJointEnabled),
"SetJointLinLambda": reflect.ValueOf(physics.SetJointLinLambda),
"SetJointLinearDoFN": reflect.ValueOf(physics.SetJointLinearDoFN),
"SetJointNoLinearRotation": reflect.ValueOf(physics.SetJointNoLinearRotation),
"SetJointPForce": reflect.ValueOf(physics.SetJointPForce),
"SetJointPPos": reflect.ValueOf(physics.SetJointPPos),
"SetJointPQuat": reflect.ValueOf(physics.SetJointPQuat),
"SetJointPTorque": reflect.ValueOf(physics.SetJointPTorque),
"SetJointParent": reflect.ValueOf(physics.SetJointParent),
"SetJointParentFixed": reflect.ValueOf(physics.SetJointParentFixed),
"SetJointTargetAngle": reflect.ValueOf(physics.SetJointTargetAngle),
"SetJointTargetPos": reflect.ValueOf(physics.SetJointTargetPos),
"SetJointTargetVel": reflect.ValueOf(physics.SetJointTargetVel),
"SetJointType": reflect.ValueOf(physics.SetJointType),
"ShapePairContacts": reflect.ValueOf(physics.ShapePairContacts),
"ShapesN": reflect.ValueOf(physics.ShapesN),
"ShapesValues": reflect.ValueOf(physics.ShapesValues),
"Sphere": reflect.ValueOf(physics.Sphere),
"SphereSDF": reflect.ValueOf(physics.SphereSDF),
"StepBodyContactDeltas": reflect.ValueOf(physics.StepBodyContactDeltas),
"StepBodyContacts": reflect.ValueOf(physics.StepBodyContacts),
"StepBodyDeltas": reflect.ValueOf(physics.StepBodyDeltas),
"StepBodyKinetics": reflect.ValueOf(physics.StepBodyKinetics),
"StepInit": reflect.ValueOf(physics.StepInit),
"StepIntegrateBodies": reflect.ValueOf(physics.StepIntegrateBodies),
"StepJointForces": reflect.ValueOf(physics.StepJointForces),
"StepSolveJoint": reflect.ValueOf(physics.StepSolveJoint),
"StepSolveJoints": reflect.ValueOf(physics.StepSolveJoints),
"StepsToMsec": reflect.ValueOf(physics.StepsToMsec),
"SyncFromGPU": reflect.ValueOf(physics.SyncFromGPU),
"TensorStrides": reflect.ValueOf(&physics.TensorStrides).Elem(),
"ToGPU": reflect.ValueOf(physics.ToGPU),
"ToGPUTensorStrides": reflect.ValueOf(physics.ToGPUTensorStrides),
"UseGPU": reflect.ValueOf(&physics.UseGPU).Elem(),
"VelocityAtPoint": reflect.ValueOf(physics.VelocityAtPoint),
"WorldsCollide": reflect.ValueOf(physics.WorldsCollide),
// type definitions
"BodyVars": reflect.ValueOf((*physics.BodyVars)(nil)),
"ContactVars": reflect.ValueOf((*physics.ContactVars)(nil)),
"DynamicVars": reflect.ValueOf((*physics.DynamicVars)(nil)),
"GPUVars": reflect.ValueOf((*physics.GPUVars)(nil)),
"GeomData": reflect.ValueOf((*physics.GeomData)(nil)),
"JointControlVars": reflect.ValueOf((*physics.JointControlVars)(nil)),
"JointDoFVars": reflect.ValueOf((*physics.JointDoFVars)(nil)),
"JointTypes": reflect.ValueOf((*physics.JointTypes)(nil)),
"JointVars": reflect.ValueOf((*physics.JointVars)(nil)),
"Model": reflect.ValueOf((*physics.Model)(nil)),
"PhysicsParams": reflect.ValueOf((*physics.PhysicsParams)(nil)),
"Shapes": reflect.ValueOf((*physics.Shapes)(nil)),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/plot/plots'. DO NOT EDIT.
package labsymbols
import (
"cogentcore.org/lab/plot/plots"
"go/constant"
"go/token"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/plot/plots/plots"] = map[string]reflect.Value{
// function, constant and variable definitions
"BarType": reflect.ValueOf(constant.MakeFromLiteral("\"Bar\"", token.STRING, 0)),
"LabelsType": reflect.ValueOf(constant.MakeFromLiteral("\"Labels\"", token.STRING, 0)),
"NewBar": reflect.ValueOf(plots.NewBar),
"NewLabels": reflect.ValueOf(plots.NewLabels),
"NewLine": reflect.ValueOf(plots.NewLine),
"NewPointLine": reflect.ValueOf(plots.NewPointLine),
"NewScatter": reflect.ValueOf(plots.NewScatter),
"NewXErrorBars": reflect.ValueOf(plots.NewXErrorBars),
"NewXY": reflect.ValueOf(plots.NewXY),
"NewYErrorBars": reflect.ValueOf(plots.NewYErrorBars),
"XErrorBarsType": reflect.ValueOf(constant.MakeFromLiteral("\"XErrorBars\"", token.STRING, 0)),
"XYType": reflect.ValueOf(constant.MakeFromLiteral("\"XY\"", token.STRING, 0)),
"YErrorBarsType": reflect.ValueOf(constant.MakeFromLiteral("\"YErrorBars\"", token.STRING, 0)),
// type definitions
"Bar": reflect.ValueOf((*plots.Bar)(nil)),
"Labels": reflect.ValueOf((*plots.Labels)(nil)),
"XErrorBars": reflect.ValueOf((*plots.XErrorBars)(nil)),
"XY": reflect.ValueOf((*plots.XY)(nil)),
"YErrorBars": reflect.ValueOf((*plots.YErrorBars)(nil)),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/plot'. DO NOT EDIT.
package labsymbols
import (
"cogentcore.org/core/math32/minmax"
"cogentcore.org/lab/plot"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/plot/plot"] = map[string]reflect.Value{
// function, constant and variable definitions
"AxisScalesN": reflect.ValueOf(plot.AxisScalesN),
"AxisScalesValues": reflect.ValueOf(plot.AxisScalesValues),
"BasicStylers": reflect.ValueOf(plot.BasicStylers),
"Box": reflect.ValueOf(plot.Box),
"CheckFloats": reflect.ValueOf(plot.CheckFloats),
"CheckNaNs": reflect.ValueOf(plot.CheckNaNs),
"Circle": reflect.ValueOf(plot.Circle),
"CopyRole": reflect.ValueOf(plot.CopyRole),
"CopyRoleLabels": reflect.ValueOf(plot.CopyRoleLabels),
"CopyValues": reflect.ValueOf(plot.CopyValues),
"Cross": reflect.ValueOf(plot.Cross),
"DataOrValuer": reflect.ValueOf(plot.DataOrValuer),
"Default": reflect.ValueOf(plot.Default),
"DefaultFontFamily": reflect.ValueOf(&plot.DefaultFontFamily).Elem(),
"DefaultOffOnN": reflect.ValueOf(plot.DefaultOffOnN),
"DefaultOffOnValues": reflect.ValueOf(plot.DefaultOffOnValues),
"DrawBox": reflect.ValueOf(plot.DrawBox),
"DrawCircle": reflect.ValueOf(plot.DrawCircle),
"DrawCross": reflect.ValueOf(plot.DrawCross),
"DrawPlus": reflect.ValueOf(plot.DrawPlus),
"DrawPyramid": reflect.ValueOf(plot.DrawPyramid),
"DrawRing": reflect.ValueOf(plot.DrawRing),
"DrawSquare": reflect.ValueOf(plot.DrawSquare),
"DrawTriangle": reflect.ValueOf(plot.DrawTriangle),
"ErrInfinity": reflect.ValueOf(&plot.ErrInfinity).Elem(),
"ErrNoData": reflect.ValueOf(&plot.ErrNoData).Elem(),
"GetStylers": reflect.ValueOf(plot.GetStylers),
"GetStylersFromData": reflect.ValueOf(plot.GetStylersFromData),
"High": reflect.ValueOf(plot.High),
"InverseLinear": reflect.ValueOf(plot.InverseLinear),
"InverseLog": reflect.ValueOf(plot.InverseLog),
"Label": reflect.ValueOf(plot.Label),
"Linear": reflect.ValueOf(plot.Linear),
"Log": reflect.ValueOf(plot.Log),
"Low": reflect.ValueOf(plot.Low),
"MidStep": reflect.ValueOf(plot.MidStep),
"MustCopyRole": reflect.ValueOf(plot.MustCopyRole),
"New": reflect.ValueOf(plot.New),
"NewPlotter": reflect.ValueOf(plot.NewPlotter),
"NewStyle": reflect.ValueOf(plot.NewStyle),
"NewTablePlot": reflect.ValueOf(plot.NewTablePlot),
"NewXY": reflect.ValueOf(plot.NewXY),
"NewY": reflect.ValueOf(plot.NewY),
"NoRole": reflect.ValueOf(plot.NoRole),
"NoStep": reflect.ValueOf(plot.NoStep),
"Off": reflect.ValueOf(plot.Off),
"On": reflect.ValueOf(plot.On),
"PlotX": reflect.ValueOf(plot.PlotX),
"PlotY": reflect.ValueOf(plot.PlotY),
"PlotYR": reflect.ValueOf(plot.PlotYR),
"PlotterByType": reflect.ValueOf(plot.PlotterByType),
"Plotters": reflect.ValueOf(&plot.Plotters).Elem(),
"Plus": reflect.ValueOf(plot.Plus),
"PostStep": reflect.ValueOf(plot.PostStep),
"PreStep": reflect.ValueOf(plot.PreStep),
"Pyramid": reflect.ValueOf(plot.Pyramid),
"Range": reflect.ValueOf(plot.Range),
"RangeClamp": reflect.ValueOf(plot.RangeClamp),
"RegisterPlotter": reflect.ValueOf(plot.RegisterPlotter),
"Ring": reflect.ValueOf(plot.Ring),
"RolesN": reflect.ValueOf(plot.RolesN),
"RolesValues": reflect.ValueOf(plot.RolesValues),
"SetBasicStylers": reflect.ValueOf(plot.SetBasicStylers),
"SetFirstStyler": reflect.ValueOf(plot.SetFirstStyler),
"SetStyler": reflect.ValueOf(plot.SetStyler),
"ShapesN": reflect.ValueOf(plot.ShapesN),
"ShapesValues": reflect.ValueOf(plot.ShapesValues),
"Size": reflect.ValueOf(plot.Size),
"Split": reflect.ValueOf(plot.Split),
"Square": reflect.ValueOf(plot.Square),
"StepKindN": reflect.ValueOf(plot.StepKindN),
"StepKindValues": reflect.ValueOf(plot.StepKindValues),
"Styler": reflect.ValueOf(plot.Styler),
"Triangle": reflect.ValueOf(plot.Triangle),
"U": reflect.ValueOf(plot.U),
"UTCUnixTime": reflect.ValueOf(&plot.UTCUnixTime).Elem(),
"UnixTimeIn": reflect.ValueOf(plot.UnixTimeIn),
"V": reflect.ValueOf(plot.V),
"W": reflect.ValueOf(plot.W),
"X": reflect.ValueOf(plot.X),
"Y": reflect.ValueOf(plot.Y),
"Z": reflect.ValueOf(plot.Z),
// type definitions
"Axis": reflect.ValueOf((*plot.Axis)(nil)),
"AxisScales": reflect.ValueOf((*plot.AxisScales)(nil)),
"AxisStyle": reflect.ValueOf((*plot.AxisStyle)(nil)),
"ConstantTicks": reflect.ValueOf((*plot.ConstantTicks)(nil)),
"Data": reflect.ValueOf((*plot.Data)(nil)),
"DefaultOffOn": reflect.ValueOf((*plot.DefaultOffOn)(nil)),
"DefaultTicks": reflect.ValueOf((*plot.DefaultTicks)(nil)),
"InvertedScale": reflect.ValueOf((*plot.InvertedScale)(nil)),
"Labels": reflect.ValueOf((*plot.Labels)(nil)),
"Legend": reflect.ValueOf((*plot.Legend)(nil)),
"LegendEntry": reflect.ValueOf((*plot.LegendEntry)(nil)),
"LegendPosition": reflect.ValueOf((*plot.LegendPosition)(nil)),
"LegendStyle": reflect.ValueOf((*plot.LegendStyle)(nil)),
"LineStyle": reflect.ValueOf((*plot.LineStyle)(nil)),
"LinearScale": reflect.ValueOf((*plot.LinearScale)(nil)),
"LogScale": reflect.ValueOf((*plot.LogScale)(nil)),
"LogTicks": reflect.ValueOf((*plot.LogTicks)(nil)),
"Normalizer": reflect.ValueOf((*plot.Normalizer)(nil)),
"PanZoom": reflect.ValueOf((*plot.PanZoom)(nil)),
"Plot": reflect.ValueOf((*plot.Plot)(nil)),
"PlotStyle": reflect.ValueOf((*plot.PlotStyle)(nil)),
"Plotter": reflect.ValueOf((*plot.Plotter)(nil)),
"PlotterName": reflect.ValueOf((*plot.PlotterName)(nil)),
"PlotterType": reflect.ValueOf((*plot.PlotterType)(nil)),
"PointStyle": reflect.ValueOf((*plot.PointStyle)(nil)),
"Roles": reflect.ValueOf((*plot.Roles)(nil)),
"Shapes": reflect.ValueOf((*plot.Shapes)(nil)),
"StepKind": reflect.ValueOf((*plot.StepKind)(nil)),
"Style": reflect.ValueOf((*plot.Style)(nil)),
"Stylers": reflect.ValueOf((*plot.Stylers)(nil)),
"Text": reflect.ValueOf((*plot.Text)(nil)),
"TextStyle": reflect.ValueOf((*plot.TextStyle)(nil)),
"Thumbnailer": reflect.ValueOf((*plot.Thumbnailer)(nil)),
"Tick": reflect.ValueOf((*plot.Tick)(nil)),
"Ticker": reflect.ValueOf((*plot.Ticker)(nil)),
"TimeTicks": reflect.ValueOf((*plot.TimeTicks)(nil)),
"Valuer": reflect.ValueOf((*plot.Valuer)(nil)),
"Values": reflect.ValueOf((*plot.Values)(nil)),
"VirtualAxis": reflect.ValueOf((*plot.VirtualAxis)(nil)),
"VirtualAxisStyle": reflect.ValueOf((*plot.VirtualAxisStyle)(nil)),
"WidthStyle": reflect.ValueOf((*plot.WidthStyle)(nil)),
"XAxisStyle": reflect.ValueOf((*plot.XAxisStyle)(nil)),
// interface wrapper definitions
"_Normalizer": reflect.ValueOf((*_cogentcore_org_lab_plot_Normalizer)(nil)),
"_Plotter": reflect.ValueOf((*_cogentcore_org_lab_plot_Plotter)(nil)),
"_Thumbnailer": reflect.ValueOf((*_cogentcore_org_lab_plot_Thumbnailer)(nil)),
"_Ticker": reflect.ValueOf((*_cogentcore_org_lab_plot_Ticker)(nil)),
"_Valuer": reflect.ValueOf((*_cogentcore_org_lab_plot_Valuer)(nil)),
}
}
// _cogentcore_org_lab_plot_Normalizer is an interface wrapper for Normalizer type
type _cogentcore_org_lab_plot_Normalizer struct {
IValue interface{}
WNormalize func(min float64, max float64, x float64) float64
}
func (W _cogentcore_org_lab_plot_Normalizer) Normalize(min float64, max float64, x float64) float64 {
return W.WNormalize(min, max, x)
}
// _cogentcore_org_lab_plot_Plotter is an interface wrapper for Plotter type
type _cogentcore_org_lab_plot_Plotter struct {
IValue interface{}
WApplyStyle func(plotStyle *plot.PlotStyle, idx int)
WData func() (data plot.Data, pixX []float32, pixY []float32)
WPlot func(pt *plot.Plot)
WStylers func() *plot.Stylers
WUpdateRange func(plt *plot.Plot, x *minmax.F64, y *minmax.F64, yr *minmax.F64, z *minmax.F64, size *minmax.F64)
}
func (W _cogentcore_org_lab_plot_Plotter) ApplyStyle(plotStyle *plot.PlotStyle, idx int) {
W.WApplyStyle(plotStyle, idx)
}
func (W _cogentcore_org_lab_plot_Plotter) Data() (data plot.Data, pixX []float32, pixY []float32) {
return W.WData()
}
func (W _cogentcore_org_lab_plot_Plotter) Plot(pt *plot.Plot) { W.WPlot(pt) }
func (W _cogentcore_org_lab_plot_Plotter) Stylers() *plot.Stylers { return W.WStylers() }
func (W _cogentcore_org_lab_plot_Plotter) UpdateRange(plt *plot.Plot, x *minmax.F64, y *minmax.F64, yr *minmax.F64, z *minmax.F64, size *minmax.F64) {
W.WUpdateRange(plt, x, y, yr, z, size)
}
// _cogentcore_org_lab_plot_Thumbnailer is an interface wrapper for Thumbnailer type
type _cogentcore_org_lab_plot_Thumbnailer struct {
IValue interface{}
WThumbnail func(pt *plot.Plot)
}
func (W _cogentcore_org_lab_plot_Thumbnailer) Thumbnail(pt *plot.Plot) { W.WThumbnail(pt) }
// _cogentcore_org_lab_plot_Ticker is an interface wrapper for Ticker type
type _cogentcore_org_lab_plot_Ticker struct {
IValue interface{}
WTicks func(mn float64, mx float64, nticks int) []plot.Tick
}
func (W _cogentcore_org_lab_plot_Ticker) Ticks(mn float64, mx float64, nticks int) []plot.Tick {
return W.WTicks(mn, mx, nticks)
}
// _cogentcore_org_lab_plot_Valuer is an interface wrapper for Valuer type
type _cogentcore_org_lab_plot_Valuer struct {
IValue interface{}
WFloat1D func(i int) float64
WLen func() int
WString1D func(i int) string
}
func (W _cogentcore_org_lab_plot_Valuer) Float1D(i int) float64 { return W.WFloat1D(i) }
func (W _cogentcore_org_lab_plot_Valuer) Len() int { return W.WLen() }
func (W _cogentcore_org_lab_plot_Valuer) String1D(i int) string { return W.WString1D(i) }
// Code generated by 'yaegi extract cogentcore.org/lab/plotcore'. DO NOT EDIT.
package labsymbols
import (
"cogentcore.org/lab/plotcore"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/plotcore/plotcore"] = map[string]reflect.Value{
// function, constant and variable definitions
"NewEditor": reflect.ValueOf(plotcore.NewEditor),
"NewPlot": reflect.ValueOf(plotcore.NewPlot),
"NewPlotterChooser": reflect.ValueOf(plotcore.NewPlotterChooser),
"NewSubPlot": reflect.ValueOf(plotcore.NewSubPlot),
// type definitions
"Editor": reflect.ValueOf((*plotcore.Editor)(nil)),
"Plot": reflect.ValueOf((*plotcore.Plot)(nil)),
"PlotterChooser": reflect.ValueOf((*plotcore.PlotterChooser)(nil)),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/tensorcore'. DO NOT EDIT.
package labsymbols
import (
"cogentcore.org/lab/tensorcore"
"go/constant"
"go/token"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/tensorcore/tensorcore"] = map[string]reflect.Value{
// function, constant and variable definitions
"AddGridStylerTo": reflect.ValueOf(tensorcore.AddGridStylerTo),
"GetGridStylersFrom": reflect.ValueOf(tensorcore.GetGridStylersFrom),
"LabelSpace": reflect.ValueOf(constant.MakeFromLiteral("8", token.INT, 0)),
"NewGridStyle": reflect.ValueOf(tensorcore.NewGridStyle),
"NewTable": reflect.ValueOf(tensorcore.NewTable),
"NewTableButton": reflect.ValueOf(tensorcore.NewTableButton),
"NewTensorButton": reflect.ValueOf(tensorcore.NewTensorButton),
"NewTensorEditor": reflect.ValueOf(tensorcore.NewTensorEditor),
"NewTensorGrid": reflect.ValueOf(tensorcore.NewTensorGrid),
"RepeatsToBlank": reflect.ValueOf(tensorcore.RepeatsToBlank),
"SetGridStylersTo": reflect.ValueOf(tensorcore.SetGridStylersTo),
// type definitions
"GridStyle": reflect.ValueOf((*tensorcore.GridStyle)(nil)),
"GridStylers": reflect.ValueOf((*tensorcore.GridStylers)(nil)),
"Layout": reflect.ValueOf((*tensorcore.Layout)(nil)),
"Table": reflect.ValueOf((*tensorcore.Table)(nil)),
"TableButton": reflect.ValueOf((*tensorcore.TableButton)(nil)),
"TensorButton": reflect.ValueOf((*tensorcore.TensorButton)(nil)),
"TensorEditor": reflect.ValueOf((*tensorcore.TensorEditor)(nil)),
"TensorGrid": reflect.ValueOf((*tensorcore.TensorGrid)(nil)),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/goal/goalib'. DO NOT EDIT.
package tensorsymbols
import (
"cogentcore.org/lab/goal/goalib"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/goal/goalib/goalib"] = map[string]reflect.Value{
// function, constant and variable definitions
"AllFiles": reflect.ValueOf(goalib.AllFiles),
"FileExists": reflect.ValueOf(goalib.FileExists),
"ReadFile": reflect.ValueOf(goalib.ReadFile),
"ReplaceInFile": reflect.ValueOf(goalib.ReplaceInFile),
"SplitLines": reflect.ValueOf(goalib.SplitLines),
"StringsToAnys": reflect.ValueOf(goalib.StringsToAnys),
"WriteFile": reflect.ValueOf(goalib.WriteFile),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/matrix'. DO NOT EDIT.
package tensorsymbols
import (
"cogentcore.org/lab/matrix"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/matrix/matrix"] = map[string]reflect.Value{
// function, constant and variable definitions
"CallOut1": reflect.ValueOf(matrix.CallOut1),
"CallOut2": reflect.ValueOf(matrix.CallOut2),
"CopyFromDense": reflect.ValueOf(matrix.CopyFromDense),
"Det": reflect.ValueOf(matrix.Det),
"Diagonal": reflect.ValueOf(matrix.Diagonal),
"DiagonalIndices": reflect.ValueOf(matrix.DiagonalIndices),
"DiagonalN": reflect.ValueOf(matrix.DiagonalN),
"Eig": reflect.ValueOf(matrix.Eig),
"EigOut": reflect.ValueOf(matrix.EigOut),
"EigSym": reflect.ValueOf(matrix.EigSym),
"EigSymOut": reflect.ValueOf(matrix.EigSymOut),
"Identity": reflect.ValueOf(matrix.Identity),
"Inverse": reflect.ValueOf(matrix.Inverse),
"InverseOut": reflect.ValueOf(matrix.InverseOut),
"LogDet": reflect.ValueOf(matrix.LogDet),
"Mul": reflect.ValueOf(matrix.Mul),
"MulOut": reflect.ValueOf(matrix.MulOut),
"NewDense": reflect.ValueOf(matrix.NewDense),
"NewMatrix": reflect.ValueOf(matrix.NewMatrix),
"NewSymmetric": reflect.ValueOf(matrix.NewSymmetric),
"ProjectOnMatrixColumn": reflect.ValueOf(matrix.ProjectOnMatrixColumn),
"ProjectOnMatrixColumnOut": reflect.ValueOf(matrix.ProjectOnMatrixColumnOut),
"SVD": reflect.ValueOf(matrix.SVD),
"SVDOut": reflect.ValueOf(matrix.SVDOut),
"SVDValues": reflect.ValueOf(matrix.SVDValues),
"SVDValuesOut": reflect.ValueOf(matrix.SVDValuesOut),
"StringCheck": reflect.ValueOf(matrix.StringCheck),
"Trace": reflect.ValueOf(matrix.Trace),
"Tri": reflect.ValueOf(matrix.Tri),
"TriL": reflect.ValueOf(matrix.TriL),
"TriLIndicies": reflect.ValueOf(matrix.TriLIndicies),
"TriLNum": reflect.ValueOf(matrix.TriLNum),
"TriLView": reflect.ValueOf(matrix.TriLView),
"TriU": reflect.ValueOf(matrix.TriU),
"TriUIndicies": reflect.ValueOf(matrix.TriUIndicies),
"TriUNum": reflect.ValueOf(matrix.TriUNum),
"TriUView": reflect.ValueOf(matrix.TriUView),
"TriUpper": reflect.ValueOf(matrix.TriUpper),
// type definitions
"Matrix": reflect.ValueOf((*matrix.Matrix)(nil)),
"Symmetric": reflect.ValueOf((*matrix.Symmetric)(nil)),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/stats/cluster'. DO NOT EDIT.
package tensorsymbols
import (
"cogentcore.org/lab/stats/cluster"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/stats/cluster/cluster"] = map[string]reflect.Value{
// function, constant and variable definitions
"Avg": reflect.ValueOf(cluster.Avg),
"AvgFunc": reflect.ValueOf(cluster.AvgFunc),
"Cluster": reflect.ValueOf(cluster.Cluster),
"Contrast": reflect.ValueOf(cluster.Contrast),
"ContrastFunc": reflect.ValueOf(cluster.ContrastFunc),
"Glom": reflect.ValueOf(cluster.Glom),
"InitAllLeaves": reflect.ValueOf(cluster.InitAllLeaves),
"Max": reflect.ValueOf(cluster.Max),
"MaxFunc": reflect.ValueOf(cluster.MaxFunc),
"MetricsN": reflect.ValueOf(cluster.MetricsN),
"MetricsValues": reflect.ValueOf(cluster.MetricsValues),
"Min": reflect.ValueOf(cluster.Min),
"MinFunc": reflect.ValueOf(cluster.MinFunc),
"NewNode": reflect.ValueOf(cluster.NewNode),
"Plot": reflect.ValueOf(cluster.Plot),
"PlotFromTable": reflect.ValueOf(cluster.PlotFromTable),
"PlotFromTableToTable": reflect.ValueOf(cluster.PlotFromTableToTable),
// type definitions
"MetricFunc": reflect.ValueOf((*cluster.MetricFunc)(nil)),
"Metrics": reflect.ValueOf((*cluster.Metrics)(nil)),
"Node": reflect.ValueOf((*cluster.Node)(nil)),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/stats/convolve'. DO NOT EDIT.
package tensorsymbols
import (
"cogentcore.org/lab/stats/convolve"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/stats/convolve/convolve"] = map[string]reflect.Value{
// function, constant and variable definitions
"GaussianKernel32": reflect.ValueOf(convolve.GaussianKernel32),
"GaussianKernel64": reflect.ValueOf(convolve.GaussianKernel64),
"Slice32": reflect.ValueOf(convolve.Slice32),
"Slice64": reflect.ValueOf(convolve.Slice64),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/stats/glm'. DO NOT EDIT.
package tensorsymbols
import (
"cogentcore.org/lab/stats/glm"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/stats/glm/glm"] = map[string]reflect.Value{
// function, constant and variable definitions
"NewGLM": reflect.ValueOf(glm.NewGLM),
// type definitions
"GLM": reflect.ValueOf((*glm.GLM)(nil)),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/stats/histogram'. DO NOT EDIT.
package tensorsymbols
import (
"cogentcore.org/lab/stats/histogram"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/stats/histogram/histogram"] = map[string]reflect.Value{
// function, constant and variable definitions
"F64": reflect.ValueOf(histogram.F64),
"F64Table": reflect.ValueOf(histogram.F64Table),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/stats/metric'. DO NOT EDIT.
package tensorsymbols
import (
"cogentcore.org/lab/stats/metric"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/stats/metric/metric"] = map[string]reflect.Value{
// function, constant and variable definitions
"AsMetricFunc": reflect.ValueOf(metric.AsMetricFunc),
"AsMetricOutFunc": reflect.ValueOf(metric.AsMetricOutFunc),
"ClosestRow": reflect.ValueOf(metric.ClosestRow),
"ClosestRowOut": reflect.ValueOf(metric.ClosestRowOut),
"Correlation": reflect.ValueOf(metric.Correlation),
"CorrelationOut": reflect.ValueOf(metric.CorrelationOut),
"CorrelationOut64": reflect.ValueOf(metric.CorrelationOut64),
"Cosine": reflect.ValueOf(metric.Cosine),
"CosineOut": reflect.ValueOf(metric.CosineOut),
"CosineOut64": reflect.ValueOf(metric.CosineOut64),
"Covariance": reflect.ValueOf(metric.Covariance),
"CovarianceMatrix": reflect.ValueOf(metric.CovarianceMatrix),
"CovarianceMatrixOut": reflect.ValueOf(metric.CovarianceMatrixOut),
"CovarianceOut": reflect.ValueOf(metric.CovarianceOut),
"CrossEntropy": reflect.ValueOf(metric.CrossEntropy),
"CrossEntropyOut": reflect.ValueOf(metric.CrossEntropyOut),
"CrossMatrix": reflect.ValueOf(metric.CrossMatrix),
"CrossMatrixOut": reflect.ValueOf(metric.CrossMatrixOut),
"DotProduct": reflect.ValueOf(metric.DotProduct),
"DotProductOut": reflect.ValueOf(metric.DotProductOut),
"Hamming": reflect.ValueOf(metric.Hamming),
"HammingOut": reflect.ValueOf(metric.HammingOut),
"InvCorrelation": reflect.ValueOf(metric.InvCorrelation),
"InvCorrelationOut": reflect.ValueOf(metric.InvCorrelationOut),
"InvCosine": reflect.ValueOf(metric.InvCosine),
"InvCosineOut": reflect.ValueOf(metric.InvCosineOut),
"L1Norm": reflect.ValueOf(metric.L1Norm),
"L1NormOut": reflect.ValueOf(metric.L1NormOut),
"L2Norm": reflect.ValueOf(metric.L2Norm),
"L2NormBinTol": reflect.ValueOf(metric.L2NormBinTol),
"L2NormBinTolOut": reflect.ValueOf(metric.L2NormBinTolOut),
"L2NormOut": reflect.ValueOf(metric.L2NormOut),
"Matrix": reflect.ValueOf(metric.Matrix),
"MatrixOut": reflect.ValueOf(metric.MatrixOut),
"MetricCorrelation": reflect.ValueOf(metric.MetricCorrelation),
"MetricCosine": reflect.ValueOf(metric.MetricCosine),
"MetricCovariance": reflect.ValueOf(metric.MetricCovariance),
"MetricCrossEntropy": reflect.ValueOf(metric.MetricCrossEntropy),
"MetricDotProduct": reflect.ValueOf(metric.MetricDotProduct),
"MetricHamming": reflect.ValueOf(metric.MetricHamming),
"MetricInvCorrelation": reflect.ValueOf(metric.MetricInvCorrelation),
"MetricInvCosine": reflect.ValueOf(metric.MetricInvCosine),
"MetricL1Norm": reflect.ValueOf(metric.MetricL1Norm),
"MetricL2Norm": reflect.ValueOf(metric.MetricL2Norm),
"MetricL2NormBinTol": reflect.ValueOf(metric.MetricL2NormBinTol),
"MetricSumSquares": reflect.ValueOf(metric.MetricSumSquares),
"MetricSumSquaresBinTol": reflect.ValueOf(metric.MetricSumSquaresBinTol),
"MetricsN": reflect.ValueOf(metric.MetricsN),
"MetricsValues": reflect.ValueOf(metric.MetricsValues),
"SumSquares": reflect.ValueOf(metric.SumSquares),
"SumSquaresBinTol": reflect.ValueOf(metric.SumSquaresBinTol),
"SumSquaresBinTolOut": reflect.ValueOf(metric.SumSquaresBinTolOut),
"SumSquaresBinTolScaleOut64": reflect.ValueOf(metric.SumSquaresBinTolScaleOut64),
"SumSquaresOut": reflect.ValueOf(metric.SumSquaresOut),
"SumSquaresOut64": reflect.ValueOf(metric.SumSquaresOut64),
"SumSquaresScaleOut64": reflect.ValueOf(metric.SumSquaresScaleOut64),
"Vectorize2Out64": reflect.ValueOf(metric.Vectorize2Out64),
"Vectorize3Out64": reflect.ValueOf(metric.Vectorize3Out64),
"VectorizeOut64": reflect.ValueOf(metric.VectorizeOut64),
"VectorizePre3Out64": reflect.ValueOf(metric.VectorizePre3Out64),
"VectorizePreOut64": reflect.ValueOf(metric.VectorizePreOut64),
// type definitions
"MetricFunc": reflect.ValueOf((*metric.MetricFunc)(nil)),
"MetricOutFunc": reflect.ValueOf((*metric.MetricOutFunc)(nil)),
"Metrics": reflect.ValueOf((*metric.Metrics)(nil)),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/stats/stats'. DO NOT EDIT.
package tensorsymbols
import (
"cogentcore.org/lab/stats/stats"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/stats/stats/stats"] = map[string]reflect.Value{
// function, constant and variable definitions
"AsStatsFunc": reflect.ValueOf(stats.AsStatsFunc),
"Binarize": reflect.ValueOf(stats.Binarize),
"BinarizeOut": reflect.ValueOf(stats.BinarizeOut),
"Clamp": reflect.ValueOf(stats.Clamp),
"ClampOut": reflect.ValueOf(stats.ClampOut),
"Count": reflect.ValueOf(stats.Count),
"CountOut": reflect.ValueOf(stats.CountOut),
"CountOut64": reflect.ValueOf(stats.CountOut64),
"Describe": reflect.ValueOf(stats.Describe),
"DescribeTable": reflect.ValueOf(stats.DescribeTable),
"DescribeTableAll": reflect.ValueOf(stats.DescribeTableAll),
"DescriptiveStats": reflect.ValueOf(&stats.DescriptiveStats).Elem(),
"Final": reflect.ValueOf(stats.Final),
"FinalOut": reflect.ValueOf(stats.FinalOut),
"First": reflect.ValueOf(stats.First),
"FirstOut": reflect.ValueOf(stats.FirstOut),
"GroupAll": reflect.ValueOf(stats.GroupAll),
"GroupDescribe": reflect.ValueOf(stats.GroupDescribe),
"GroupStats": reflect.ValueOf(stats.GroupStats),
"GroupStatsAsTable": reflect.ValueOf(stats.GroupStatsAsTable),
"GroupStatsAsTableNoStatName": reflect.ValueOf(stats.GroupStatsAsTableNoStatName),
"Groups": reflect.ValueOf(stats.Groups),
"L1Norm": reflect.ValueOf(stats.L1Norm),
"L1NormOut": reflect.ValueOf(stats.L1NormOut),
"L2Norm": reflect.ValueOf(stats.L2Norm),
"L2NormOut": reflect.ValueOf(stats.L2NormOut),
"L2NormOut64": reflect.ValueOf(stats.L2NormOut64),
"Max": reflect.ValueOf(stats.Max),
"MaxAbs": reflect.ValueOf(stats.MaxAbs),
"MaxAbsOut": reflect.ValueOf(stats.MaxAbsOut),
"MaxOut": reflect.ValueOf(stats.MaxOut),
"Mean": reflect.ValueOf(stats.Mean),
"MeanOut": reflect.ValueOf(stats.MeanOut),
"MeanOut64": reflect.ValueOf(stats.MeanOut64),
"MeanTables": reflect.ValueOf(stats.MeanTables),
"Median": reflect.ValueOf(stats.Median),
"MedianOut": reflect.ValueOf(stats.MedianOut),
"Min": reflect.ValueOf(stats.Min),
"MinAbs": reflect.ValueOf(stats.MinAbs),
"MinAbsOut": reflect.ValueOf(stats.MinAbsOut),
"MinOut": reflect.ValueOf(stats.MinOut),
"Prod": reflect.ValueOf(stats.Prod),
"ProdOut": reflect.ValueOf(stats.ProdOut),
"Q1": reflect.ValueOf(stats.Q1),
"Q1Out": reflect.ValueOf(stats.Q1Out),
"Q3": reflect.ValueOf(stats.Q3),
"Q3Out": reflect.ValueOf(stats.Q3Out),
"Quantiles": reflect.ValueOf(stats.Quantiles),
"QuantilesOut": reflect.ValueOf(stats.QuantilesOut),
"Sem": reflect.ValueOf(stats.Sem),
"SemOut": reflect.ValueOf(stats.SemOut),
"SemPop": reflect.ValueOf(stats.SemPop),
"SemPopOut": reflect.ValueOf(stats.SemPopOut),
"StatCount": reflect.ValueOf(stats.StatCount),
"StatFinal": reflect.ValueOf(stats.StatFinal),
"StatFirst": reflect.ValueOf(stats.StatFirst),
"StatL1Norm": reflect.ValueOf(stats.StatL1Norm),
"StatL2Norm": reflect.ValueOf(stats.StatL2Norm),
"StatMax": reflect.ValueOf(stats.StatMax),
"StatMaxAbs": reflect.ValueOf(stats.StatMaxAbs),
"StatMean": reflect.ValueOf(stats.StatMean),
"StatMedian": reflect.ValueOf(stats.StatMedian),
"StatMin": reflect.ValueOf(stats.StatMin),
"StatMinAbs": reflect.ValueOf(stats.StatMinAbs),
"StatProd": reflect.ValueOf(stats.StatProd),
"StatQ1": reflect.ValueOf(stats.StatQ1),
"StatQ3": reflect.ValueOf(stats.StatQ3),
"StatSem": reflect.ValueOf(stats.StatSem),
"StatSemPop": reflect.ValueOf(stats.StatSemPop),
"StatStd": reflect.ValueOf(stats.StatStd),
"StatStdPop": reflect.ValueOf(stats.StatStdPop),
"StatSum": reflect.ValueOf(stats.StatSum),
"StatSumSq": reflect.ValueOf(stats.StatSumSq),
"StatVar": reflect.ValueOf(stats.StatVar),
"StatVarPop": reflect.ValueOf(stats.StatVarPop),
"StatsN": reflect.ValueOf(stats.StatsN),
"StatsValues": reflect.ValueOf(stats.StatsValues),
"Std": reflect.ValueOf(stats.Std),
"StdOut": reflect.ValueOf(stats.StdOut),
"StdOut64": reflect.ValueOf(stats.StdOut64),
"StdPop": reflect.ValueOf(stats.StdPop),
"StdPopOut": reflect.ValueOf(stats.StdPopOut),
"StripPackage": reflect.ValueOf(stats.StripPackage),
"Sum": reflect.ValueOf(stats.Sum),
"SumOut": reflect.ValueOf(stats.SumOut),
"SumOut64": reflect.ValueOf(stats.SumOut64),
"SumSq": reflect.ValueOf(stats.SumSq),
"SumSqDevOut64": reflect.ValueOf(stats.SumSqDevOut64),
"SumSqOut": reflect.ValueOf(stats.SumSqOut),
"SumSqOut64": reflect.ValueOf(stats.SumSqOut64),
"SumSqScaleOut64": reflect.ValueOf(stats.SumSqScaleOut64),
"TableGroupDescribe": reflect.ValueOf(stats.TableGroupDescribe),
"TableGroupStats": reflect.ValueOf(stats.TableGroupStats),
"TableGroups": reflect.ValueOf(stats.TableGroups),
"UnitNorm": reflect.ValueOf(stats.UnitNorm),
"UnitNormOut": reflect.ValueOf(stats.UnitNormOut),
"Var": reflect.ValueOf(stats.Var),
"VarOut": reflect.ValueOf(stats.VarOut),
"VarOut64": reflect.ValueOf(stats.VarOut64),
"VarPop": reflect.ValueOf(stats.VarPop),
"VarPopOut": reflect.ValueOf(stats.VarPopOut),
"VarPopOut64": reflect.ValueOf(stats.VarPopOut64),
"Vectorize2Out64": reflect.ValueOf(stats.Vectorize2Out64),
"VectorizeOut64": reflect.ValueOf(stats.VectorizeOut64),
"VectorizePreOut64": reflect.ValueOf(stats.VectorizePreOut64),
"ZScore": reflect.ValueOf(stats.ZScore),
"ZScoreOut": reflect.ValueOf(stats.ZScoreOut),
// type definitions
"Stats": reflect.ValueOf((*stats.Stats)(nil)),
"StatsFunc": reflect.ValueOf((*stats.StatsFunc)(nil)),
"StatsOutFunc": reflect.ValueOf((*stats.StatsOutFunc)(nil)),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/table'. DO NOT EDIT.
package tensorsymbols
import (
"cogentcore.org/lab/table"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/table/table"] = map[string]reflect.Value{
// function, constant and variable definitions
"CleanCatTSV": reflect.ValueOf(table.CleanCatTSV),
"ConfigFromDataValues": reflect.ValueOf(table.ConfigFromDataValues),
"ConfigFromHeaders": reflect.ValueOf(table.ConfigFromHeaders),
"ConfigFromTableHeaders": reflect.ValueOf(table.ConfigFromTableHeaders),
"DetectTableHeaders": reflect.ValueOf(table.DetectTableHeaders),
"ErrLogNoNewRows": reflect.ValueOf(&table.ErrLogNoNewRows).Elem(),
"Headers": reflect.ValueOf(table.Headers),
"InferDataType": reflect.ValueOf(table.InferDataType),
"New": reflect.ValueOf(table.New),
"NewColumns": reflect.ValueOf(table.NewColumns),
"NewSliceTable": reflect.ValueOf(table.NewSliceTable),
"NewView": reflect.ValueOf(table.NewView),
"NoHeaders": reflect.ValueOf(table.NoHeaders),
"ShapeFromString": reflect.ValueOf(table.ShapeFromString),
"TableColumnType": reflect.ValueOf(table.TableColumnType),
"TableHeaderChar": reflect.ValueOf(table.TableHeaderChar),
"TableHeaderToType": reflect.ValueOf(&table.TableHeaderToType).Elem(),
"UpdateSliceTable": reflect.ValueOf(table.UpdateSliceTable),
// type definitions
"Columns": reflect.ValueOf((*table.Columns)(nil)),
"FilterFunc": reflect.ValueOf((*table.FilterFunc)(nil)),
"Table": reflect.ValueOf((*table.Table)(nil)),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/tensor/tmath'. DO NOT EDIT.
package tensorsymbols
import (
"cogentcore.org/lab/tensor/tmath"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/tensor/tmath/tmath"] = map[string]reflect.Value{
// function, constant and variable definitions
"Abs": reflect.ValueOf(tmath.Abs),
"AbsOut": reflect.ValueOf(tmath.AbsOut),
"Acos": reflect.ValueOf(tmath.Acos),
"AcosOut": reflect.ValueOf(tmath.AcosOut),
"Acosh": reflect.ValueOf(tmath.Acosh),
"AcoshOut": reflect.ValueOf(tmath.AcoshOut),
"Add": reflect.ValueOf(tmath.Add),
"AddAssign": reflect.ValueOf(tmath.AddAssign),
"AddOut": reflect.ValueOf(tmath.AddOut),
"And": reflect.ValueOf(tmath.And),
"AndOut": reflect.ValueOf(tmath.AndOut),
"Asin": reflect.ValueOf(tmath.Asin),
"AsinOut": reflect.ValueOf(tmath.AsinOut),
"Asinh": reflect.ValueOf(tmath.Asinh),
"AsinhOut": reflect.ValueOf(tmath.AsinhOut),
"Assign": reflect.ValueOf(tmath.Assign),
"Atan": reflect.ValueOf(tmath.Atan),
"Atan2": reflect.ValueOf(tmath.Atan2),
"Atan2Out": reflect.ValueOf(tmath.Atan2Out),
"AtanOut": reflect.ValueOf(tmath.AtanOut),
"Atanh": reflect.ValueOf(tmath.Atanh),
"AtanhOut": reflect.ValueOf(tmath.AtanhOut),
"Cbrt": reflect.ValueOf(tmath.Cbrt),
"CbrtOut": reflect.ValueOf(tmath.CbrtOut),
"Ceil": reflect.ValueOf(tmath.Ceil),
"CeilOut": reflect.ValueOf(tmath.CeilOut),
"Copysign": reflect.ValueOf(tmath.Copysign),
"CopysignOut": reflect.ValueOf(tmath.CopysignOut),
"Cos": reflect.ValueOf(tmath.Cos),
"CosOut": reflect.ValueOf(tmath.CosOut),
"Cosh": reflect.ValueOf(tmath.Cosh),
"CoshOut": reflect.ValueOf(tmath.CoshOut),
"Dec": reflect.ValueOf(tmath.Dec),
"Dim": reflect.ValueOf(tmath.Dim),
"DimOut": reflect.ValueOf(tmath.DimOut),
"Div": reflect.ValueOf(tmath.Div),
"DivAssign": reflect.ValueOf(tmath.DivAssign),
"DivOut": reflect.ValueOf(tmath.DivOut),
"Equal": reflect.ValueOf(tmath.Equal),
"EqualOut": reflect.ValueOf(tmath.EqualOut),
"Erf": reflect.ValueOf(tmath.Erf),
"ErfOut": reflect.ValueOf(tmath.ErfOut),
"Erfc": reflect.ValueOf(tmath.Erfc),
"ErfcOut": reflect.ValueOf(tmath.ErfcOut),
"Erfcinv": reflect.ValueOf(tmath.Erfcinv),
"ErfcinvOut": reflect.ValueOf(tmath.ErfcinvOut),
"Erfinv": reflect.ValueOf(tmath.Erfinv),
"ErfinvOut": reflect.ValueOf(tmath.ErfinvOut),
"Exp": reflect.ValueOf(tmath.Exp),
"Exp2": reflect.ValueOf(tmath.Exp2),
"Exp2Out": reflect.ValueOf(tmath.Exp2Out),
"ExpOut": reflect.ValueOf(tmath.ExpOut),
"Expm1": reflect.ValueOf(tmath.Expm1),
"Expm1Out": reflect.ValueOf(tmath.Expm1Out),
"Floor": reflect.ValueOf(tmath.Floor),
"FloorOut": reflect.ValueOf(tmath.FloorOut),
"Gamma": reflect.ValueOf(tmath.Gamma),
"GammaOut": reflect.ValueOf(tmath.GammaOut),
"Greater": reflect.ValueOf(tmath.Greater),
"GreaterEqual": reflect.ValueOf(tmath.GreaterEqual),
"GreaterEqualOut": reflect.ValueOf(tmath.GreaterEqualOut),
"GreaterOut": reflect.ValueOf(tmath.GreaterOut),
"Hypot": reflect.ValueOf(tmath.Hypot),
"HypotOut": reflect.ValueOf(tmath.HypotOut),
"Inc": reflect.ValueOf(tmath.Inc),
"J0": reflect.ValueOf(tmath.J0),
"J0Out": reflect.ValueOf(tmath.J0Out),
"J1": reflect.ValueOf(tmath.J1),
"J1Out": reflect.ValueOf(tmath.J1Out),
"Less": reflect.ValueOf(tmath.Less),
"LessEqual": reflect.ValueOf(tmath.LessEqual),
"LessEqualOut": reflect.ValueOf(tmath.LessEqualOut),
"LessOut": reflect.ValueOf(tmath.LessOut),
"Log": reflect.ValueOf(tmath.Log),
"Log10": reflect.ValueOf(tmath.Log10),
"Log10Out": reflect.ValueOf(tmath.Log10Out),
"Log1p": reflect.ValueOf(tmath.Log1p),
"Log1pOut": reflect.ValueOf(tmath.Log1pOut),
"Log2": reflect.ValueOf(tmath.Log2),
"Log2Out": reflect.ValueOf(tmath.Log2Out),
"LogOut": reflect.ValueOf(tmath.LogOut),
"Logb": reflect.ValueOf(tmath.Logb),
"LogbOut": reflect.ValueOf(tmath.LogbOut),
"Max": reflect.ValueOf(tmath.Max),
"MaxOut": reflect.ValueOf(tmath.MaxOut),
"Min": reflect.ValueOf(tmath.Min),
"MinOut": reflect.ValueOf(tmath.MinOut),
"Mod": reflect.ValueOf(tmath.Mod),
"ModAssign": reflect.ValueOf(tmath.ModAssign),
"ModOut": reflect.ValueOf(tmath.ModOut),
"Mul": reflect.ValueOf(tmath.Mul),
"MulAssign": reflect.ValueOf(tmath.MulAssign),
"MulOut": reflect.ValueOf(tmath.MulOut),
"Negate": reflect.ValueOf(tmath.Negate),
"NegateOut": reflect.ValueOf(tmath.NegateOut),
"Nextafter": reflect.ValueOf(tmath.Nextafter),
"NextafterOut": reflect.ValueOf(tmath.NextafterOut),
"Not": reflect.ValueOf(tmath.Not),
"NotEqual": reflect.ValueOf(tmath.NotEqual),
"NotEqualOut": reflect.ValueOf(tmath.NotEqualOut),
"NotOut": reflect.ValueOf(tmath.NotOut),
"Or": reflect.ValueOf(tmath.Or),
"OrOut": reflect.ValueOf(tmath.OrOut),
"Pow": reflect.ValueOf(tmath.Pow),
"PowOut": reflect.ValueOf(tmath.PowOut),
"Remainder": reflect.ValueOf(tmath.Remainder),
"RemainderOut": reflect.ValueOf(tmath.RemainderOut),
"Round": reflect.ValueOf(tmath.Round),
"RoundOut": reflect.ValueOf(tmath.RoundOut),
"RoundToEven": reflect.ValueOf(tmath.RoundToEven),
"RoundToEvenOut": reflect.ValueOf(tmath.RoundToEvenOut),
"Sin": reflect.ValueOf(tmath.Sin),
"SinOut": reflect.ValueOf(tmath.SinOut),
"Sinh": reflect.ValueOf(tmath.Sinh),
"SinhOut": reflect.ValueOf(tmath.SinhOut),
"Sqrt": reflect.ValueOf(tmath.Sqrt),
"SqrtOut": reflect.ValueOf(tmath.SqrtOut),
"Sub": reflect.ValueOf(tmath.Sub),
"SubAssign": reflect.ValueOf(tmath.SubAssign),
"SubOut": reflect.ValueOf(tmath.SubOut),
"Tan": reflect.ValueOf(tmath.Tan),
"TanOut": reflect.ValueOf(tmath.TanOut),
"Tanh": reflect.ValueOf(tmath.Tanh),
"TanhOut": reflect.ValueOf(tmath.TanhOut),
"Trunc": reflect.ValueOf(tmath.Trunc),
"TruncOut": reflect.ValueOf(tmath.TruncOut),
"Y0": reflect.ValueOf(tmath.Y0),
"Y0Out": reflect.ValueOf(tmath.Y0Out),
"Y1": reflect.ValueOf(tmath.Y1),
"Y1Out": reflect.ValueOf(tmath.Y1Out),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/tensor'. DO NOT EDIT.
package tensorsymbols
import (
"cogentcore.org/core/base/metadata"
"cogentcore.org/lab/tensor"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/tensor/tensor"] = map[string]reflect.Value{
// function, constant and variable definitions
"AddFunc": reflect.ValueOf(tensor.AddFunc),
"AddShapes": reflect.ValueOf(tensor.AddShapes),
"AlignForAssign": reflect.ValueOf(tensor.AlignForAssign),
"AlignShapes": reflect.ValueOf(tensor.AlignShapes),
"AnySlice": reflect.ValueOf(tensor.AnySlice),
"As1D": reflect.ValueOf(tensor.As1D),
"AsFloat32": reflect.ValueOf(tensor.AsFloat32),
"AsFloat64": reflect.ValueOf(tensor.AsFloat64),
"AsFloat64Scalar": reflect.ValueOf(tensor.AsFloat64Scalar),
"AsFloat64Slice": reflect.ValueOf(tensor.AsFloat64Slice),
"AsIndexed": reflect.ValueOf(tensor.AsIndexed),
"AsInt": reflect.ValueOf(tensor.AsInt),
"AsIntScalar": reflect.ValueOf(tensor.AsIntScalar),
"AsIntSlice": reflect.ValueOf(tensor.AsIntSlice),
"AsMasked": reflect.ValueOf(tensor.AsMasked),
"AsReshaped": reflect.ValueOf(tensor.AsReshaped),
"AsRows": reflect.ValueOf(tensor.AsRows),
"AsSliced": reflect.ValueOf(tensor.AsSliced),
"AsString": reflect.ValueOf(tensor.AsString),
"AsStringScalar": reflect.ValueOf(tensor.AsStringScalar),
"AsStringSlice": reflect.ValueOf(tensor.AsStringSlice),
"Ascending": reflect.ValueOf(tensor.Ascending),
"BoolFloatsFunc": reflect.ValueOf(tensor.BoolFloatsFunc),
"BoolFloatsFuncOut": reflect.ValueOf(tensor.BoolFloatsFuncOut),
"BoolIntsFunc": reflect.ValueOf(tensor.BoolIntsFunc),
"BoolIntsFuncOut": reflect.ValueOf(tensor.BoolIntsFuncOut),
"BoolStringsFunc": reflect.ValueOf(tensor.BoolStringsFunc),
"BoolStringsFuncOut": reflect.ValueOf(tensor.BoolStringsFuncOut),
"BoolToFloat64": reflect.ValueOf(tensor.BoolToFloat64),
"BoolToInt": reflect.ValueOf(tensor.BoolToInt),
"Calc": reflect.ValueOf(tensor.Calc),
"CallOut1": reflect.ValueOf(tensor.CallOut1),
"CallOut1Float64": reflect.ValueOf(tensor.CallOut1Float64),
"CallOut2": reflect.ValueOf(tensor.CallOut2),
"CallOut2Bool": reflect.ValueOf(tensor.CallOut2Bool),
"CallOut2Float64": reflect.ValueOf(tensor.CallOut2Float64),
"CallOut3": reflect.ValueOf(tensor.CallOut3),
"Cells1D": reflect.ValueOf(tensor.Cells1D),
"CellsSize": reflect.ValueOf(tensor.CellsSize),
"Clone": reflect.ValueOf(tensor.Clone),
"ColumnMajorStrides": reflect.ValueOf(tensor.ColumnMajorStrides),
"Comma": reflect.ValueOf(tensor.Comma),
"ContainsFloat": reflect.ValueOf(tensor.ContainsFloat),
"ContainsInt": reflect.ValueOf(tensor.ContainsInt),
"ContainsString": reflect.ValueOf(tensor.ContainsString),
"CopyFromLargerShape": reflect.ValueOf(tensor.CopyFromLargerShape),
"DefaultNumThreads": reflect.ValueOf(tensor.DefaultNumThreads),
"DelimsN": reflect.ValueOf(tensor.DelimsN),
"DelimsValues": reflect.ValueOf(tensor.DelimsValues),
"Descending": reflect.ValueOf(tensor.Descending),
"Detect": reflect.ValueOf(tensor.Detect),
"Ellipsis": reflect.ValueOf(tensor.Ellipsis),
"Flatten": reflect.ValueOf(tensor.Flatten),
"Float64ToBool": reflect.ValueOf(tensor.Float64ToBool),
"Float64ToString": reflect.ValueOf(tensor.Float64ToString),
"FloatAssignFunc": reflect.ValueOf(tensor.FloatAssignFunc),
"FloatBinaryFunc": reflect.ValueOf(tensor.FloatBinaryFunc),
"FloatBinaryFuncOut": reflect.ValueOf(tensor.FloatBinaryFuncOut),
"FloatFunc": reflect.ValueOf(tensor.FloatFunc),
"FloatFuncOut": reflect.ValueOf(tensor.FloatFuncOut),
"FloatPromoteType": reflect.ValueOf(tensor.FloatPromoteType),
"FloatSetFunc": reflect.ValueOf(tensor.FloatSetFunc),
"FromBinary": reflect.ValueOf(tensor.FromBinary),
"FullAxis": reflect.ValueOf(tensor.FullAxis),
"FuncByName": reflect.ValueOf(tensor.FuncByName),
"Funcs": reflect.ValueOf(&tensor.Funcs).Elem(),
"IntToBool": reflect.ValueOf(tensor.IntToBool),
"Mask": reflect.ValueOf(tensor.Mask),
"MaxPrintLineWidth": reflect.ValueOf(&tensor.MaxPrintLineWidth).Elem(),
"MaxSprintLength": reflect.ValueOf(&tensor.MaxSprintLength).Elem(),
"MustBeSameShape": reflect.ValueOf(tensor.MustBeSameShape),
"MustBeValues": reflect.ValueOf(tensor.MustBeValues),
"NFirstLen": reflect.ValueOf(tensor.NFirstLen),
"NFirstRows": reflect.ValueOf(tensor.NFirstRows),
"NMinLen": reflect.ValueOf(tensor.NMinLen),
"NegIndex": reflect.ValueOf(tensor.NegIndex),
"NewAxis": reflect.ValueOf(tensor.NewAxis),
"NewBool": reflect.ValueOf(tensor.NewBool),
"NewBoolFromValues": reflect.ValueOf(tensor.NewBoolFromValues),
"NewBoolShape": reflect.ValueOf(tensor.NewBoolShape),
"NewByte": reflect.ValueOf(tensor.NewByte),
"NewFloat32": reflect.ValueOf(tensor.NewFloat32),
"NewFloat32FromValues": reflect.ValueOf(tensor.NewFloat32FromValues),
"NewFloat32Scalar": reflect.ValueOf(tensor.NewFloat32Scalar),
"NewFloat64": reflect.ValueOf(tensor.NewFloat64),
"NewFloat64FromValues": reflect.ValueOf(tensor.NewFloat64FromValues),
"NewFloat64Full": reflect.ValueOf(tensor.NewFloat64Full),
"NewFloat64Ones": reflect.ValueOf(tensor.NewFloat64Ones),
"NewFloat64Rand": reflect.ValueOf(tensor.NewFloat64Rand),
"NewFloat64Scalar": reflect.ValueOf(tensor.NewFloat64Scalar),
"NewFloat64SpacedLinear": reflect.ValueOf(tensor.NewFloat64SpacedLinear),
"NewFromValues": reflect.ValueOf(tensor.NewFromValues),
"NewFunc": reflect.ValueOf(tensor.NewFunc),
"NewIndexed": reflect.ValueOf(tensor.NewIndexed),
"NewInt": reflect.ValueOf(tensor.NewInt),
"NewInt32": reflect.ValueOf(tensor.NewInt32),
"NewIntFromValues": reflect.ValueOf(tensor.NewIntFromValues),
"NewIntFull": reflect.ValueOf(tensor.NewIntFull),
"NewIntRange": reflect.ValueOf(tensor.NewIntRange),
"NewIntScalar": reflect.ValueOf(tensor.NewIntScalar),
"NewMasked": reflect.ValueOf(tensor.NewMasked),
"NewOfType": reflect.ValueOf(tensor.NewOfType),
"NewReshaped": reflect.ValueOf(tensor.NewReshaped),
"NewRowCellsView": reflect.ValueOf(tensor.NewRowCellsView),
"NewRows": reflect.ValueOf(tensor.NewRows),
"NewShape": reflect.ValueOf(tensor.NewShape),
"NewSlice": reflect.ValueOf(tensor.NewSlice),
"NewSliced": reflect.ValueOf(tensor.NewSliced),
"NewString": reflect.ValueOf(tensor.NewString),
"NewStringFromValues": reflect.ValueOf(tensor.NewStringFromValues),
"NewStringFull": reflect.ValueOf(tensor.NewStringFull),
"NewStringScalar": reflect.ValueOf(tensor.NewStringScalar),
"NewStringShape": reflect.ValueOf(tensor.NewStringShape),
"NewUint32": reflect.ValueOf(tensor.NewUint32),
"NumThreads": reflect.ValueOf(&tensor.NumThreads).Elem(),
"OnedColumn": reflect.ValueOf(tensor.OnedColumn),
"OnedRow": reflect.ValueOf(tensor.OnedRow),
"OpenCSV": reflect.ValueOf(tensor.OpenCSV),
"OpenFS": reflect.ValueOf(tensor.OpenFS),
"Precision": reflect.ValueOf(tensor.Precision),
"Projection2DCoords": reflect.ValueOf(tensor.Projection2DCoords),
"Projection2DDimShapes": reflect.ValueOf(tensor.Projection2DDimShapes),
"Projection2DIndex": reflect.ValueOf(tensor.Projection2DIndex),
"Projection2DSet": reflect.ValueOf(tensor.Projection2DSet),
"Projection2DSetString": reflect.ValueOf(tensor.Projection2DSetString),
"Projection2DShape": reflect.ValueOf(tensor.Projection2DShape),
"Projection2DString": reflect.ValueOf(tensor.Projection2DString),
"Projection2DValue": reflect.ValueOf(tensor.Projection2DValue),
"Range": reflect.ValueOf(tensor.Range),
"ReadCSV": reflect.ValueOf(tensor.ReadCSV),
"Reshape": reflect.ValueOf(tensor.Reshape),
"Reslice": reflect.ValueOf(tensor.Reslice),
"RowMajorStrides": reflect.ValueOf(tensor.RowMajorStrides),
"SaveCSV": reflect.ValueOf(tensor.SaveCSV),
"SetAllFloat64": reflect.ValueOf(tensor.SetAllFloat64),
"SetAllInt": reflect.ValueOf(tensor.SetAllInt),
"SetAllString": reflect.ValueOf(tensor.SetAllString),
"SetCalcFunc": reflect.ValueOf(tensor.SetCalcFunc),
"SetPrecision": reflect.ValueOf(tensor.SetPrecision),
"SetShape": reflect.ValueOf(tensor.SetShape),
"SetShapeFrom": reflect.ValueOf(tensor.SetShapeFrom),
"SetShapeNames": reflect.ValueOf(tensor.SetShapeNames),
"SetShapeSizesFromTensor": reflect.ValueOf(tensor.SetShapeSizesFromTensor),
"ShapeNames": reflect.ValueOf(tensor.ShapeNames),
"SlicesMagicN": reflect.ValueOf(tensor.SlicesMagicN),
"SlicesMagicValues": reflect.ValueOf(tensor.SlicesMagicValues),
"Space": reflect.ValueOf(tensor.Space),
"SplitAtInnerDims": reflect.ValueOf(tensor.SplitAtInnerDims),
"Sprintf": reflect.ValueOf(tensor.Sprintf),
"Squeeze": reflect.ValueOf(tensor.Squeeze),
"StableSort": reflect.ValueOf(tensor.StableSort),
"StringAssignFunc": reflect.ValueOf(tensor.StringAssignFunc),
"StringBinaryFunc": reflect.ValueOf(tensor.StringBinaryFunc),
"StringBinaryFuncOut": reflect.ValueOf(tensor.StringBinaryFuncOut),
"StringToFloat64": reflect.ValueOf(tensor.StringToFloat64),
"Tab": reflect.ValueOf(tensor.Tab),
"ThreadingThreshold": reflect.ValueOf(&tensor.ThreadingThreshold).Elem(),
"ToBinary": reflect.ValueOf(tensor.ToBinary),
"Transpose": reflect.ValueOf(tensor.Transpose),
"UnstableSort": reflect.ValueOf(tensor.UnstableSort),
"Vectorize": reflect.ValueOf(tensor.Vectorize),
"VectorizeOnThreads": reflect.ValueOf(tensor.VectorizeOnThreads),
"VectorizeThreaded": reflect.ValueOf(tensor.VectorizeThreaded),
"WrapIndex1D": reflect.ValueOf(tensor.WrapIndex1D),
"WriteCSV": reflect.ValueOf(tensor.WriteCSV),
// type definitions
"Arg": reflect.ValueOf((*tensor.Arg)(nil)),
"Bool": reflect.ValueOf((*tensor.Bool)(nil)),
"Byte": reflect.ValueOf((*tensor.Byte)(nil)),
"Delims": reflect.ValueOf((*tensor.Delims)(nil)),
"FilterFunc": reflect.ValueOf((*tensor.FilterFunc)(nil)),
"Float32": reflect.ValueOf((*tensor.Float32)(nil)),
"Float64": reflect.ValueOf((*tensor.Float64)(nil)),
"Func": reflect.ValueOf((*tensor.Func)(nil)),
"Indexed": reflect.ValueOf((*tensor.Indexed)(nil)),
"Int": reflect.ValueOf((*tensor.Int)(nil)),
"Int32": reflect.ValueOf((*tensor.Int32)(nil)),
"Masked": reflect.ValueOf((*tensor.Masked)(nil)),
"Reshaped": reflect.ValueOf((*tensor.Reshaped)(nil)),
"RowMajor": reflect.ValueOf((*tensor.RowMajor)(nil)),
"Rows": reflect.ValueOf((*tensor.Rows)(nil)),
"Shape": reflect.ValueOf((*tensor.Shape)(nil)),
"Slice": reflect.ValueOf((*tensor.Slice)(nil)),
"Sliced": reflect.ValueOf((*tensor.Sliced)(nil)),
"SlicesMagic": reflect.ValueOf((*tensor.SlicesMagic)(nil)),
"String": reflect.ValueOf((*tensor.String)(nil)),
"StringMatch": reflect.ValueOf((*tensor.StringMatch)(nil)),
"Tensor": reflect.ValueOf((*tensor.Tensor)(nil)),
"Uint32": reflect.ValueOf((*tensor.Uint32)(nil)),
"Values": reflect.ValueOf((*tensor.Values)(nil)),
// interface wrapper definitions
"_RowMajor": reflect.ValueOf((*_cogentcore_org_lab_tensor_RowMajor)(nil)),
"_Tensor": reflect.ValueOf((*_cogentcore_org_lab_tensor_Tensor)(nil)),
"_Values": reflect.ValueOf((*_cogentcore_org_lab_tensor_Values)(nil)),
}
}
// _cogentcore_org_lab_tensor_RowMajor is an interface wrapper for RowMajor type
type _cogentcore_org_lab_tensor_RowMajor struct {
IValue interface{}
WAppendRow func(val tensor.Values)
WAppendRowFloat func(val ...float64)
WAppendRowInt func(val ...int)
WAppendRowString func(val ...string)
WAsValues func() tensor.Values
WDataType func() reflect.Kind
WDimSize func(dim int) int
WFloat func(i ...int) float64
WFloat1D func(i int) float64
WFloatRow func(row int, cell int) float64
WInt func(i ...int) int
WInt1D func(i int) int
WIntRow func(row int, cell int) int
WIsString func() bool
WLabel func() string
WLen func() int
WMetadata func() *metadata.Data
WNumDims func() int
WRowTensor func(row int) tensor.Values
WSetFloat func(val float64, i ...int)
WSetFloat1D func(val float64, i int)
WSetFloatRow func(val float64, row int, cell int)
WSetInt func(val int, i ...int)
WSetInt1D func(val int, i int)
WSetIntRow func(val int, row int, cell int)
WSetRowTensor func(val tensor.Values, row int)
WSetString func(val string, i ...int)
WSetString1D func(val string, i int)
WSetStringRow func(val string, row int, cell int)
WShape func() *tensor.Shape
WShapeSizes func() []int
WString func() string
WString1D func(i int) string
WStringRow func(row int, cell int) string
WStringValue func(i ...int) string
WSubSpace func(offs ...int) tensor.Values
}
func (W _cogentcore_org_lab_tensor_RowMajor) AppendRow(val tensor.Values) { W.WAppendRow(val) }
func (W _cogentcore_org_lab_tensor_RowMajor) AppendRowFloat(val ...float64) {
W.WAppendRowFloat(val...)
}
func (W _cogentcore_org_lab_tensor_RowMajor) AppendRowInt(val ...int) { W.WAppendRowInt(val...) }
func (W _cogentcore_org_lab_tensor_RowMajor) AppendRowString(val ...string) {
W.WAppendRowString(val...)
}
func (W _cogentcore_org_lab_tensor_RowMajor) AsValues() tensor.Values { return W.WAsValues() }
func (W _cogentcore_org_lab_tensor_RowMajor) DataType() reflect.Kind { return W.WDataType() }
func (W _cogentcore_org_lab_tensor_RowMajor) DimSize(dim int) int { return W.WDimSize(dim) }
func (W _cogentcore_org_lab_tensor_RowMajor) Float(i ...int) float64 { return W.WFloat(i...) }
func (W _cogentcore_org_lab_tensor_RowMajor) Float1D(i int) float64 { return W.WFloat1D(i) }
func (W _cogentcore_org_lab_tensor_RowMajor) FloatRow(row int, cell int) float64 {
return W.WFloatRow(row, cell)
}
func (W _cogentcore_org_lab_tensor_RowMajor) Int(i ...int) int { return W.WInt(i...) }
func (W _cogentcore_org_lab_tensor_RowMajor) Int1D(i int) int { return W.WInt1D(i) }
func (W _cogentcore_org_lab_tensor_RowMajor) IntRow(row int, cell int) int {
return W.WIntRow(row, cell)
}
func (W _cogentcore_org_lab_tensor_RowMajor) IsString() bool { return W.WIsString() }
func (W _cogentcore_org_lab_tensor_RowMajor) Label() string { return W.WLabel() }
func (W _cogentcore_org_lab_tensor_RowMajor) Len() int { return W.WLen() }
func (W _cogentcore_org_lab_tensor_RowMajor) Metadata() *metadata.Data { return W.WMetadata() }
func (W _cogentcore_org_lab_tensor_RowMajor) NumDims() int { return W.WNumDims() }
func (W _cogentcore_org_lab_tensor_RowMajor) RowTensor(row int) tensor.Values {
return W.WRowTensor(row)
}
func (W _cogentcore_org_lab_tensor_RowMajor) SetFloat(val float64, i ...int) { W.WSetFloat(val, i...) }
func (W _cogentcore_org_lab_tensor_RowMajor) SetFloat1D(val float64, i int) { W.WSetFloat1D(val, i) }
func (W _cogentcore_org_lab_tensor_RowMajor) SetFloatRow(val float64, row int, cell int) {
W.WSetFloatRow(val, row, cell)
}
func (W _cogentcore_org_lab_tensor_RowMajor) SetInt(val int, i ...int) { W.WSetInt(val, i...) }
func (W _cogentcore_org_lab_tensor_RowMajor) SetInt1D(val int, i int) { W.WSetInt1D(val, i) }
func (W _cogentcore_org_lab_tensor_RowMajor) SetIntRow(val int, row int, cell int) {
W.WSetIntRow(val, row, cell)
}
func (W _cogentcore_org_lab_tensor_RowMajor) SetRowTensor(val tensor.Values, row int) {
W.WSetRowTensor(val, row)
}
func (W _cogentcore_org_lab_tensor_RowMajor) SetString(val string, i ...int) { W.WSetString(val, i...) }
func (W _cogentcore_org_lab_tensor_RowMajor) SetString1D(val string, i int) { W.WSetString1D(val, i) }
func (W _cogentcore_org_lab_tensor_RowMajor) SetStringRow(val string, row int, cell int) {
W.WSetStringRow(val, row, cell)
}
func (W _cogentcore_org_lab_tensor_RowMajor) Shape() *tensor.Shape { return W.WShape() }
func (W _cogentcore_org_lab_tensor_RowMajor) ShapeSizes() []int { return W.WShapeSizes() }
func (W _cogentcore_org_lab_tensor_RowMajor) String() string {
if W.WString == nil {
return ""
}
return W.WString()
}
func (W _cogentcore_org_lab_tensor_RowMajor) String1D(i int) string { return W.WString1D(i) }
func (W _cogentcore_org_lab_tensor_RowMajor) StringRow(row int, cell int) string {
return W.WStringRow(row, cell)
}
func (W _cogentcore_org_lab_tensor_RowMajor) StringValue(i ...int) string {
return W.WStringValue(i...)
}
func (W _cogentcore_org_lab_tensor_RowMajor) SubSpace(offs ...int) tensor.Values {
return W.WSubSpace(offs...)
}
// _cogentcore_org_lab_tensor_Tensor is an interface wrapper for Tensor type
type _cogentcore_org_lab_tensor_Tensor struct {
IValue interface{}
WAsValues func() tensor.Values
WDataType func() reflect.Kind
WDimSize func(dim int) int
WFloat func(i ...int) float64
WFloat1D func(i int) float64
WInt func(i ...int) int
WInt1D func(i int) int
WIsString func() bool
WLabel func() string
WLen func() int
WMetadata func() *metadata.Data
WNumDims func() int
WSetFloat func(val float64, i ...int)
WSetFloat1D func(val float64, i int)
WSetInt func(val int, i ...int)
WSetInt1D func(val int, i int)
WSetString func(val string, i ...int)
WSetString1D func(val string, i int)
WShape func() *tensor.Shape
WShapeSizes func() []int
WString func() string
WString1D func(i int) string
WStringValue func(i ...int) string
}
func (W _cogentcore_org_lab_tensor_Tensor) AsValues() tensor.Values { return W.WAsValues() }
func (W _cogentcore_org_lab_tensor_Tensor) DataType() reflect.Kind { return W.WDataType() }
func (W _cogentcore_org_lab_tensor_Tensor) DimSize(dim int) int { return W.WDimSize(dim) }
func (W _cogentcore_org_lab_tensor_Tensor) Float(i ...int) float64 { return W.WFloat(i...) }
func (W _cogentcore_org_lab_tensor_Tensor) Float1D(i int) float64 { return W.WFloat1D(i) }
func (W _cogentcore_org_lab_tensor_Tensor) Int(i ...int) int { return W.WInt(i...) }
func (W _cogentcore_org_lab_tensor_Tensor) Int1D(i int) int { return W.WInt1D(i) }
func (W _cogentcore_org_lab_tensor_Tensor) IsString() bool { return W.WIsString() }
func (W _cogentcore_org_lab_tensor_Tensor) Label() string { return W.WLabel() }
func (W _cogentcore_org_lab_tensor_Tensor) Len() int { return W.WLen() }
func (W _cogentcore_org_lab_tensor_Tensor) Metadata() *metadata.Data { return W.WMetadata() }
func (W _cogentcore_org_lab_tensor_Tensor) NumDims() int { return W.WNumDims() }
func (W _cogentcore_org_lab_tensor_Tensor) SetFloat(val float64, i ...int) { W.WSetFloat(val, i...) }
func (W _cogentcore_org_lab_tensor_Tensor) SetFloat1D(val float64, i int) { W.WSetFloat1D(val, i) }
func (W _cogentcore_org_lab_tensor_Tensor) SetInt(val int, i ...int) { W.WSetInt(val, i...) }
func (W _cogentcore_org_lab_tensor_Tensor) SetInt1D(val int, i int) { W.WSetInt1D(val, i) }
func (W _cogentcore_org_lab_tensor_Tensor) SetString(val string, i ...int) { W.WSetString(val, i...) }
func (W _cogentcore_org_lab_tensor_Tensor) SetString1D(val string, i int) { W.WSetString1D(val, i) }
func (W _cogentcore_org_lab_tensor_Tensor) Shape() *tensor.Shape { return W.WShape() }
func (W _cogentcore_org_lab_tensor_Tensor) ShapeSizes() []int { return W.WShapeSizes() }
func (W _cogentcore_org_lab_tensor_Tensor) String() string {
if W.WString == nil {
return ""
}
return W.WString()
}
func (W _cogentcore_org_lab_tensor_Tensor) String1D(i int) string { return W.WString1D(i) }
func (W _cogentcore_org_lab_tensor_Tensor) StringValue(i ...int) string { return W.WStringValue(i...) }
// _cogentcore_org_lab_tensor_Values is an interface wrapper for Values type
type _cogentcore_org_lab_tensor_Values struct {
IValue interface{}
WAppendFrom func(from tensor.Values) tensor.Values
WAppendRow func(val tensor.Values)
WAppendRowFloat func(val ...float64)
WAppendRowInt func(val ...int)
WAppendRowString func(val ...string)
WAsValues func() tensor.Values
WBytes func() []byte
WClone func() tensor.Values
WCopyCellsFrom func(from tensor.Values, to int, start int, n int)
WCopyFrom func(from tensor.Values)
WDataType func() reflect.Kind
WDimSize func(dim int) int
WFloat func(i ...int) float64
WFloat1D func(i int) float64
WFloatRow func(row int, cell int) float64
WInt func(i ...int) int
WInt1D func(i int) int
WIntRow func(row int, cell int) int
WIsString func() bool
WLabel func() string
WLen func() int
WMetadata func() *metadata.Data
WNumDims func() int
WRowTensor func(row int) tensor.Values
WSetFloat func(val float64, i ...int)
WSetFloat1D func(val float64, i int)
WSetFloatRow func(val float64, row int, cell int)
WSetFromBytes func(b []byte)
WSetInt func(val int, i ...int)
WSetInt1D func(val int, i int)
WSetIntRow func(val int, row int, cell int)
WSetNumRows func(rows int)
WSetRowTensor func(val tensor.Values, row int)
WSetShapeSizes func(sizes ...int)
WSetString func(val string, i ...int)
WSetString1D func(val string, i int)
WSetStringRow func(val string, row int, cell int)
WSetZeros func()
WShape func() *tensor.Shape
WShapeSizes func() []int
WSizeof func() int64
WString func() string
WString1D func(i int) string
WStringRow func(row int, cell int) string
WStringValue func(i ...int) string
WSubSpace func(offs ...int) tensor.Values
}
func (W _cogentcore_org_lab_tensor_Values) AppendFrom(from tensor.Values) tensor.Values {
return W.WAppendFrom(from)
}
func (W _cogentcore_org_lab_tensor_Values) AppendRow(val tensor.Values) { W.WAppendRow(val) }
func (W _cogentcore_org_lab_tensor_Values) AppendRowFloat(val ...float64) { W.WAppendRowFloat(val...) }
func (W _cogentcore_org_lab_tensor_Values) AppendRowInt(val ...int) { W.WAppendRowInt(val...) }
func (W _cogentcore_org_lab_tensor_Values) AppendRowString(val ...string) { W.WAppendRowString(val...) }
func (W _cogentcore_org_lab_tensor_Values) AsValues() tensor.Values { return W.WAsValues() }
func (W _cogentcore_org_lab_tensor_Values) Bytes() []byte { return W.WBytes() }
func (W _cogentcore_org_lab_tensor_Values) Clone() tensor.Values { return W.WClone() }
func (W _cogentcore_org_lab_tensor_Values) CopyCellsFrom(from tensor.Values, to int, start int, n int) {
W.WCopyCellsFrom(from, to, start, n)
}
func (W _cogentcore_org_lab_tensor_Values) CopyFrom(from tensor.Values) { W.WCopyFrom(from) }
func (W _cogentcore_org_lab_tensor_Values) DataType() reflect.Kind { return W.WDataType() }
func (W _cogentcore_org_lab_tensor_Values) DimSize(dim int) int { return W.WDimSize(dim) }
func (W _cogentcore_org_lab_tensor_Values) Float(i ...int) float64 { return W.WFloat(i...) }
func (W _cogentcore_org_lab_tensor_Values) Float1D(i int) float64 { return W.WFloat1D(i) }
func (W _cogentcore_org_lab_tensor_Values) FloatRow(row int, cell int) float64 {
return W.WFloatRow(row, cell)
}
func (W _cogentcore_org_lab_tensor_Values) Int(i ...int) int { return W.WInt(i...) }
func (W _cogentcore_org_lab_tensor_Values) Int1D(i int) int { return W.WInt1D(i) }
func (W _cogentcore_org_lab_tensor_Values) IntRow(row int, cell int) int { return W.WIntRow(row, cell) }
func (W _cogentcore_org_lab_tensor_Values) IsString() bool { return W.WIsString() }
func (W _cogentcore_org_lab_tensor_Values) Label() string { return W.WLabel() }
func (W _cogentcore_org_lab_tensor_Values) Len() int { return W.WLen() }
func (W _cogentcore_org_lab_tensor_Values) Metadata() *metadata.Data { return W.WMetadata() }
func (W _cogentcore_org_lab_tensor_Values) NumDims() int { return W.WNumDims() }
func (W _cogentcore_org_lab_tensor_Values) RowTensor(row int) tensor.Values { return W.WRowTensor(row) }
func (W _cogentcore_org_lab_tensor_Values) SetFloat(val float64, i ...int) { W.WSetFloat(val, i...) }
func (W _cogentcore_org_lab_tensor_Values) SetFloat1D(val float64, i int) { W.WSetFloat1D(val, i) }
func (W _cogentcore_org_lab_tensor_Values) SetFloatRow(val float64, row int, cell int) {
W.WSetFloatRow(val, row, cell)
}
func (W _cogentcore_org_lab_tensor_Values) SetFromBytes(b []byte) { W.WSetFromBytes(b) }
func (W _cogentcore_org_lab_tensor_Values) SetInt(val int, i ...int) { W.WSetInt(val, i...) }
func (W _cogentcore_org_lab_tensor_Values) SetInt1D(val int, i int) { W.WSetInt1D(val, i) }
func (W _cogentcore_org_lab_tensor_Values) SetIntRow(val int, row int, cell int) {
W.WSetIntRow(val, row, cell)
}
func (W _cogentcore_org_lab_tensor_Values) SetNumRows(rows int) { W.WSetNumRows(rows) }
func (W _cogentcore_org_lab_tensor_Values) SetRowTensor(val tensor.Values, row int) {
W.WSetRowTensor(val, row)
}
func (W _cogentcore_org_lab_tensor_Values) SetShapeSizes(sizes ...int) { W.WSetShapeSizes(sizes...) }
func (W _cogentcore_org_lab_tensor_Values) SetString(val string, i ...int) { W.WSetString(val, i...) }
func (W _cogentcore_org_lab_tensor_Values) SetString1D(val string, i int) { W.WSetString1D(val, i) }
func (W _cogentcore_org_lab_tensor_Values) SetStringRow(val string, row int, cell int) {
W.WSetStringRow(val, row, cell)
}
func (W _cogentcore_org_lab_tensor_Values) SetZeros() { W.WSetZeros() }
func (W _cogentcore_org_lab_tensor_Values) Shape() *tensor.Shape { return W.WShape() }
func (W _cogentcore_org_lab_tensor_Values) ShapeSizes() []int { return W.WShapeSizes() }
func (W _cogentcore_org_lab_tensor_Values) Sizeof() int64 { return W.WSizeof() }
func (W _cogentcore_org_lab_tensor_Values) String() string {
if W.WString == nil {
return ""
}
return W.WString()
}
func (W _cogentcore_org_lab_tensor_Values) String1D(i int) string { return W.WString1D(i) }
func (W _cogentcore_org_lab_tensor_Values) StringRow(row int, cell int) string {
return W.WStringRow(row, cell)
}
func (W _cogentcore_org_lab_tensor_Values) StringValue(i ...int) string { return W.WStringValue(i...) }
func (W _cogentcore_org_lab_tensor_Values) SubSpace(offs ...int) tensor.Values {
return W.WSubSpace(offs...)
}
// Code generated by 'yaegi extract cogentcore.org/lab/tensorfs'. DO NOT EDIT.
package tensorsymbols
import (
"cogentcore.org/lab/tensorfs"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/tensorfs/tensorfs"] = map[string]reflect.Value{
// function, constant and variable definitions
"AllFiles": reflect.ValueOf(tensorfs.AllFiles),
"Chdir": reflect.ValueOf(tensorfs.Chdir),
"CurDir": reflect.ValueOf(&tensorfs.CurDir).Elem(),
"CurRoot": reflect.ValueOf(&tensorfs.CurRoot).Elem(),
"DirFromTable": reflect.ValueOf(tensorfs.DirFromTable),
"DirOnly": reflect.ValueOf(tensorfs.DirOnly),
"DirTable": reflect.ValueOf(tensorfs.DirTable),
"Get": reflect.ValueOf(tensorfs.Get),
"List": reflect.ValueOf(tensorfs.List),
"ListOutput": reflect.ValueOf(&tensorfs.ListOutput).Elem(),
"Long": reflect.ValueOf(tensorfs.Long),
"Mkdir": reflect.ValueOf(tensorfs.Mkdir),
"NewDir": reflect.ValueOf(tensorfs.NewDir),
"Overwrite": reflect.ValueOf(tensorfs.Overwrite),
"Preserve": reflect.ValueOf(tensorfs.Preserve),
"Record": reflect.ValueOf(tensorfs.Record),
"Recursive": reflect.ValueOf(tensorfs.Recursive),
"Set": reflect.ValueOf(tensorfs.Set),
"SetCopy": reflect.ValueOf(tensorfs.SetCopy),
"SetTensor": reflect.ValueOf(tensorfs.SetTensor),
"Short": reflect.ValueOf(tensorfs.Short),
"Tar": reflect.ValueOf(tensorfs.Tar),
"Untar": reflect.ValueOf(tensorfs.Untar),
"ValueType": reflect.ValueOf(tensorfs.ValueType),
// type definitions
"DirFile": reflect.ValueOf((*tensorfs.DirFile)(nil)),
"File": reflect.ValueOf((*tensorfs.File)(nil)),
"Node": reflect.ValueOf((*tensorfs.Node)(nil)),
"Nodes": reflect.ValueOf((*tensorfs.Nodes)(nil)),
}
}
// Code generated by 'yaegi extract cogentcore.org/lab/vector'. DO NOT EDIT.
package tensorsymbols
import (
"cogentcore.org/lab/vector"
"reflect"
)
func init() {
Symbols["cogentcore.org/lab/vector/vector"] = map[string]reflect.Value{
// function, constant and variable definitions
"Dot": reflect.ValueOf(vector.Dot),
"L1Norm": reflect.ValueOf(vector.L1Norm),
"L2Norm": reflect.ValueOf(vector.L2Norm),
"Mul": reflect.ValueOf(vector.Mul),
"MulOut": reflect.ValueOf(vector.MulOut),
"Sum": reflect.ValueOf(vector.Sum),
}
}
// Copyright (c) 2024, Cogent Core. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package yaegilab provides functions connecting
// https://github.com/cogentcore/yaegi to Cogent Lab.
package yaegilab
import (
"reflect"
"cogentcore.org/core/base/errors"
"cogentcore.org/core/yaegicore"
"cogentcore.org/lab/goal/interpreter"
"cogentcore.org/lab/tensorfs"
"cogentcore.org/lab/yaegilab/labsymbols"
"cogentcore.org/lab/yaegilab/tensorsymbols"
"github.com/cogentcore/yaegi/interp"
)
func init() {
yaegicore.Interpreters["Goal"] = func(options interp.Options) yaegicore.Interpreter {
return NewInterpreter(options)
}
}
// Interpreter implements [yaegicore.Interpreter] using the [interpreter.Interpreter] for Goal.
type Interpreter struct {
*interpreter.Interpreter
}
// NewInterpreter returns a new [Interpreter] initialized with the given options.
func NewInterpreter(options interp.Options) *Interpreter {
return &Interpreter{interpreter.NewInterpreter(options)}
}
func (in *Interpreter) Use(values interp.Exports) error {
return in.Interp.Use(values)
}
func (in *Interpreter) ImportUsed() {
errors.Log(in.Use(tensorsymbols.Symbols))
errors.Log(in.Use(labsymbols.Symbols))
in.Config()
}
func (in *Interpreter) Eval(src string) (res reflect.Value, err error) {
tensorfs.ListOutput = in.Goal.Config.StdIO.Out
in.Interpreter.Goal.TrState.MathRecord = true
res, _, err = in.Interpreter.Eval(src)
tensorfs.ListOutput = nil
return
}