package main
import (
"os"
engine "github.com/klothoplatform/klotho/pkg/engine"
"github.com/spf13/cobra"
)
func main() {
root := newRootCmd()
err := root.Execute()
if err == nil {
return
}
// Shouldn't happen, the engine CLI should handle errors
os.Exit(1)
}
func newRootCmd() *cobra.Command {
em := &engine.EngineMain{}
var root = &cobra.Command{}
em.AddEngineCli(root)
return root
}
package main
import (
"github.com/klothoplatform/klotho/pkg/infra"
"github.com/spf13/cobra"
)
func main() {
var root = &cobra.Command{}
// iac := &infra.IacCli{}
// err := iac.AddIacCli(root)
err := infra.AddIacCli(root)
if err != nil {
panic(err)
}
err = root.Execute()
if err != nil {
panic(err)
}
}
package main
import (
"bytes"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"sort"
"github.com/dominikbraun/graph"
"github.com/klothoplatform/klotho/pkg/dot"
"github.com/klothoplatform/klotho/pkg/graph_addons"
"github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/klothoplatform/klotho/pkg/knowledgebase/properties"
)
func dotAttributes(tmpl *knowledgebase.ResourceTemplate, props graph.VertexProperties) map[string]string {
a := make(map[string]string)
for k, v := range props.Attributes {
if k != "rank" {
a[k] = v
}
}
a["label"] = tmpl.QualifiedTypeName
a["shape"] = "box"
return a
}
func dotEdgeAttributes(
kb knowledgebase.TemplateKB,
e *knowledgebase.EdgeTemplate,
props graph.EdgeProperties,
) map[string]string {
a := make(map[string]string)
for k, v := range props.Attributes {
a[k] = v
}
if e.DeploymentOrderReversed {
a["style"] = "dashed"
}
a["edgetooltip"] = fmt.Sprintf("%s -> %s", e.Source, e.Target)
if source, err := kb.GetResourceTemplate(e.Source); err == nil {
var isTarget func(ps knowledgebase.Properties) knowledgebase.Property
isTarget = func(ps knowledgebase.Properties) knowledgebase.Property {
for _, p := range ps {
name := p.Details().Name
if name == "" {
fmt.Print()
}
switch inst := p.(type) {
case *properties.ResourceProperty:
if inst.AllowedTypes.MatchesAny(e.Target) {
return p
}
case knowledgebase.CollectionProperty:
if ip := inst.Item(); ip != nil {
ret := isTarget(knowledgebase.Properties{"item": ip})
if ret != nil {
return ret
}
}
case knowledgebase.MapProperty:
mapProps := make(knowledgebase.Properties)
if kp := inst.Key(); kp != nil {
mapProps["key"] = kp
}
if vp := inst.Value(); vp != nil {
mapProps["value"] = vp
}
ret := isTarget(mapProps)
if ret != nil {
return ret
}
}
return isTarget(p.SubProperties())
}
return nil
}
prop := isTarget(source.Properties)
if prop != nil {
if label, ok := a["label"]; ok {
a["label"] = label + "\n" + prop.Details().Path
} else {
a["label"] = prop.Details().Path
}
}
}
return a
}
func KbToDot(kb knowledgebase.TemplateKB, out io.Writer) error {
hasGraph, ok := kb.(interface {
Graph() graph.Graph[string, *knowledgebase.ResourceTemplate]
})
if !ok {
return fmt.Errorf("knowledgebase does not have a graph")
}
g := hasGraph.Graph()
ids, err := graph_addons.TopologicalSort(g, func(a, b string) bool {
return a < b
})
if err != nil {
return err
}
var errs error
printf := func(s string, args ...any) {
_, err := fmt.Fprintf(out, s, args...)
errs = errors.Join(errs, err)
}
printf(`digraph {
rankdir = TB
`)
for _, id := range ids {
t, props, err := g.VertexWithProperties(id)
if err != nil {
errs = errors.Join(errs, err)
continue
}
if rank, ok := props.Attributes["rank"]; ok {
printf(" { rank = %s; %q%s; }\n", rank, id, dot.AttributesToString(dotAttributes(t, props)))
} else {
printf(" %q%s;\n", t.QualifiedTypeName, dot.AttributesToString(dotAttributes(t, props)))
}
}
topoIndex := func(id string) int {
for i, id2 := range ids {
if id2 == id {
return i
}
}
return -1
}
edges, err := g.Edges()
if err != nil {
return err
}
sort.Slice(edges, func(i, j int) bool {
ti, tj := topoIndex(edges[i].Source), topoIndex(edges[j].Source)
if ti != tj {
return ti < tj
}
ti, tj = topoIndex(edges[i].Target), topoIndex(edges[j].Target)
return ti < tj
})
for _, e := range edges {
et, ok := e.Properties.Data.(*knowledgebase.EdgeTemplate)
if !ok {
errs = errors.Join(errs, fmt.Errorf("edge %q -> %q has no EdgeTemplate", e.Source, e.Target))
continue
}
printf(" %q -> %q%s\n", e.Source, e.Target, dot.AttributesToString(dotEdgeAttributes(kb, et, e.Properties)))
}
printf("}\n")
return errs
}
func KbToSVG(kb knowledgebase.TemplateKB, prefix string) error {
if debugDir := os.Getenv("KLOTHO_DEBUG_DIR"); debugDir != "" {
prefix = filepath.Join(debugDir, prefix)
}
f, err := os.Create(prefix + ".gv")
if err != nil {
return err
}
defer f.Close()
dotContent := new(bytes.Buffer)
err = KbToDot(kb, io.MultiWriter(f, dotContent))
if err != nil {
return fmt.Errorf("could not render graph to file %s: %v", prefix+".gv", err)
}
svgContent, err := dot.ExecPan(bytes.NewReader(dotContent.Bytes()))
if err != nil {
return fmt.Errorf("could not run 'dot' for %s: %v", prefix+".gv", err)
}
svgFile, err := os.Create(prefix + ".gv.svg")
if err != nil {
return fmt.Errorf("could not create file %s: %v", prefix+".gv.svg", err)
}
defer svgFile.Close()
_, err = fmt.Fprint(svgFile, svgContent)
return err
}
package main
import (
"errors"
"fmt"
"github.com/alecthomas/kong"
"github.com/dominikbraun/graph"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine"
"github.com/klothoplatform/klotho/pkg/engine/path_selection"
"github.com/klothoplatform/klotho/pkg/graph_addons"
"github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/klothoplatform/klotho/pkg/knowledgebase/reader"
"github.com/klothoplatform/klotho/pkg/logging"
"github.com/klothoplatform/klotho/pkg/templates"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
type Args struct {
Verbose bool `short:"v" help:"Enable verbose mode"`
Distance int `short:"d" help:"Distance from single type to display" default:"2"`
Classification string `short:"c" help:"Classification to filter for (like path expansion)"`
Source string `arg:"" optional:""`
Target string `arg:"" optional:""`
}
func main() {
var args Args
ctx := kong.Parse(&args)
logOpts := logging.LogOpts{
Verbose: args.Verbose,
CategoryLogsDir: "",
DefaultLevels: map[string]zapcore.Level{
"lsp": zap.WarnLevel,
"lsp/pylsp": zap.WarnLevel,
},
Encoding: "pretty_console",
}
zap.ReplaceGlobals(logOpts.NewLogger())
defer zap.L().Sync() //nolint:errcheck
if err := args.Run(ctx); err != nil {
panic(err)
}
}
func (args Args) Run(ctx *kong.Context) error {
kb, err := reader.NewKBFromFs(templates.ResourceTemplates, templates.EdgeTemplates, templates.Models)
if err != nil {
return err
}
switch {
case args.Source == "" && args.Target == "":
break
case args.Target == "":
if args.Classification != "" {
return fmt.Errorf("classification can only be used with two types (for now)")
}
kb = args.filterSingleKb(kb)
default:
if args.Classification != "" {
var edge construct.SimpleEdge
if err := edge.Source.UnmarshalText([]byte(args.Source)); err != nil {
return fmt.Errorf("could not parse source: %w", err)
}
edge.Source.Name = "source"
if err := edge.Target.UnmarshalText([]byte(args.Target)); err != nil {
return fmt.Errorf("could not parse target: %w", err)
}
edge.Target.Name = "target"
resultGraph := construct.NewGraph()
err := resultGraph.AddVertex(
&construct.Resource{ID: edge.Source, Properties: make(construct.Properties)},
graph.VertexAttributes(map[string]string{
"rank": "source",
"color": "green",
"penwidth": "2",
}),
)
if err != nil {
return fmt.Errorf("failed to add source vertex to path selection graph for %s: %w", edge, err)
}
err = resultGraph.AddVertex(
&construct.Resource{ID: edge.Target, Properties: make(construct.Properties)},
graph.VertexAttributes(map[string]string{
"rank": "sink",
"color": "green",
"penwidth": "2",
}),
)
if err != nil {
return fmt.Errorf("failed to add target vertex to path selection graph for %s: %w", edge, err)
}
satisfied_paths := 0
addPath := func(path []string) error {
var prevId construct.ResourceId
for i, typeName := range path {
tmpl, err := kb.Graph().Vertex(typeName)
if err != nil {
return fmt.Errorf("failed to get template for path[%d]: %w", i, err)
}
var id construct.ResourceId
switch i {
case 0:
prevId = edge.Source
continue
case len(path) - 1:
id = edge.Target
default:
id = tmpl.Id()
id.Name = "phantom"
if _, err := resultGraph.Vertex(id); errors.Is(err, graph.ErrVertexNotFound) {
res := &construct.Resource{ID: id, Properties: make(construct.Properties)}
if err := resultGraph.AddVertex(res); err != nil {
return fmt.Errorf("failed to add phantom vertex for path[%d]: %w", i, err)
}
}
}
if _, err := resultGraph.Edge(prevId, id); errors.Is(err, graph.ErrEdgeNotFound) {
weight := graph.EdgeWeight(path_selection.CalculateEdgeWeight(edge, prevId, id, 0, 0, args.Classification, kb))
if err := resultGraph.AddEdge(prevId, id, weight); err != nil {
return fmt.Errorf("failed to add edge[%d] %s -> %s: %w", i-1, prevId, id, err)
}
}
prevId = id
}
satisfied_paths++
return nil
}
err = path_selection.ClassPaths(kb.Graph(), args.Source, args.Target, args.Classification, addPath)
if err != nil {
return err
}
zap.S().Debugf("Found %d paths for %s :: %s", satisfied_paths, edge, args.Classification)
return engine.GraphToSVG(kb, resultGraph, "kb_path_selection")
}
kb = args.filterPathKB(kb)
}
return KbToSVG(kb, "knowledgebase")
}
func (args Args) filterPathKB(kb *knowledgebase.KnowledgeBase) *knowledgebase.KnowledgeBase {
var source, target construct.ResourceId
if err := source.UnmarshalText([]byte(args.Source)); err != nil {
panic(fmt.Errorf("could not parse source: %w", err))
}
if err := target.UnmarshalText([]byte(args.Target)); err != nil {
panic(fmt.Errorf("could not parse target: %w", err))
}
paths, err := kb.AllPaths(source, target)
if err != nil {
panic(err)
}
shortestPath, err := graph.ShortestPath(kb.Graph(), args.Source, args.Target)
if err != nil {
panic(err)
}
filteredKb := knowledgebase.NewKB()
g := filteredKb.Graph()
addV := func(t *knowledgebase.ResourceTemplate) (err error) {
if t.QualifiedTypeName == args.Source || t.QualifiedTypeName == args.Target {
attribs := map[string]string{
"color": "green",
"penwidth": "2",
}
if t.QualifiedTypeName == args.Source {
attribs["rank"] = "source"
} else {
attribs["rank"] = "sink"
}
err = g.AddVertex(t, graph.VertexAttributes(attribs))
} else {
err = g.AddVertex(t)
}
if errors.Is(err, graph.ErrVertexAlreadyExists) {
return nil
}
return err
}
addE := func(path []*knowledgebase.ResourceTemplate, t1, t2 *knowledgebase.ResourceTemplate) error {
edge, err := kb.Graph().Edge(t1.QualifiedTypeName, t2.QualifiedTypeName)
if err != nil {
return err
}
err = g.AddEdge(t1.QualifiedTypeName, t2.QualifiedTypeName, func(ep *graph.EdgeProperties) {
*ep = edge.Properties
if len(path) == len(shortestPath) {
ep.Attributes["color"] = "green"
ep.Attributes["penwidth"] = "2"
}
})
if errors.Is(err, graph.ErrEdgeAlreadyExists) {
return nil
}
return err
}
var errs error
for _, path := range paths {
if len(path) > len(shortestPath)*2 {
continue
}
errs = errors.Join(errs, addV(path[0]))
for i, t := range path[1:] {
errs = errors.Join(
errs,
addV(t),
addE(path, path[i], t),
)
}
}
return filteredKb
}
func (args Args) filterSingleKb(kb *knowledgebase.KnowledgeBase) *knowledgebase.KnowledgeBase {
filteredKb := knowledgebase.NewKB()
g := filteredKb.Graph()
r, props, err := kb.Graph().VertexWithProperties(args.Source)
if err != nil {
panic(err)
}
err = g.AddVertex(r, func(vp *graph.VertexProperties) {
*vp = props
vp.Attributes["color"] = "green"
vp.Attributes["penwidth"] = "2"
})
if err != nil {
panic(err)
}
addV := func(s string) (err error) {
t, err := kb.Graph().Vertex(s)
if err != nil {
return err
}
err = g.AddVertex(t)
if errors.Is(err, graph.ErrVertexAlreadyExists) {
return nil
}
return err
}
walkFunc := func(up bool) func(p graph_addons.Path[string], nerr error) error {
edge := func(a, b string) (graph.Edge[*knowledgebase.ResourceTemplate], error) {
if up {
a, b = b, a
}
return kb.Graph().Edge(a, b)
}
return func(p graph_addons.Path[string], nerr error) error {
last := p[len(p)-1]
if err := addV(last); err != nil {
return err
}
edge, err := edge(p[len(p)-2], last)
if err != nil {
return err
}
err = g.AddEdge(edge.Source.QualifiedTypeName, edge.Target.QualifiedTypeName, func(ep *graph.EdgeProperties) {
*ep = edge.Properties
})
if err != nil && !errors.Is(err, graph.ErrEdgeAlreadyExists) {
return err
}
if len(p) >= args.Distance {
return graph_addons.SkipPath
}
return nil
}
}
err = errors.Join(
graph_addons.WalkUp(kb.Graph(), args.Source, walkFunc(true)),
graph_addons.WalkDown(kb.Graph(), args.Source, walkFunc(false)),
)
if err != nil {
panic(err)
}
return filteredKb
}
package main
import (
"fmt"
"github.com/klothoplatform/klotho/pkg/k2/model"
"gopkg.in/yaml.v3"
)
func irCmd(filePath string) string {
ir, err := model.ReadIRFile(filePath)
if err != nil {
return fmt.Sprintf("error reading IR file: %s", err)
}
res, err := yaml.Marshal(ir)
if err != nil {
return fmt.Sprintf("error marshalling IR: %s", err)
}
return string(res)
}
package main
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"github.com/klothoplatform/klotho/pkg/logging"
pulumi "github.com/pulumi/pulumi/sdk/v3"
"github.com/pulumi/pulumi/sdk/v3/go/auto"
)
type (
CliDependency string
CliDependencyConfig struct {
Dependency CliDependency
Optional bool
}
)
const (
CliDependencyDocker CliDependency = "docker"
CliDependencyPulumi CliDependency = "pulumi"
)
type installFunc func(ctx context.Context) error
// InstallDependencies installs the dependencies specified in the configs
func InstallDependencies(ctx context.Context, configs []CliDependencyConfig) error {
var err error
var installers []installFunc
for _, config := range configs {
switch config.Dependency {
case CliDependencyDocker:
if isDockerInstalled() {
continue
}
installers = append(installers, installDocker)
case CliDependencyPulumi:
if isPulumiInstalled() {
continue
}
installers = append(installers, installPulumi)
}
}
log := logging.GetLogger(ctx).Sugar()
if len(installers) > 0 {
log.Infof("Installing CLI dependencies...")
}
for _, installer := range installers {
if e := installer(ctx); e != nil {
err = errors.Join(err, e)
}
}
return err
}
func installDocker(ctx context.Context) error {
// Install docker
installUrl := ""
switch runtime.GOOS {
case "darwin":
installUrl = "https://docs.docker.com/desktop/install/mac-install/"
case "linux":
installUrl = "https://docs.docker.com/desktop/install/linux-install/"
case "windows":
installUrl = "https://docs.docker.com/desktop/install/windows-install/"
default:
return errors.New("unsupported OS")
}
return fmt.Errorf("install docker from %s", installUrl)
}
func installPulumi(ctx context.Context) error {
installDir, err := pulumiInstallDir()
if err != nil {
return err
}
_, err = auto.InstallPulumiCommand(ctx, &auto.PulumiCommandOptions{
Root: installDir,
})
if err != nil {
return fmt.Errorf("failed to install pulumi: %w", err)
}
return nil
}
func isDockerInstalled() bool {
// Check if docker is installed and the daemon is running
cmd := exec.Command("docker", "ps")
err := cmd.Run()
return err == nil
}
func isPulumiInstalled() bool {
installDir, err := pulumiInstallDir()
if err != nil {
return false
}
cmd, err := auto.NewPulumiCommand(&auto.PulumiCommandOptions{
Root: installDir,
})
if err != nil {
return false
}
// The installed version must be the same as the current SDK version
return cmd.Version().EQ(pulumi.Version)
}
func pulumiInstallDir() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".k2", "pulumi", "versions", pulumi.Version.String()), nil
}
package main
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"time"
"github.com/klothoplatform/klotho/pkg/engine/debug"
"github.com/klothoplatform/klotho/pkg/k2/language_host"
pb "github.com/klothoplatform/klotho/pkg/k2/language_host/go"
"github.com/klothoplatform/klotho/pkg/k2/model"
"github.com/klothoplatform/klotho/pkg/k2/orchestration"
"github.com/klothoplatform/klotho/pkg/k2/stack"
"github.com/klothoplatform/klotho/pkg/logging"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
var downConfig struct {
stateDir string
debugMode string
debugPort int
}
func newDownCmd() *cobra.Command {
downCommand := &cobra.Command{
Use: "down",
Short: "Run the down command",
RunE: down,
}
flags := downCommand.Flags()
flags.StringVar(&downConfig.stateDir, "state-directory", "", "State directory")
flags.StringVar(&downConfig.debugMode, "debug", "", "Debug mode")
flags.IntVar(&downConfig.debugPort, "debug-port", 5678, "Language Host Debug port")
return downCommand
}
func getProjectPath(ctx context.Context, inputPath string) (string, error) {
langHost, srvState, err := language_host.StartPythonClient(ctx, language_host.DebugConfig{
Port: downConfig.debugPort,
Mode: downConfig.debugMode,
}, filepath.Dir(inputPath))
if err != nil {
return "", err
}
defer func() {
if err := langHost.Process.Kill(); err != nil {
zap.L().Warn("failed to kill Python client", zap.Error(err))
}
}()
log := logging.GetLogger(ctx).Sugar()
log.Debug("Waiting for Python server to start")
if downConfig.debugMode != "" {
// Don't add a timeout in case there are breakpoints in the language host before an address is printed
<-srvState.Done
} else {
select {
case <-srvState.Done:
case <-time.After(30 * time.Second):
return "", errors.New("timeout waiting for Python server to start")
}
}
if srvState.Error != nil {
return "", srvState.Error
}
conn, err := grpc.NewClient(srvState.Address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return "", fmt.Errorf("failed to connect to Python server: %w", err)
}
defer func(conn *grpc.ClientConn) {
err = conn.Close()
if err != nil {
zap.L().Error("failed to close connection", zap.Error(err))
}
}(conn)
client := pb.NewKlothoServiceClient(conn)
// make sure the ctx used later doesn't have the timeout (which is only for the IR request)
irCtx := ctx
if downConfig.debugMode == "" {
var cancel context.CancelFunc
irCtx, cancel = context.WithTimeout(irCtx, time.Second*10)
defer cancel()
}
req := &pb.IRRequest{Filename: inputPath}
res, err := client.SendIR(irCtx, req)
if err != nil {
return "", fmt.Errorf("error sending IR request: %w", err)
}
ir, err := model.ParseIRFile([]byte(res.GetYamlPayload()))
if err != nil {
return "", fmt.Errorf("error parsing IR file: %w", err)
}
appUrnPath, err := model.UrnPath(ir.AppURN)
if err != nil {
return "", fmt.Errorf("error getting URN path: %w", err)
}
return appUrnPath, nil
}
func down(cmd *cobra.Command, args []string) error {
filePath := args[0]
if _, err := os.Stat(filePath); os.IsNotExist(err) {
return err
}
absolutePath, err := filepath.Abs(filePath)
if err != nil {
return err
}
var projectPath string
switch len(args) {
case 1:
projectPath, err = getProjectPath(cmd.Context(), absolutePath)
if err != nil {
return fmt.Errorf("error getting project path: %w", err)
}
case 4:
project := args[1]
app := args[2]
env := args[3]
projectPath = filepath.Join(project, app, env)
default:
return fmt.Errorf("invalid number of arguments (%d) expected 4", len(args))
}
if downConfig.stateDir == "" {
homeDir, err := os.UserHomeDir()
if err != nil {
return err
}
downConfig.stateDir = filepath.Join(homeDir, ".k2", "projects")
}
debugDir := debug.GetDebugDir(cmd.Context())
if debugDir == "" {
debugDir = downConfig.stateDir
cmd.SetContext(debug.WithDebugDir(cmd.Context(), debugDir))
}
ctx := cmd.Context()
err = InstallDependencies(ctx, []CliDependencyConfig{
{Dependency: CliDependencyPulumi, Optional: false},
})
if err != nil {
return fmt.Errorf("error installing dependencies: %w", err)
}
stateFile := filepath.Join(downConfig.stateDir, projectPath, "state.yaml")
osfs := afero.NewOsFs()
sm := model.NewStateManager(osfs, stateFile)
if !sm.CheckStateFileExists() {
return fmt.Errorf("state file does not exist: %s", stateFile)
}
err = sm.LoadState()
if err != nil {
return fmt.Errorf("error loading state: %w", err)
}
var stackReferences []stack.Reference
for name, construct := range sm.GetAllConstructs() {
constructPath := filepath.Join(downConfig.stateDir, projectPath, name)
stackReference := stack.Reference{
ConstructURN: *construct.URN,
Name: name,
IacDirectory: constructPath,
}
stackReferences = append(stackReferences, stackReference)
}
o := orchestration.NewDownOrchestrator(sm, osfs, downConfig.stateDir)
err = o.RunDownCommand(
cmd.Context(),
orchestration.DownRequest{StackReferences: stackReferences, DryRun: model.DryRun(commonCfg.dryRun)},
5,
)
if err != nil {
return fmt.Errorf("error running down command: %w", err)
}
return nil
}
package main
import (
"context"
"fmt"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/x/term"
clicommon "github.com/klothoplatform/klotho/pkg/cli_common"
"github.com/klothoplatform/klotho/pkg/k2/initialize"
"github.com/klothoplatform/klotho/pkg/tui/prompt"
"github.com/spf13/cobra"
"os"
"path/filepath"
)
var initConfig struct {
projectName string
appName string
environment string
outputDirectory string
defaultRegion string
programFileName string
interactive bool
nonInteractive bool
skipInstall bool
}
var awsDefaultRegions = []string{
"us-east-1", "us-east-2", "us-west-1", "us-west-2",
"af-south-1", "ap-east-1", "ap-south-1", "ap-northeast-1",
"ap-northeast-2", "ap-northeast-3", "ap-southeast-1",
"ap-southeast-2", "ca-central-1", "eu-central-1",
"eu-west-1", "eu-west-2", "eu-west-3", "eu-south-1",
"eu-north-1", "me-south-1", "sa-east-1",
}
func newInitCommand() *cobra.Command {
var initCommand = &cobra.Command{
Use: "init",
Short: "Initialize a new Klotho application",
PreRunE: prerunInit,
RunE: runInit,
}
flags := initCommand.Flags()
flags.StringVarP(&initConfig.appName, "app", "a", "", "App name")
flags.StringVarP(&initConfig.environment, "environment", "e", "", "Environment")
flags.BoolVarP(&initConfig.interactive, "interactive", "i", false, "Interactive mode")
flags.StringVarP(&initConfig.outputDirectory, "output", "o", "", "Output directory")
flags.StringVarP(&initConfig.programFileName, "program", "p", "infra.py", "Program file name")
flags.StringVarP(&initConfig.projectName, "project", "P", "default-project", "Project name")
flags.StringVarP(&initConfig.defaultRegion, "default-region", "R", "", "AWS default region")
flags.BoolVarP(&initConfig.nonInteractive, "non-interactive", "", false, "Non-interactive mode")
flags.BoolVarP(&initConfig.skipInstall, "skip-install", "", false, "Skip installing dependencies")
exitOnError(initCommand.MarkFlagRequired("app"))
exitOnError(initCommand.MarkFlagRequired("program"))
exitOnError(initCommand.MarkFlagRequired("project"))
exitOnError(initCommand.MarkFlagDirname("output"))
exitOnError(initCommand.MarkFlagFilename("program"))
return initCommand
}
func prerunInit(cmd *cobra.Command, args []string) error {
if initConfig.nonInteractive && initConfig.interactive {
return fmt.Errorf("cannot specify both interactive and non-interactive flags")
}
if !term.IsTerminal(os.Stdout.Fd()) {
if initConfig.interactive {
return fmt.Errorf("interactive mode is only supported in a terminal environment")
}
return nil
}
if initConfig.nonInteractive {
return nil
}
return promptInputs(cmd)
}
func runInit(cmd *cobra.Command, args []string) error {
fmt.Println("Initializing Klotho application...")
err := initialize.Application(
initialize.ApplicationRequest{
Context: context.Background(),
ProjectName: initConfig.projectName,
AppName: initConfig.appName,
OutputDirectory: initConfig.outputDirectory,
DefaultRegion: initConfig.defaultRegion,
Runtime: "python",
ProgramFileName: initConfig.programFileName,
Environment: initConfig.environment,
SkipInstall: initConfig.skipInstall,
})
if err != nil {
fmt.Println("Error initializing Klotho application:", err)
return err
}
fmt.Println("Klotho application initialized successfully")
return nil
}
func promptInputs(cmd *cobra.Command) error {
helpers := map[string]prompt.Helper{
"program": {
SuggestionResolverFunc: func(input string) []string {
var err error
outputDir := initConfig.outputDirectory
if outputDir == "" {
outputDir, err = os.Getwd()
if err != nil {
outputDir = "."
}
}
if err == nil {
_, err = os.Stat(filepath.Join(outputDir, initConfig.programFileName))
}
if err == nil {
return []string{fmt.Sprintf("%s.py", initConfig.appName)}
}
return []string{"infra.py"}
},
ValidateFunc: func(input string) error {
outputDir := initConfig.outputDirectory
if outputDir == "" {
var err error
outputDir, err = os.Getwd()
if err != nil {
return nil
}
}
if _, err := os.Stat(filepath.Join(outputDir, input)); err == nil {
return fmt.Errorf("file '%s' already exists", input)
}
return nil
},
},
"environment": {
SuggestionResolverFunc: func(input string) []string {
return []string{"dev", "staging", "prod", "test", "qa", "default", "development", "production"}
},
},
"default-region": {
SuggestionResolverFunc: func(input string) []string {
return awsDefaultRegions
},
},
}
interactiveFlags := []string{"project", "app", "program", "default-region", "environment"}
var flagsToPrompt []string
for _, flagName := range interactiveFlags {
flag := cmd.Flags().Lookup(flagName)
isRequired := false
if required, found := flag.Annotations[cobra.BashCompOneRequiredFlag]; found && required[0] == "true" {
isRequired = true
}
if initConfig.interactive || (isRequired && flag.Value.String() == "") {
flagsToPrompt = append(flagsToPrompt, flagName)
continue
}
// Set any flags that have a default value to changed if they are not empty strings to allow a flag to both be required and have a default value
if flag.Value.String() != "" {
flag.Changed = true
}
}
if len(flagsToPrompt) == 0 {
return nil
}
promptCreator := func(flagName string) prompt.FlagPromptModel {
flag := cmd.Flags().Lookup(flagName)
return prompt.CreatePromptModel(flag, helpers[flagName], clicommon.IsFlagRequired(flag))
}
firstPrompt := promptCreator(flagsToPrompt[0])
model := prompt.MultiFlagPromptModel{
Prompts: []prompt.FlagPromptModel{firstPrompt},
FlagNames: flagsToPrompt,
Cmd: cmd,
Helpers: helpers,
PromptCreator: promptCreator,
}
p := tea.NewProgram(model, tea.WithAltScreen())
finalModel, err := p.Run()
if err != nil {
return err
}
finalMultiPromptModel := finalModel.(prompt.MultiFlagPromptModel)
if finalMultiPromptModel.Quit {
return fmt.Errorf("operation cancelled by user")
}
return nil
}
func exitOnError(err error) {
if err != nil {
fmt.Println(err)
os.Exit(1)
}
}
package main
import (
"os"
"syscall"
"github.com/klothoplatform/klotho/pkg/k2/cleanup"
"github.com/klothoplatform/klotho/pkg/logging"
"go.uber.org/zap"
clicommon "github.com/klothoplatform/klotho/pkg/cli_common"
"github.com/spf13/cobra"
)
var commonCfg struct {
clicommon.CommonConfig
dryRun clicommon.LevelledFlag
}
func cli() int {
// Set up signal and panic handling to ensure cleanup is executed
defer func() {
if r := recover(); r != nil {
_ = cleanup.Execute(syscall.SIGTERM)
panic(r)
}
}()
var rootCmd = &cobra.Command{
Use: "app",
SilenceUsage: true,
SilenceErrors: true,
}
rootCmd.PersistentPreRun = func(cmd *cobra.Command, args []string) {
cmd.SetContext(cleanup.InitializeHandler(cmd.Context()))
cmd.SilenceErrors = true
}
flags := rootCmd.PersistentFlags()
dryRunFlag := flags.VarPF(&commonCfg.dryRun, "dry-run", "n", "Dry run (once for pulumi preview, twice for tsc)")
dryRunFlag.NoOptDefVal = "true" // Allow -n to be used without a value
var cleanupFuncs []func()
defer func() {
for _, f := range cleanupFuncs {
f()
}
}()
initCommand := newInitCommand()
upCommand := newUpCmd()
cleanupFuncs = append(cleanupFuncs, clicommon.SetupCoreCommand(upCommand, &commonCfg.CommonConfig))
downCommand := newDownCmd()
cleanupFuncs = append(cleanupFuncs, clicommon.SetupCoreCommand(downCommand, &commonCfg.CommonConfig))
var irCommand = &cobra.Command{
Use: "ir [file path]",
Short: "Run the IR command",
//Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
filePath := args[0]
if _, err := os.Stat(filePath); err != nil {
return err
}
irCmd(filePath)
return nil
},
}
rootCmd.AddCommand(initCommand)
rootCmd.AddCommand(upCommand)
rootCmd.AddCommand(downCommand)
rootCmd.AddCommand(irCommand)
if err := rootCmd.Execute(); err != nil {
logging.GetLogger(rootCmd.Context()).Error("Failed to execute command", zap.Error(err))
return 1
}
return 0
}
func main() {
os.Exit(cli())
}
package main
import (
"fmt"
"os"
"path/filepath"
"github.com/klothoplatform/klotho/pkg/k2/language_host"
"github.com/klothoplatform/klotho/pkg/logging"
"go.uber.org/zap"
"golang.org/x/sync/semaphore"
"github.com/klothoplatform/klotho/pkg/engine/debug"
pb "github.com/klothoplatform/klotho/pkg/k2/language_host/go"
"github.com/klothoplatform/klotho/pkg/k2/model"
"github.com/klothoplatform/klotho/pkg/k2/orchestration"
"github.com/spf13/afero"
"github.com/spf13/cobra"
)
var upConfig struct {
stateDir string
debugMode string
debugPort int
}
func newUpCmd() *cobra.Command {
var upCommand = &cobra.Command{
Use: "up",
Short: "Run the up command",
RunE: up,
}
flags := upCommand.Flags()
flags.StringVar(&upConfig.stateDir, "state-directory", "", "State directory")
flags.StringVar(&upConfig.debugMode, "debug", "", "Debug mode")
flags.IntVar(&upConfig.debugPort, "debug-port", 5678, "Language Host Debug port")
return upCommand
}
func up(cmd *cobra.Command, args []string) error {
filePath := args[0]
if _, err := os.Stat(filePath); os.IsNotExist(err) {
return err
}
absolutePath, err := filepath.Abs(filePath)
if err != nil {
return err
}
inputPath := absolutePath
ctx := cmd.Context()
if upConfig.stateDir == "" {
homeDir, err := os.UserHomeDir()
if err != nil {
return err
}
upConfig.stateDir = filepath.Join(homeDir, ".k2", "projects")
}
if debugDir := debug.GetDebugDir(ctx); debugDir == "" && commonCfg.CommonConfig.Verbose > 0 {
ctx = debug.WithDebugDir(ctx, upConfig.stateDir)
cmd.SetContext(ctx)
}
err = InstallDependencies(ctx, []CliDependencyConfig{
{Dependency: CliDependencyPulumi, Optional: false},
})
if err != nil {
return fmt.Errorf("error installing dependencies: %w", err)
}
log := logging.GetLogger(ctx).Sugar()
var langHost language_host.LanguageHost
err = langHost.Start(ctx, language_host.DebugConfig{
Port: upConfig.debugPort,
Mode: upConfig.debugMode,
}, filepath.Dir(inputPath))
if err != nil {
return err
}
defer func() {
if err := langHost.Close(); err != nil {
log.Warnf("Error closing language host", zap.Error(err))
}
}()
ir, err := langHost.GetIR(ctx, &pb.IRRequest{Filename: inputPath})
if err != nil {
return fmt.Errorf("error getting IR: %w", err)
}
// Take the IR -- generate and save a state file and stored in the
// output directory, the path should include the environment name and
// the project URN
appUrnPath, err := model.UrnPath(ir.AppURN)
if err != nil {
return fmt.Errorf("error getting URN path: %w", err)
}
appDir := filepath.Join(upConfig.stateDir, appUrnPath)
// Create the app state directory
if err = os.MkdirAll(appDir, 0755); err != nil {
return fmt.Errorf("error creating app directory: %w", err)
}
stateFile := filepath.Join(appDir, "state.yaml")
osfs := afero.NewOsFs()
sm := model.NewStateManager(osfs, stateFile)
if !sm.CheckStateFileExists() {
sm.InitState(ir)
if err = sm.SaveState(); err != nil {
return fmt.Errorf("error saving state: %w", err)
}
} else {
if err = sm.LoadState(); err != nil {
return fmt.Errorf("error loading state: %w", err)
}
}
o, err := orchestration.NewUpOrchestrator(sm, langHost.NewClient(), osfs, appDir)
if err != nil {
return fmt.Errorf("error creating up orchestrator: %w", err)
}
err = o.RunUpCommand(ctx, ir, model.DryRun(commonCfg.dryRun), semaphore.NewWeighted(5))
if err != nil {
return fmt.Errorf("error running up command: %w", err)
}
return nil
}
package async
import "sync"
type (
ConcurrentMap[K comparable, V any] struct {
mu sync.RWMutex
m map[K]V
}
MapEntry[K comparable, V any] struct {
Key K
Value V
}
)
// initForWrite caller must hold a write lock. Read operations do not need to call this - they must all be compatible
// with a nil map.
func (cf *ConcurrentMap[K, V]) initForWrite() {
if cf.m == nil {
cf.m = make(map[K]V)
}
}
func (cf *ConcurrentMap[K, V]) Len() int {
cf.mu.Lock()
defer cf.mu.Unlock()
return len(cf.m)
}
func (cf *ConcurrentMap[K, V]) Set(k K, v V) {
cf.mu.Lock()
defer cf.mu.Unlock()
cf.initForWrite()
cf.m[k] = v
}
func (cf *ConcurrentMap[K, V]) AddAll(entries map[K]V) {
cf.mu.Lock()
defer cf.mu.Unlock()
cf.initForWrite()
for key, value := range entries {
cf.m[key] = value
}
}
// Compute sets the value of key 'k' to the result of the supplied computeFunc.
// If the value of 'ok' is 'false', the entry for key 'k' will be removed from the ConcurrentMap.
func (cf *ConcurrentMap[K, V]) Compute(k K, computeFunc func(k K, v V) (val V, ok bool)) {
cf.mu.Lock()
defer cf.mu.Unlock()
cf.initForWrite()
if val, ok := computeFunc(k, cf.m[k]); ok {
cf.m[k] = val
} else {
cf.Delete(k)
}
}
func (cf *ConcurrentMap[K, V]) Delete(k K) (v V, existed bool) {
cf.mu.Lock()
defer cf.mu.Unlock()
if cf.m != nil {
v, existed = cf.m[k]
delete(cf.m, k)
}
return
}
func (cf *ConcurrentMap[K, V]) Get(k K) (v V, ok bool) {
cf.mu.RLock()
defer cf.mu.RUnlock()
if cf.m != nil {
v, ok = cf.m[k]
}
return
}
func (cf *ConcurrentMap[K, V]) Keys() []K {
cf.mu.RLock()
defer cf.mu.RUnlock()
if cf.m == nil {
return nil
}
ks := make([]K, 0, len(cf.m))
for k := range cf.m {
ks = append(ks, k)
}
return ks
}
func (cf *ConcurrentMap[K, V]) Values() []V {
cf.mu.RLock()
defer cf.mu.RUnlock()
if cf.m == nil {
return nil
}
vs := make([]V, 0, len(cf.m))
for _, v := range cf.m {
vs = append(vs, v)
}
return vs
}
func (cf *ConcurrentMap[K, V]) Entries() []MapEntry[K, V] {
cf.mu.RLock()
defer cf.mu.RUnlock()
if cf.m == nil {
return nil
}
kvs := make([]MapEntry[K, V], 0, len(cf.m))
for k, v := range cf.m {
kvs = append(kvs, MapEntry[K, V]{Key: k, Value: v})
}
return kvs
}
// Each executes `f` for each key-value pair in the map, while holding the lock.
// ! Avoid doing expensive operations in `f`, instead create a copy (eg via `Entries()`).
func (cf *ConcurrentMap[K, V]) Each(f func(k K, v V) (stop bool)) {
cf.mu.RLock()
defer cf.mu.RUnlock()
for k, v := range cf.m {
if stop := f(k, v); stop {
return
}
}
}
package auth
import (
"context"
"crypto/rand"
_ "embed"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/pkg/browser"
"go.uber.org/zap"
"golang.org/x/oauth2"
)
func env(name, dflt string) string {
v, ok := os.LookupEnv(name)
if !ok {
return dflt
}
return v
}
var (
domain = env("AUTH_DOMAIN", "klotho.us.auth0.com")
clientId = env("AUTH_CLIENT_ID", "AeIvquQVLg9jy2V6Jq5Bz48cKQOmIPDw")
browserEnv = env("BROWSER", "")
//go:embed auth0_client_secret.key
clientSecret string
)
func GetAuthToken(ctx context.Context) (*oauth2.Token, *http.Client, error) {
log := zap.S().Named("auth")
auth, err := newAuth(ctx)
if err != nil {
return nil, nil, err
}
if token := readCachedToken(); token != nil {
return token, auth.HTTPClient(ctx, token), nil
}
tokenCh := make(chan *oauth2.Token)
callbackUrl, err := url.Parse(auth.RedirectURL)
if err != nil {
return nil, nil, err
}
state, err := generateRandomState()
if err != nil {
return nil, nil, err
}
srv := &http.Server{Addr: callbackUrl.Host}
srv.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Debugf("received request: %s %s", r.Method, r.URL.Path)
if r.URL.Path != "/callback" {
http.Error(w, "not found", http.StatusNotFound)
return
}
if reqState := r.URL.Query().Get("state"); reqState != state {
log.Warnf("Got mismatched state: expected %s, got %s", state, reqState)
http.Error(w, "state did not match", http.StatusBadRequest)
return
}
token, err := auth.Exchange(r.Context(), r.URL.Query().Get("code"))
if err != nil {
http.Error(w, "failed to exchange token", http.StatusInternalServerError)
log.Errorf("failed to exchange token: %+v", err)
return
}
_, err = auth.VerifyIDToken(r.Context(), token)
if err != nil {
http.Error(w, "failed to verify token", http.StatusInternalServerError)
log.Errorf("failed to verify token: %+v", err)
return
}
token.AccessToken = token.Extra("id_token").(string)
tokenCh <- token
fmt.Fprint(w, `<html><body>Success, you can now close this window</body></html>`)
log.Debugf("successfully authenticated")
})
defer func() {
if err := srv.Shutdown(ctx); err != nil {
zap.S().Errorw("failed to shutdown server", "error", err)
}
}()
ready := make(chan struct{})
srvErrCh := make(chan error)
go func() {
ln, err := net.Listen("tcp", srv.Addr)
if err != nil {
srvErrCh <- err
return
}
ready <- struct{}{}
srvErrCh <- srv.Serve(ln)
}()
for {
select {
case <-ready:
dest := auth.AuthCodeURL(state)
if strings.ToLower(browserEnv) == "none" {
fmt.Printf("To authenticate, please visit:\n\t%s\n", dest)
continue
}
err := browser.OpenURL(dest)
if err != nil {
return nil, nil, err
}
case err := <-srvErrCh:
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return nil, nil, err
}
case token := <-tokenCh:
writeCachedToken(token)
return token, auth.HTTPClient(ctx, token), nil
}
}
}
func readCachedToken() *oauth2.Token {
log := zap.S().Named("auth.cache")
cacheDir, err := os.UserCacheDir()
if err != nil {
log.Debugf("failed to get cache dir: %v", err)
return nil
}
cacheLoc := filepath.Join(cacheDir, "klotho", "token")
cacheFile, err := os.Open(cacheLoc)
if err != nil {
log.Debugf("failed to open cache file: %v", err)
return nil
}
defer cacheFile.Close()
var token oauth2.Token
if err := json.NewDecoder(cacheFile).Decode(&token); err == nil {
log.Debugf("using cached token")
return &token
} else {
log.Debugf("failed to decode token: %v", err)
}
return nil
}
func writeCachedToken(token *oauth2.Token) {
log := zap.S().Named("auth.cache")
cacheDir, err := os.UserCacheDir()
if err != nil {
log.Debugf("failed to get cache dir: %v", err)
return
}
cacheLoc := filepath.Join(cacheDir, "klotho", "token")
err = os.MkdirAll(filepath.Dir(cacheLoc), 0700)
if err != nil {
log.Debugf("failed to create cache dir: %v", err)
return
}
cacheFile, err := os.OpenFile(cacheLoc, os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
log.Debugf("failed to open cache file: %v", err)
return
}
defer cacheFile.Close()
err = json.NewEncoder(cacheFile).Encode(token)
if err != nil {
log.Debugf("failed to write token: %v", err)
}
}
type Authenticator struct {
*oidc.Provider
oauth2.Config
}
func newAuth(ctx context.Context) (*Authenticator, error) {
provider, err := oidc.NewProvider(
ctx,
"https://"+domain+"/",
)
if err != nil {
return nil, err
}
if clientSecret == "" {
return nil, errors.New("missing client secret (pkg/auth/auth0_client_secret.key not embedded)")
}
conf := oauth2.Config{
ClientID: clientId,
ClientSecret: clientSecret,
RedirectURL: "http://localhost:3000/callback",
Endpoint: provider.Endpoint(),
Scopes: []string{oidc.ScopeOpenID, "profile"},
}
return &Authenticator{
Provider: provider,
Config: conf,
}, nil
}
// VerifyIDToken verifies that an *oauth2.Token is a valid *oidc.IDToken.
func (a *Authenticator) VerifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) {
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
return nil, errors.New("no id_token field in oauth2 token")
}
oidcConfig := &oidc.Config{
ClientID: a.ClientID,
}
return a.Verifier(oidcConfig).Verify(ctx, rawIDToken)
}
func generateRandomState() (string, error) {
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
return "", err
}
state := base64.StdEncoding.EncodeToString(b)
return state, nil
}
type idTokenSource struct {
src oauth2.TokenSource
}
func (s *idTokenSource) Token() (*oauth2.Token, error) {
t, err := s.src.Token()
if err != nil {
return nil, err
}
id, ok := t.Extra("id_token").(string)
if ok {
// per TokenSource contract, we must return a copy if modifying
tCopy := *t
t = &tCopy
t.AccessToken = id
}
return t, nil
}
func (auth *Authenticator) HTTPClient(ctx context.Context, token *oauth2.Token) *http.Client {
ts := oauth2.ReuseTokenSource(token, &idTokenSource{src: auth.Config.TokenSource(ctx, token)})
return oauth2.NewClient(ctx, ts)
}
package clicommon
import (
"fmt"
"os"
"path/filepath"
"runtime/pprof"
"github.com/spf13/pflag"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/x/term"
"github.com/klothoplatform/klotho/pkg/logging"
"github.com/klothoplatform/klotho/pkg/tui"
"github.com/spf13/cobra"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
type (
CommonConfig struct {
jsonLog bool
Verbose LevelledFlag
logsDir string
profileTo string
color string
}
)
func setupProfiling(commonCfg *CommonConfig) func() {
if commonCfg.profileTo != "" {
err := os.MkdirAll(filepath.Dir(commonCfg.profileTo), 0755)
if err != nil {
panic(fmt.Errorf("failed to create profile directory: %w", err))
}
profileF, err := os.OpenFile(commonCfg.profileTo, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
panic(fmt.Errorf("failed to open profile file: %w", err))
}
err = pprof.StartCPUProfile(profileF)
if err != nil {
panic(fmt.Errorf("failed to start profile: %w", err))
}
return func() {
pprof.StopCPUProfile()
profileF.Close()
}
}
return func() {}
}
func SetupCoreCommand(root *cobra.Command, commonCfg *CommonConfig) func() {
flags := root.PersistentFlags()
verbosity := flags.VarPF(&commonCfg.Verbose, "verbose", "v", "Enable verbose logging")
verbosity.NoOptDefVal = "true" // Allow -v to be used without a value
flags.BoolVar(&commonCfg.jsonLog, "json-log", false, "Enable JSON logging")
flags.StringVar(&commonCfg.logsDir, "logs-dir", "", "Directory to write logs to")
flags.StringVar(&commonCfg.profileTo, "profiling", "", "Profile to file")
flags.StringVar(&commonCfg.color, "color", "auto", "Colorize output (auto, on, off)")
profileClose := func() {}
tuiClose := func() {}
root.PersistentPreRun = func(cmd *cobra.Command, args []string) {
cmd.SilenceUsage = true // Silence usage after args have been parsed
verbosity := tui.Verbosity(commonCfg.Verbose)
logOpts := logging.LogOpts{
Verbose: verbosity.LogLevel() <= zapcore.DebugLevel,
Color: commonCfg.color,
CategoryLogsDir: commonCfg.logsDir,
DefaultLevels: map[string]zapcore.Level{
"kb.load": zap.WarnLevel,
"engine.opeval": zap.WarnLevel,
"dot": zap.WarnLevel,
"npm": zap.WarnLevel,
"pulumi.events": zap.WarnLevel,
},
}
if commonCfg.jsonLog {
logOpts.Encoding = "json"
}
if term.IsTerminal(os.Stderr.Fd()) {
prog := tea.NewProgram(
tui.NewModel(verbosity),
tea.WithoutSignalHandler(),
tea.WithContext(root.Context()),
tea.WithOutput(os.Stderr),
)
log := zap.New(tui.NewLogCore(logOpts, verbosity, prog))
zap.ReplaceGlobals(log)
go func() {
_, err := prog.Run()
if err != nil {
zap.S().With(zap.Error(err)).Error("TUI exited with error")
} else {
zap.S().Debug("TUI exited")
}
}()
zap.S().Debug("Starting TUI")
cmd.SetContext(tui.WithProgram(cmd.Context(), prog))
tuiClose = func() {
zap.L().Debug("Shutting down TUI")
prog.Quit()
prog.Wait()
}
} else {
log := logOpts.NewLogger()
zap.ReplaceGlobals(log)
}
profileClose = setupProfiling(commonCfg)
}
return func() {
tuiClose()
profileClose()
_ = zap.L().Sync()
}
}
func IsFlagRequired(flag *pflag.Flag) bool {
required, found := flag.Annotations[cobra.BashCompOneRequiredFlag]
return found && required[0] == "true"
}
package clicommon
import "strconv"
type LevelledFlag int
func (f *LevelledFlag) Set(s string) error {
v, err := strconv.ParseBool(s)
if err != nil {
l, intErr := strconv.ParseInt(s, 10, 64)
if intErr != nil {
return err
}
*f = LevelledFlag(l)
return nil
}
if v {
*f++
} else if *f > 0 {
*f--
}
return nil
}
func (f *LevelledFlag) Type() string {
return "levelled_flag"
}
func (f *LevelledFlag) String() string {
return strconv.FormatInt(int64(*f), 10)
}
package cli_config
import (
"os"
"os/user"
"path"
)
// KlothoConfigPath returns a path to a file in ~/.klotho/<filename>
func KlothoConfigPath(file string) (string, error) {
osUser, err := user.Current()
if err != nil {
return "", err
}
klothoPath := path.Join(osUser.HomeDir, ".klotho")
configPath := path.Join(klothoPath, file)
return configPath, nil
}
func CreateKlothoConfigPath() error {
osUser, err := user.Current()
if err != nil {
return err
}
klothoPath := path.Join(osUser.HomeDir, ".klotho")
// create the directory if it doesn't exist
_, err = os.Stat(klothoPath)
if os.IsNotExist(err) {
err = os.MkdirAll(klothoPath, os.ModePerm)
}
if err != nil {
return err
}
return nil
}
package cli_config
import (
"os"
"strings"
)
// EnvVar represents an environment variable, specified by its key name.
// wrapper around os.Getenv. This string's value is the env var key. Use GetOr to get its value, or a
// default if the value isn't set.
type EnvVar string
// GetOr uses os.Getenv to get the env var specified by the target EnvVar. If that env var's value is unset or empty,
// it returns the defaultValue.
func (s EnvVar) GetOr(defaultValue string) string {
value := os.Getenv(string(s))
if value == "" {
return defaultValue
} else {
return value
}
}
// GetBool returns the env var as a boolean.
//
// The value is false if the env var is:
//
// - unset
// - the empty string ("")
// - "0" or "false" (case-insensitive)
//
// The value is true for all other values, including other false-looking strings like "no".
func (s EnvVar) GetBool() bool {
switch strings.ToLower(os.Getenv(string(s))) {
case "", "0", "false":
return false
default:
return true
}
}
func (s EnvVar) IsSet() bool {
_, isSet := os.LookupEnv(string(s))
return isSet
}
package closenicely
import (
"github.com/pkg/errors"
"go.uber.org/zap"
"io"
"syscall"
)
func OrDebug(closer io.Closer) {
FuncOrDebug(closer.Close)
}
func FuncOrDebug(closer func() error) {
// zap.Logger.Sync() always returns a syscall.ENOTTY error when logging to stdout
// see: https://github.com/uber-go/zap/issues/991#issuecomment-962098428
if err := closer(); err != nil && !errors.Is(err, syscall.ENOTTY) {
zap.L().Debug("Failed to close resource", zap.Error(err))
}
}
package collectionutil
func Keys[K comparable, V any](m map[K]V) []K {
keys := make([]K, 0, len(m))
for k := range m {
keys = append(keys, k)
}
return keys
}
// GetOneEntry gets an entry from the given map. If the map contains multiple entries, it's undefined which this
// returns. If the map is empty, this will return the zero value for both the key and map.
func GetOneEntry[K comparable, V any](m map[K]V) (K, V) {
for k, v := range m {
return k, v
}
var k K
var v V
return k, v
}
type ExtendingMap[K comparable, V any] map[K]V
func Extend[K comparable, V any](m map[K]V) ExtendingMap[K, V] {
return ExtendingMap[K, V](m)
}
func (source ExtendingMap[K, V]) Into(target map[K]V) {
for k, v := range source {
target[k] = v
}
}
func CopyMap[K comparable, V any](m map[K]V) map[K]V {
cp := make(map[K]V, len(m))
Extend(m).Into(cp)
return cp
}
package collectionutil
// FlattenUnique appends each unique item in each list, in the order in which it first appears.
//
// Examples:
// - FlattenUnique([]int{1, 2, 3}, []int{4, 3, 4}) => []int{1, 2, 3, 4}
// - FlattenUnique([]int{1, 2, 2}, []int{3, 4}) => []int{1, 2, 3, 4}
func FlattenUnique[E comparable](slices ...[]E) []E {
alreadyInResult := make(map[E]struct{})
var result []E
for _, slice := range slices {
for _, elem := range slice {
_, alreadyIn := alreadyInResult[elem]
if !alreadyIn {
result = append(result, elem)
alreadyInResult[elem] = struct{}{}
}
}
}
return result
}
func Contains[E comparable](slice []E, elem E) bool {
for _, e := range slice {
if e == elem {
return true
}
}
return false
}
package command
import (
"os/exec"
"syscall"
)
func SetProcAttr(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{
Setpgid: true,
Pdeathsig: syscall.SIGTERM,
}
}
package construct
import (
"bytes"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"sort"
"github.com/klothoplatform/klotho/pkg/dot"
)
func dotAttributes(r *Resource) map[string]string {
a := make(map[string]string)
a["label"] = r.ID.String()
a["shape"] = "box"
return a
}
func dotEdgeAttributes(e ResourceEdge) map[string]string {
a := make(map[string]string)
_ = e.Source.WalkProperties(func(path PropertyPath, nerr error) error {
v, _ := path.Get()
if v == e.Target.ID {
a["label"] = path.String()
return StopWalk
}
return nil
})
if e.Properties.Weight > 0 {
if a["label"] == "" {
a["label"] = fmt.Sprintf("%d", e.Properties.Weight)
} else {
a["label"] = fmt.Sprintf("%s\n%d", a["label"], e.Properties.Weight)
}
}
return a
}
func GraphToDOT(g Graph, out io.Writer) error {
ids, err := TopologicalSort(g)
if err != nil {
return err
}
nodes, err := ResolveIds(g, ids)
if err != nil {
return err
}
var errs error
printf := func(s string, args ...any) {
_, err := fmt.Fprintf(out, s, args...)
errs = errors.Join(errs, err)
}
printf(`digraph {
rankdir = TB
`)
for _, n := range nodes {
printf(" %q%s\n", n.ID, dot.AttributesToString(dotAttributes(n)))
}
topoIndex := func(id ResourceId) int {
for i, id2 := range ids {
if id2 == id {
return i
}
}
return -1
}
edges, err := g.Edges()
if err != nil {
return err
}
sort.Slice(edges, func(i, j int) bool {
ti, tj := topoIndex(edges[i].Source), topoIndex(edges[j].Source)
if ti != tj {
return ti < tj
}
ti, tj = topoIndex(edges[i].Target), topoIndex(edges[j].Target)
return ti < tj
})
for _, e := range edges {
edge, err := g.Edge(e.Source, e.Target)
if err != nil {
errs = errors.Join(errs, err)
continue
}
printf(" %q -> %q%s\n", e.Source, e.Target, dot.AttributesToString(dotEdgeAttributes(edge)))
}
printf("}\n")
return errs
}
func GraphToSVG(g Graph, prefix string) error {
if debugDir := os.Getenv("KLOTHO_DEBUG_DIR"); debugDir != "" {
prefix = filepath.Join(debugDir, prefix)
}
f, err := os.Create(prefix + ".gv")
if err != nil {
return err
}
defer f.Close()
dotContent := new(bytes.Buffer)
err = GraphToDOT(g, io.MultiWriter(f, dotContent))
if err != nil {
return fmt.Errorf("could not render graph to file %s: %v", prefix+".gv", err)
}
svgContent, err := dot.ExecPan(bytes.NewReader(dotContent.Bytes()))
if err != nil {
return fmt.Errorf("could not run 'dot' for %s: %v", prefix+".gv", err)
}
svgFile, err := os.Create(prefix + ".gv.svg")
if err != nil {
return fmt.Errorf("could not create file %s: %v", prefix+".gv.svg", err)
}
defer svgFile.Close()
_, err = fmt.Fprint(svgFile, svgContent)
return err
}
package construct
type EdgeData struct {
ConnectionType string `yaml:"connection_type,omitempty" json:"connection_type,omitempty"`
}
// Equals implements an interface used in [graph_addons.MemoryStore] to determine whether edges are equal
// to allow for idempotent edge addition.
func (ed EdgeData) Equals(other any) bool {
if other, ok := other.(EdgeData); ok {
return ed == other
}
return false
}
package construct
import (
"crypto/sha256"
"fmt"
"io"
"sort"
"strings"
"github.com/dominikbraun/graph"
"github.com/klothoplatform/klotho/pkg/graph_addons"
)
type (
Graph = graph.Graph[ResourceId, *Resource]
Edge = graph.Edge[ResourceId]
ResourceEdge = graph.Edge[*Resource]
)
func NewGraphWithOptions(options ...func(*graph.Traits)) Graph {
return graph.NewWithStore(
ResourceHasher,
graph_addons.NewMemoryStore[ResourceId, *Resource](),
options...,
)
}
func NewGraph(options ...func(*graph.Traits)) Graph {
return NewGraphWithOptions(append(options,
graph.Directed(),
)...,
)
}
func NewAcyclicGraph(options ...func(*graph.Traits)) Graph {
return NewGraphWithOptions(append(options, graph.Directed(), graph.PreventCycles())...)
}
func ResourceHasher(r *Resource) ResourceId {
return r.ID
}
func Hash(g Graph) ([]byte, error) {
sum := sha256.New()
err := stringTo(g, sum)
return sum.Sum(nil), err
}
func String(g Graph) (string, error) {
w := new(strings.Builder)
err := stringTo(g, w)
return w.String(), err
}
func stringTo(g Graph, w io.Writer) error {
topo, err := TopologicalSort(g)
if err != nil {
return err
}
adjacent, err := g.AdjacencyMap()
if err != nil {
return err
}
var errs []error
write := func(format string, args ...any) {
_, err := fmt.Fprintf(w, format, args...)
if err != nil {
errs = append(errs, err)
}
}
for _, id := range topo {
write("%q", id)
targets := make([]ResourceId, 0, len(adjacent[id]))
for t := range adjacent[id] {
targets = append(targets, t)
}
sort.Sort(SortedIds(targets))
if len(targets) > 1 {
write("\n")
} else if len(targets) == 1 {
write(" ")
}
for _, t := range targets {
e := adjacent[id][t]
// Adjacent edges always have `id` as the source, so just write the target.
write("-> %q", t)
if e.Properties.Weight > 1 {
write(" (weight=%d)", e.Properties.Weight)
}
write("\n")
}
}
return nil
}
type IdResolutionError map[ResourceId]error
func (e IdResolutionError) Error() string {
if len(e) == 1 {
for id, err := range e {
return fmt.Sprintf("failed to resolve ID %s: %v", id, err)
}
}
var b strings.Builder
b.WriteString("failed to resolve IDs:\n")
for id, err := range e {
fmt.Fprintf(&b, " %s: %v\n", id, err)
}
return b.String()
}
func ResolveIds(g Graph, ids []ResourceId) ([]*Resource, error) {
errs := make(IdResolutionError)
var resources []*Resource
for _, id := range ids {
res, err := g.Vertex(id)
if err != nil {
errs[id] = err
continue
}
resources = append(resources, res)
}
if len(errs) > 0 {
return resources, errs
}
return resources, nil
}
func ResourceEdgeToKeyEdge(re ResourceEdge) Edge {
return Edge{
Source: re.Source.ID,
Target: re.Target.ID,
Properties: re.Properties,
}
}
package construct
import (
"errors"
)
// GraphBatch can be used to batch adding vertices and edges to the graph,
// collecting errors in the [Err] field.
type GraphBatch struct {
Graph
Err error
// errorAdding is to keep track on which resources we failed to add to the graph
// so that we can ignore them when adding edges to not pollute the errors.
errorAdding map[ResourceId]struct{}
}
func NewGraphBatch(g Graph) *GraphBatch {
return &GraphBatch{
Graph: g,
errorAdding: make(map[ResourceId]struct{}),
}
}
func (b *GraphBatch) AddVertices(rs ...*Resource) {
for _, r := range rs {
err := b.Graph.AddVertex(r)
if err == nil {
continue
}
b.Err = errors.Join(b.Err, err)
b.errorAdding[r.ID] = struct{}{}
}
}
func (b *GraphBatch) AddEdges(es ...Edge) {
for _, e := range es {
if _, ok := b.errorAdding[e.Source]; ok {
continue
}
if _, ok := b.errorAdding[e.Target]; ok {
continue
}
err := b.Graph.AddEdge(e.Source, e.Target, CopyEdgeProps(e.Properties))
b.Err = errors.Join(b.Err, err)
}
}
package construct
import (
"sort"
"github.com/klothoplatform/klotho/pkg/set"
)
// AllDownstreamDependencies returns all downstream dependencies of the given resource.
// Downstream means that for A -> B -> C -> D the downstream dependencies of B are [C, D].
func AllDownstreamDependencies(g Graph, r ResourceId) ([]ResourceId, error) {
adj, err := g.AdjacencyMap()
if err != nil {
return nil, err
}
return allDependencies(adj, r), nil
}
// DirectDownstreamDependencies returns the direct downstream dependencies of the given resource.
// Direct means that for A -> B -> C -> D the direct downstream dependencies of B are [C].
func DirectDownstreamDependencies(g Graph, r ResourceId) ([]ResourceId, error) {
edges, err := g.Edges()
if err != nil {
return nil, err
}
var ids []ResourceId
for _, e := range edges {
if e.Source == r {
ids = append(ids, e.Target)
}
}
sort.Sort(SortedIds(ids))
return ids, nil
}
// AllUpstreamDependencies returns all upstream dependencies of the given resource.
// Upstream means that for A -> B -> C -> D the upstream dependencies of C are [B, A] (in that order).
func AllUpstreamDependencies(g Graph, r ResourceId) ([]ResourceId, error) {
adj, err := g.PredecessorMap()
if err != nil {
return nil, err
}
return allDependencies(adj, r), nil
}
// DirectUpstreamDependencies returns the direct upstream dependencies of the given resource.
// Direct means that for A -> B -> C -> D the direct upstream dependencies of C are [B].
func DirectUpstreamDependencies(g Graph, r ResourceId) ([]ResourceId, error) {
edges, err := g.Edges()
if err != nil {
return nil, err
}
var ids []ResourceId
for _, e := range edges {
if e.Target == r {
ids = append(ids, e.Source)
}
}
sort.Sort(SortedIds(ids))
return ids, nil
}
func allDependencies(deps map[ResourceId]map[ResourceId]Edge, r ResourceId) []ResourceId {
visited := make(map[ResourceId]struct{})
var stack []ResourceId
for d := range deps[r] {
stack = append(stack, d)
}
sort.Sort(SortedIds(stack))
var ids []ResourceId
for len(stack) > 0 {
id := stack[0]
stack = stack[1:]
visited[id] = struct{}{}
ids = append(ids, id)
var next []ResourceId
for d := range deps[id] {
if _, ok := visited[d]; ok {
continue
}
next = append(next, d)
}
sort.Sort(SortedIds(next))
stack = append(stack, next...)
}
return ids
}
func Neighbors(g Graph, r ResourceId) (upstream, downstream set.Set[ResourceId], err error) {
adj, err := g.AdjacencyMap()
if err != nil {
return nil, nil, err
}
pred, err := g.PredecessorMap()
if err != nil {
return nil, nil, err
}
downstream = make(set.Set[ResourceId])
for d := range adj[r] {
downstream.Add(d)
}
upstream = make(set.Set[ResourceId])
for u := range pred[r] {
upstream.Add(u)
}
return upstream, downstream, nil
}
package construct
import (
"errors"
"fmt"
"sort"
"strings"
"github.com/dominikbraun/graph"
"github.com/klothoplatform/klotho/pkg/yaml_util"
"gopkg.in/yaml.v3"
)
type YamlGraph struct {
Graph Graph
Outputs map[string]Output
}
// nullNode is used to render as nothing in the YAML output
// useful for empty mappings, for example instead of `resources: {}`
// it would render as `resources:`. A small change, but helps reduce
// the visual clutter.
var nullNode = &yaml.Node{
Kind: yaml.ScalarNode,
Tag: "!!null",
Value: "",
}
func (g YamlGraph) MarshalYAML() (interface{}, error) {
topo, err := TopologicalSort(g.Graph)
if err != nil {
return nil, err
}
adj, err := g.Graph.AdjacencyMap()
if err != nil {
return nil, err
}
var errs error
resources := &yaml.Node{
Kind: yaml.MappingNode,
}
for _, rid := range topo {
r, err := g.Graph.Vertex(rid)
if err != nil {
errs = errors.Join(errs, err)
continue
}
if r.Imported {
r.Properties["imported"] = r.Imported
}
props, err := yaml_util.MarshalMap(r.Properties, func(a, b string) bool { return a < b })
if err != nil {
errs = errors.Join(errs, err)
continue
}
resources.Content = append(resources.Content,
&yaml.Node{
Kind: yaml.ScalarNode,
Value: rid.String(),
},
props,
)
}
if len(resources.Content) == 0 {
resources = nullNode
}
edges := &yaml.Node{
Kind: yaml.MappingNode,
}
for _, source := range topo {
targets := make([]ResourceId, 0, len(adj[source]))
for t := range adj[source] {
targets = append(targets, t)
}
sort.Sort(SortedIds(targets))
for _, target := range targets {
edgeValue := nullNode
edge := adj[source][target]
if data, ok := edge.Properties.Data.(EdgeData); ok && data != (EdgeData{}) {
edgeValue = &yaml.Node{}
err = edgeValue.Encode(data)
if err != nil {
errs = errors.Join(errs, err)
continue
}
}
edges.Content = append(edges.Content,
&yaml.Node{
Kind: yaml.ScalarNode,
Value: fmt.Sprintf("%s -> %s", source, target),
},
edgeValue)
}
}
if len(edges.Content) == 0 {
edges = nullNode
}
outputs := &yaml.Node{
Kind: yaml.MappingNode,
}
for name, output := range g.Outputs {
outputs.Content = append(outputs.Content,
&yaml.Node{
Kind: yaml.ScalarNode,
Value: name,
})
outputMap := &yaml.Node{
Kind: yaml.MappingNode,
}
if !output.Ref.IsZero() {
outputMap.Content = append(outputMap.Content,
&yaml.Node{
Kind: yaml.ScalarNode,
Value: "ref",
},
&yaml.Node{
Kind: yaml.ScalarNode,
Value: output.Ref.String(),
},
)
} else {
value := &yaml.Node{}
err = value.Encode(output.Value)
if err != nil {
errs = errors.Join(errs, err)
continue
}
outputMap.Content = append(outputMap.Content,
&yaml.Node{
Kind: yaml.ScalarNode,
Value: "value",
},
value,
)
}
outputs.Content = append(outputs.Content, outputMap)
}
return &yaml.Node{
Kind: yaml.MappingNode,
Content: []*yaml.Node{
{
Kind: yaml.ScalarNode,
Value: "resources",
},
resources,
{
Kind: yaml.ScalarNode,
Value: "edges",
},
edges,
{
Kind: yaml.ScalarNode,
Value: "outputs",
},
outputs,
},
}, nil
}
func (g *YamlGraph) UnmarshalYAML(n *yaml.Node) error {
type graphHelper struct {
Resources map[ResourceId]Properties `yaml:"resources"`
Edges map[SimpleEdge]EdgeData `yaml:"edges"`
Outputs map[string]Output `yaml:"outputs"`
}
var y graphHelper
if err := n.Decode(&y); err != nil {
return err
}
if g.Graph == nil {
g.Graph = NewGraph()
}
var errs error
for rid, props := range y.Resources {
var imported bool
if imp, ok := props["imported"]; ok {
val, ok := imp.(bool)
if !ok {
errs = errors.Join(errs, fmt.Errorf("unable to parse imported value as boolean for resource %s", rid))
// Don't continue here so that the vertex is still added, otherwise it could erroneously cause failures in the edge copying
}
imported = val
delete(props, "imported")
}
err := g.Graph.AddVertex(&Resource{
ID: rid,
Properties: props,
Imported: imported,
})
errs = errors.Join(errs, err)
}
for e, data := range y.Edges {
err := g.Graph.AddEdge(e.Source, e.Target, func(ep *graph.EdgeProperties) {
ep.Data = data
})
errs = errors.Join(errs, err)
}
if g.Outputs == nil {
g.Outputs = make(map[string]Output)
}
for name, output := range y.Outputs {
g.Outputs[name] = Output{Ref: output.Ref, Value: output.Value}
}
return errs
}
type SimpleEdge struct {
Source ResourceId
Target ResourceId
}
func ToSimpleEdge(e Edge) SimpleEdge {
return SimpleEdge{
Source: e.Source,
Target: e.Target,
}
}
func (e SimpleEdge) String() string {
return fmt.Sprintf("%s -> %s", e.Source, e.Target)
}
func (e SimpleEdge) MarshalText() (string, error) {
return e.String(), nil
}
func (e SimpleEdge) Less(other SimpleEdge) bool {
if e.Source != other.Source {
return ResourceIdLess(e.Source, other.Source)
}
return ResourceIdLess(e.Target, other.Target)
}
func (e *SimpleEdge) Parse(s string) error {
source, target, found := strings.Cut(s, " -> ")
if !found {
target, source, found = strings.Cut(s, " <- ")
if !found {
return errors.New("invalid edge format, expected either `source -> target` or `target <- source`")
}
}
return errors.Join(
e.Source.Parse(source),
e.Target.Parse(target),
)
}
func (e *SimpleEdge) Validate() error {
return errors.Join(e.Source.Validate(), e.Target.Validate())
}
func (e *SimpleEdge) UnmarshalText(data []byte) error {
if err := e.Parse(string(data)); err != nil {
return err
}
return e.Validate()
}
func (e SimpleEdge) ToEdge() Edge {
return Edge{
Source: e.Source,
Target: e.Target,
}
}
func EdgeKeys[V any](m map[SimpleEdge]V) []SimpleEdge {
keys := make([]SimpleEdge, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Slice(keys, func(i, j int) bool {
return keys[i].Less(keys[j])
})
return keys
}
package construct
import (
"errors"
"fmt"
"github.com/dominikbraun/graph"
"github.com/klothoplatform/klotho/pkg/graph_addons"
)
func CopyVertexProps(p graph.VertexProperties) func(*graph.VertexProperties) {
return func(dst *graph.VertexProperties) {
*dst = p
}
}
func CopyEdgeProps(p graph.EdgeProperties) func(*graph.EdgeProperties) {
return func(dst *graph.EdgeProperties) {
*dst = p
}
}
// ReplaceResource replaces the resources identified by `oldId` with `newRes` in the graph and in any property
// references (as [ResourceId] or [PropertyRef]) of the old ID to the new ID in any resource that depends on or is
// depended on by the resource.
func ReplaceResource(g Graph, oldId ResourceId, newRes *Resource) error {
if oldId == newRes.ID {
return nil
}
err := graph_addons.ReplaceVertex(g, oldId, newRes, ResourceHasher)
if err != nil {
return fmt.Errorf("could not update resource %s to %s: %w", oldId, newRes.ID, err)
}
updateId := func(path PropertyPathItem) error {
itemVal, _ := path.Get()
if itemId, ok := itemVal.(ResourceId); ok && itemId == oldId {
return path.Set(newRes.ID)
}
if itemRef, ok := itemVal.(PropertyRef); ok && itemRef.Resource == oldId {
itemRef.Resource = newRes.ID
return path.Set(itemRef)
}
return nil
}
return WalkGraph(g, func(id ResourceId, resource *Resource, nerr error) error {
err = resource.WalkProperties(func(path PropertyPath, err error) error {
err = errors.Join(err, updateId(path))
if kv, ok := path.Last().(PropertyKVItem); ok {
err = errors.Join(err, updateId(kv.Key()))
}
return err
})
return errors.Join(nerr, err)
})
}
// UpdateResourceId is used when a resource's ID changes. It updates the graph in-place, using the resource
// currently referenced by `old`. No-op if the resource ID hasn't changed.
// Also updates any property references (as [ResourceId] or [PropertyRef]) of the old ID to the new ID in any
// resource that depends on or is depended on by the resource.
func PropagateUpdatedId(g Graph, old ResourceId) error {
newRes, err := g.Vertex(old)
if err != nil {
return err
}
// Short circuit if the resource ID hasn't changed.
if old == newRes.ID {
return nil
}
return ReplaceResource(g, old, newRes)
}
// RemoveResource removes all edges from the resource. any property references (as [ResourceId] or [PropertyRef])
// to the resource, and finally the resource itself.
func RemoveResource(g Graph, id ResourceId) error {
adj, err := g.AdjacencyMap()
if err != nil {
return err
}
if _, ok := adj[id]; !ok {
return nil
}
for _, edge := range adj[id] {
err = errors.Join(
err,
g.RemoveEdge(edge.Source, edge.Target),
)
}
if err != nil {
return err
}
pred, err := g.PredecessorMap()
if err != nil {
return err
}
for _, edge := range pred[id] {
err = errors.Join(
err,
g.RemoveEdge(edge.Source, edge.Target),
)
}
if err != nil {
return err
}
removeId := func(path PropertyPathItem) (bool, error) {
itemVal, _ := path.Get()
itemId, ok := itemVal.(ResourceId)
if ok && itemId == id {
return true, path.Remove(nil)
}
itemRef, ok := itemVal.(PropertyRef)
if ok && itemRef.Resource == id {
return true, path.Remove(nil)
}
return false, nil
}
for neighborId := range adj {
neighbor, err := g.Vertex(neighborId)
if err != nil {
return err
}
err = neighbor.WalkProperties(func(path PropertyPath, nerr error) error {
removed, err := removeId(path)
nerr = errors.Join(nerr, err)
if removed {
return SkipProperty
}
kv, ok := path.Last().(PropertyKVItem)
if !ok {
return err
}
removed, err = removeId(kv.Key())
nerr = errors.Join(nerr, err)
if removed {
return SkipProperty
}
return nerr
})
if err != nil {
return err
}
}
return g.RemoveVertex(id)
}
package construct
import (
"errors"
"github.com/dominikbraun/graph"
"github.com/klothoplatform/klotho/pkg/graph_addons"
)
// TopologicalSort provides a stable topological ordering of resource IDs.
// This is a modified implementation of graph.StableTopologicalSort with the primary difference
// being any uses of the internal function `enqueueArbitrary`.
func TopologicalSort[T any](g graph.Graph[ResourceId, T]) ([]ResourceId, error) {
return graph_addons.TopologicalSort(g, ResourceIdLess)
}
// ReverseTopologicalSort is like TopologicalSort, but returns the reverse order. This is primarily useful for
// IaC graphs to determine the order in which resources should be created.
func ReverseTopologicalSort[T any](g graph.Graph[ResourceId, T]) ([]ResourceId, error) {
return graph_addons.ReverseTopologicalSort(g, ResourceIdLess)
}
// WalkGraphFunc is much like `fs.WalkDirFunc` and is used in `WalkGraph` and `WalkGraphReverse` for the callback
// during graph traversal. Return `StopWalk` to end the walk.
type WalkGraphFunc func(id ResourceId, resource *Resource, nerr error) error
// StopWalk is a special error that can be returned from WalkGraphFunc to stop walking the graph.
// The resulting error from WalkGraph will be whatever was previously passed into the walk function.
var StopWalk = errors.New("stop walking")
func walkGraph(g Graph, ids []ResourceId, fn WalkGraphFunc) (nerr error) {
for _, id := range ids {
v, verr := g.Vertex(id)
if verr != nil {
return verr
}
err := fn(id, v, nerr)
if errors.Is(err, StopWalk) {
return
}
nerr = err
}
return
}
func WalkGraph(g Graph, fn WalkGraphFunc) error {
topo, err := TopologicalSort(g)
if err != nil {
return err
}
return walkGraph(g, topo, fn)
}
func WalkGraphReverse(g Graph, fn WalkGraphFunc) error {
topo, err := ReverseTopologicalSort(g)
if err != nil {
return err
}
return walkGraph(g, topo, fn)
}
package graphtest
import (
"errors"
"fmt"
"testing"
"github.com/dominikbraun/graph"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/stretchr/testify/assert"
)
func AssertGraphEqual(t *testing.T, expect, actual construct.Graph, message string, args ...any) {
t.Helper()
assert := assert.New(t)
must := func(v any, err error) any {
if err != nil {
t.Fatal(err)
}
return v
}
msg := func(subMessage string) []any {
if message == "" {
return []any{subMessage}
}
return append([]any{message + ": " + subMessage}, args...)
}
assert.Equal(must(expect.Order()), must(actual.Order()), msg("order (# of nodes) mismatch")...)
assert.Equal(must(expect.Size()), must(actual.Size()), msg("size (# of edges) mismatch")...)
// Use the string representation to compare the graphs so that the diffs are nicer
eStr := must(construct.String(expect))
aStr := must(construct.String(actual))
assert.Equal(eStr, aStr, msg("graph mismatch")...)
}
func AssertGraphContains(t *testing.T, expect, actual construct.Graph) {
t.Helper()
assert := assert.New(t)
must := func(v any, err error) any {
if err != nil {
t.Fatal(err)
}
return v
}
expectVs := must(construct.TopologicalSort(expect)).([]construct.ResourceId)
for _, expectV := range expectVs {
_, err := actual.Vertex(expectV)
assert.NoError(err)
}
expectEs := must(expect.Edges()).([]construct.Edge)
for _, expectE := range expectEs {
_, err := actual.Edge(expectE.Source, expectE.Target)
assert.NoError(err)
}
}
func StringToGraphElement(e string) (any, error) {
var id construct.ResourceId
idErr := id.Parse(e)
if id.Validate() == nil {
return id, nil
}
var path construct.Path
pathErr := path.Parse(e)
if len(path) > 0 {
return path, nil
}
return nil, errors.Join(idErr, pathErr)
}
// AddElement is a utility function for adding an element to a graph. See [MakeGraph] for more information on supported
// element types. Returns whether adding the element failed.
func AddElement(t *testing.T, g construct.Graph, e any) (failed bool) {
must := func(err error) {
if err != nil {
t.Fatal(err)
}
}
if estr, ok := e.(string); ok {
var err error
e, err = StringToGraphElement(estr)
if err != nil {
t.Errorf("invalid element %q (type %[1]T) Parse errors: %v", e, err)
return true
}
}
addIfMissing := func(res *construct.Resource) {
if _, err := g.Vertex(res.ID); errors.Is(err, graph.ErrVertexNotFound) {
must(g.AddVertex(res))
} else if err != nil {
t.Fatal(fmt.Errorf("could check vertex %s: %w", res.ID, err))
}
}
switch e := e.(type) {
case construct.ResourceId:
addIfMissing(&construct.Resource{ID: e})
case construct.Resource:
must(g.AddVertex(&e))
case *construct.Resource:
must(g.AddVertex(e))
case construct.Edge:
addIfMissing(&construct.Resource{ID: e.Source})
addIfMissing(&construct.Resource{ID: e.Target})
must(g.AddEdge(e.Source, e.Target))
case construct.ResourceEdge:
addIfMissing(e.Source)
addIfMissing(e.Target)
must(g.AddEdge(e.Source.ID, e.Target.ID))
case construct.SimpleEdge:
addIfMissing(&construct.Resource{ID: e.Source})
addIfMissing(&construct.Resource{ID: e.Target})
must(g.AddEdge(e.Source, e.Target))
case construct.Path:
for i, id := range e {
addIfMissing(&construct.Resource{ID: id})
if i > 0 {
must(g.AddEdge(e[i-1], id))
}
}
default:
t.Errorf("invalid element of type %T", e)
return true
}
return false
}
// MakeGraph is a utility function for creating a graph from a list of elements which can be of types:
// - ResourceId : adds an empty resource with the given ID
// - Resource, *Resource : adds the given resource
// - Edge : adds the given edge
// - Path : adds all the edges in the path
// - string : parses the string as either a ResourceId or an Edge and add it as above
//
// The input graph is so it can be either via NewGraph or NewAcyclicGraph.
// Users are encouraged to wrap this function for the specific test function for ease of use, such as:
//
// makeGraph := func(elements ...any) Graph {
// return MakeGraph(t, NewGraph(), elements...)
// }
func MakeGraph(t *testing.T, g construct.Graph, elements ...any) construct.Graph {
failed := false
for i, e := range elements {
elemFailed := AddElement(t, g, e)
if elemFailed {
t.Errorf("failed to add element[%d] (%v) to graph", i, e)
failed = true
}
}
if failed {
// Fail now because if the graph didn't parse correctly, then the rest of the test is likely to fail
t.FailNow()
}
return g
}
package graphtest
import (
"fmt"
"testing"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/stretchr/testify/assert"
)
type GraphChanges struct {
construct.Graph
Added []construct.ResourceId
Removed []construct.ResourceId
AddedEdges []construct.Edge
RemovedEdges []construct.Edge
}
func RecordChanges(inner construct.Graph) *GraphChanges {
return &GraphChanges{
Graph: inner,
}
}
func (c *GraphChanges) AddVertex(value *construct.Resource, options ...func(*graph.VertexProperties)) error {
err := c.Graph.AddVertex(value, options...)
if err == nil {
c.Added = append(c.Added, value.ID)
}
return err
}
func (c *GraphChanges) AddVerticesFrom(g construct.Graph) error {
adj, err := g.AdjacencyMap()
if err != nil {
return err
}
err = c.Graph.AddVerticesFrom(g)
if err == nil {
for v := range adj {
c.Added = append(c.Added, v)
}
}
return err
}
func (c *GraphChanges) RemoveVertex(hash construct.ResourceId) error {
err := c.Graph.RemoveVertex(hash)
if err == nil {
c.Removed = append(c.Removed, hash)
}
return err
}
func (c *GraphChanges) AddEdge(
sourceHash, targetHash construct.ResourceId,
options ...func(*graph.EdgeProperties),
) error {
err := c.Graph.AddEdge(sourceHash, targetHash, options...)
if err == nil {
c.AddedEdges = append(c.AddedEdges, construct.Edge{Source: sourceHash, Target: targetHash})
}
return err
}
func (c *GraphChanges) AddEdgesFrom(g construct.Graph) error {
edges, err := g.Edges()
if err != nil {
return err
}
err = c.Graph.AddEdgesFrom(g)
if err == nil {
c.AddedEdges = append(c.AddedEdges, edges...)
}
return err
}
func (c *GraphChanges) RemoveEdge(source, target construct.ResourceId) error {
err := c.Graph.RemoveEdge(source, target)
if err == nil {
c.RemovedEdges = append(c.RemovedEdges, construct.Edge{Source: source, Target: target})
}
return err
}
func (expected *GraphChanges) AssertEqual(t *testing.T, actual *GraphChanges) {
// the following two helpers make the diffs nicer to read, instead of printing the whole structs
ids := func(s []construct.ResourceId) []string {
out := make([]string, len(s))
for i, id := range s {
out[i] = id.String()
}
return out
}
edges := func(s []construct.Edge) []string {
out := make([]string, len(s))
for i, e := range s {
out[i] = fmt.Sprintf("%s -> %s", e.Source, e.Target)
}
return out
}
assert.ElementsMatch(t, ids(expected.Added), ids(actual.Added), "added vertices")
assert.ElementsMatch(t, ids(expected.Removed), ids(actual.Removed), "removed vertices")
assert.ElementsMatch(t, edges(expected.AddedEdges), edges(actual.AddedEdges), "added edges")
assert.ElementsMatch(t, edges(expected.RemovedEdges), edges(actual.RemovedEdges), "removed edges")
}
package graphtest
import (
"testing"
construct "github.com/klothoplatform/klotho/pkg/construct"
)
func ParseId(t *testing.T, str string) (id construct.ResourceId) {
err := id.Parse(str)
if err != nil {
t.Fatalf("failed to parse resource id %q: %v", str, err)
}
return
}
func ParseEdge(t *testing.T, str string) construct.Edge {
var io construct.SimpleEdge
err := io.Parse(str)
if err != nil {
t.Fatalf("failed to parse edge %q: %v", str, err)
}
return construct.Edge{
Source: io.Source,
Target: io.Target,
}
}
func ParseRef(t *testing.T, str string) construct.PropertyRef {
var ref construct.PropertyRef
err := ref.Parse(str)
if err != nil {
t.Fatalf("failed to parse property ref %q: %v", str, err)
}
return ref
}
func ParsePath(t *testing.T, str string) construct.Path {
var path construct.Path
err := path.Parse(str)
if err != nil {
t.Fatalf("failed to parse path %q: %v", str, err)
}
return path
}
package construct
// SortedIds is a helper type for sorting ResourceIds by purely their content, for use when deterministic ordering
// is desired (when no other sources of ordering are available).
type SortedIds []ResourceId
func (s SortedIds) Len() int {
return len(s)
}
func ResourceIdLess(a, b ResourceId) bool {
if a.Provider != b.Provider {
return a.Provider < b.Provider
}
if a.Type != b.Type {
return a.Type < b.Type
}
if a.Namespace != b.Namespace {
return a.Namespace < b.Namespace
}
return a.Name < b.Name
}
func (s SortedIds) Less(i, j int) bool {
return ResourceIdLess(s[i], s[j])
}
func (s SortedIds) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
package construct
import (
"errors"
"fmt"
"math"
"strings"
"github.com/dominikbraun/graph"
"github.com/klothoplatform/klotho/pkg/graph_addons"
"github.com/klothoplatform/klotho/pkg/set"
)
type (
Path graph_addons.Path[ResourceId]
Dependencies struct {
Resource ResourceId
Paths []Path
All set.Set[ResourceId]
}
)
func (p Path) String() string {
parts := make([]string, len(p))
for i, id := range p {
parts[i] = id.String()
}
return strings.Join(parts, " -> ")
}
func (p Path) Contains(id ResourceId) bool {
for _, pathId := range p {
if pathId == id {
return true
}
}
return false
}
func (p Path) MarshalText() ([]byte, error) {
return []byte(p.String()), nil
}
func (p *Path) Parse(s string) error {
parts := strings.Split(s, " -> ")
*p = make(Path, len(parts))
var errs error
for i, part := range parts {
var id ResourceId
err := id.Parse(part)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("could not parse path[%d]: %w", i, err))
}
(*p)[i] = id
}
return errs
}
func (p *Path) Validate() error {
var errs error
for i, id := range *p {
err := id.Validate()
if err != nil {
errs = errors.Join(errs, fmt.Errorf("path[%d] invalid: %w", i, err))
}
}
return errs
}
func (p *Path) UnmarshalText(text []byte) error {
if err := p.Parse(string(text)); err != nil {
return err
}
return p.Validate()
}
func (d *Dependencies) Add(p Path) {
d.Paths = append(d.Paths, p)
for _, id := range p {
if id != d.Resource {
d.All.Add(id)
}
}
}
func newDependencies(
g Graph,
start ResourceId,
skipEdge func(Edge) bool,
deps map[ResourceId]map[ResourceId]Edge,
) (*Dependencies, error) {
bfRes, err := bellmanFord(g, start, skipEdge)
if err != nil {
return nil, err
}
d := &Dependencies{Resource: start, All: make(set.Set[ResourceId])}
for v := range deps {
path, err := bfRes.ShortestPath(v)
if errors.Is(err, graph.ErrTargetNotReachable) {
continue
} else if err != nil {
return nil, fmt.Errorf("could not get shortest path from %s to %s: %w", start, v, err)
}
d.Add(path)
}
return d, nil
}
func UpstreamDependencies(g Graph, start ResourceId, skipEdge func(Edge) bool) (*Dependencies, error) {
pred, err := g.PredecessorMap()
if err != nil {
return nil, err
}
return newDependencies(g, start, skipEdge, pred)
}
func DownstreamDependencies(g Graph, start ResourceId, skipEdge func(Edge) bool) (*Dependencies, error) {
adj, err := g.AdjacencyMap()
if err != nil {
return nil, err
}
return newDependencies(g, start, skipEdge, adj)
}
type ShortestPather interface {
ShortestPath(target ResourceId) (Path, error)
}
func ShortestPaths(
g Graph,
source ResourceId,
skipEdge func(Edge) bool,
) (ShortestPather, error) {
return bellmanFord(g, source, skipEdge)
}
func DontSkipEdges(_ Edge) bool {
return false
}
type bellmanFordResult struct {
source ResourceId
prev map[ResourceId]ResourceId
}
func bellmanFord(g Graph, source ResourceId, skipEdge func(Edge) bool) (*bellmanFordResult, error) {
dist := make(map[ResourceId]int)
prev := make(map[ResourceId]ResourceId)
adjacencyMap, err := g.AdjacencyMap()
if err != nil {
return nil, fmt.Errorf("could not get adjacency map: %w", err)
}
for key := range adjacencyMap {
dist[key] = math.MaxInt32
}
dist[source] = 0
for i := 0; i < len(adjacencyMap)-1; i++ {
for key, edges := range adjacencyMap {
for _, edge := range edges {
if skipEdge(edge) {
continue
}
if edge.Source == edge.Target {
continue
}
edgeWeight := edge.Properties.Weight
if !g.Traits().IsWeighted {
edgeWeight = 1
}
newDist := dist[key] + edgeWeight
if newDist < dist[edge.Target] {
dist[edge.Target] = newDist
prev[edge.Target] = key
} else if newDist == dist[edge.Target] && ResourceIdLess(key, prev[edge.Target]) {
prev[edge.Target] = key
}
}
}
}
for _, edges := range adjacencyMap {
for _, edge := range edges {
if skipEdge(edge) {
continue
}
edgeWeight := edge.Properties.Weight
if !g.Traits().IsWeighted {
edgeWeight = 1
}
if newDist := dist[edge.Source] + edgeWeight; newDist < dist[edge.Target] {
return nil, errors.New("graph contains a negative-weight cycle")
}
}
}
return &bellmanFordResult{
source: source,
prev: prev,
}, nil
}
func (b bellmanFordResult) ShortestPath(target ResourceId) (Path, error) {
var path []ResourceId
u := target
for u != b.source {
if _, ok := b.prev[u]; !ok {
return nil, graph.ErrTargetNotReachable
}
if len(path) > 5000 {
// This is "slow" but if there's this many path elements, something's wrong
// and this debug info will be useful.
for i, e := range path {
for j := i - 1; j >= 0; j-- {
if path[j] == e {
return nil, fmt.Errorf("path contains a cycle: %s", Path(path[j:i+1]))
}
}
}
return nil, errors.New("path too long")
}
path = append([]ResourceId{u}, path...)
u = b.prev[u]
}
path = append([]ResourceId{b.source}, path...)
return path, nil
}
package construct
import (
"errors"
"fmt"
"reflect"
"sort"
"strconv"
"strings"
"github.com/klothoplatform/klotho/pkg/reflectutil"
"github.com/klothoplatform/klotho/pkg/set"
"github.com/klothoplatform/klotho/pkg/yaml_util"
)
type (
Properties map[string]any
)
// SetProperty is a wrapper around [PropertyPath.Set] for convenience.
func (r *Resource) SetProperty(pathStr string, value any) error {
path, err := r.PropertyPath(pathStr)
if err != nil {
return err
}
return path.Set(value)
}
// GetProperty is a wrapper around [PropertyPath.Get] for convenience.
// It returns ErrPropertyDoesNotExist if the property does not exist.
func (r *Resource) GetProperty(pathStr string) (any, error) {
path, err := r.PropertyPath(pathStr)
if err != nil {
return nil, err
}
if value, ok := path.Get(); ok {
return value, nil
}
// Backwards compatibility: if the property does not exist, return nil instead of an error.
return nil, nil
}
// ErrPropertyDoesNotExist is returned when a property does not exist.
var ErrPropertyDoesNotExist = errors.New("property does not exist")
// AppendProperty is a wrapper around [PropertyPath.Append] for convenience.
func (r *Resource) AppendProperty(pathStr string, value any) error {
path, err := r.PropertyPath(pathStr)
if err != nil {
return err
}
return path.Append(value)
}
// RemoveProperty is a wrapper around [PropertyPath.Remove] for convenience.
func (r *Resource) RemoveProperty(pathStr string, value any) error {
path, err := r.PropertyPath(pathStr)
if err != nil {
return err
}
return path.Remove(value)
}
func (r *Resource) PropertyPath(pathStr string) (PropertyPath, error) {
if r.Properties == nil {
r.Properties = Properties{}
}
return r.Properties.PropertyPath(pathStr)
}
func (p Properties) Equals(other any) (equal bool) {
otherProps, ok := other.(Properties)
if !ok {
return false
}
if len(p) != len(otherProps) {
return false
}
equal = true
_ = p.WalkProperties(func(path PropertyPath, _ error) error {
otherPath, err := otherProps.PropertyPath(path.String())
if err != nil {
equal = false
return StopWalk
}
v, _ := path.Get()
otherV, _ := otherPath.Get()
if v == nil || otherV == nil {
equal = v == otherV
} else if vEq, ok := v.(interface{ Equals(any) bool }); ok {
equal = vEq.Equals(otherV)
} else {
vVal := reflect.ValueOf(v)
otherVVal := reflect.ValueOf(otherV)
if vVal.Comparable() && otherVVal.Comparable() && v == otherV {
return nil
}
equal = reflect.DeepEqual(v, otherV)
}
if !equal {
return StopWalk
}
return nil
})
return equal
}
func (p Properties) SetProperty(pathStr string, value any) error {
path, err := p.PropertyPath(pathStr)
if err != nil {
return err
}
return path.Set(value)
}
func (p *Properties) GetProperty(pathStr string) (any, error) {
path, err := p.PropertyPath(pathStr)
if err != nil {
return nil, err
}
if value, ok := path.Get(); ok {
return value, nil
}
return nil, ErrPropertyDoesNotExist
}
func (p Properties) AppendProperty(pathStr string, value any) error {
path, err := p.PropertyPath(pathStr)
if err != nil {
return err
}
return path.Append(value)
}
func (p Properties) RemoveProperty(pathStr string, value any) error {
path, err := p.PropertyPath(pathStr)
if err != nil {
return err
}
return path.Remove(value)
}
type (
PropertyPathItem interface {
Get() (value any, ok bool)
Set(value any) error
Remove(value any) error
Append(value any) error
parent() PropertyPathItem
}
PropertyKVItem interface {
Key() PropertyPathItem
}
// PropertyPath represents a path into a resource's properties. See [Resource.PropertyPath] for
// more information.
PropertyPath []PropertyPathItem
mapValuePathItem struct {
_parent PropertyPathItem
m reflect.Value
key reflect.Value
}
mapKeyPathItem mapValuePathItem
arrayIndexPathItem struct {
_parent PropertyPathItem
a reflect.Value
index int
}
)
// PropertyPath interprets a string path to index (potentially deeply) into [Resource.Properties]
// which can be used to get, set, append, or remove values.
func (p Properties) PropertyPath(pathStr string) (PropertyPath, error) {
pathParts := reflectutil.SplitPath(pathStr)
if len(pathParts) == 0 {
return nil, fmt.Errorf("empty path")
}
path := make(PropertyPath, len(pathParts))
value := reflect.ValueOf(p)
setMap := func(i int, key string) error {
for value.Kind() == reflect.Interface || value.Kind() == reflect.Ptr {
value = value.Elem()
}
if value.IsValid() {
if value.Kind() == reflect.Struct {
hs, ok := value.Interface().(set.HashedSet[string, any])
if !ok {
return &PropertyPathError{
Path: pathParts[:i],
Cause: fmt.Errorf("expected HashedSet as struct, got %T", value.Interface()),
}
}
// NOTE: this depends on the internals of set.HashedSet
value = reflect.ValueOf(hs.M)
} else if value.Kind() != reflect.Map {
return &PropertyPathError{
Path: pathParts[:i-1],
Cause: fmt.Errorf("expected map, got %s", value.Type()),
}
}
}
item := &mapValuePathItem{
m: value,
key: reflect.ValueOf(key),
}
if i > 0 {
item._parent = path[i-1]
}
path[i] = item
if value.IsValid() {
value = value.MapIndex(item.key)
}
return nil
}
for i, part := range pathParts {
switch part[0] {
case '.':
err := setMap(i, part[1:])
if err != nil {
return nil, err
}
default:
if i > 0 {
return nil, &PropertyPathError{
Path: pathParts[:i],
Cause: fmt.Errorf("expected '.' or '[' to start path part, got %q", part),
}
}
err := setMap(i, part)
if err != nil {
return nil, err
}
case '[':
if len(part) < 2 || part[len(part)-1] != ']' {
return nil, &PropertyPathError{
Path: pathParts[:i],
Cause: fmt.Errorf("invalid array index format, got %q", part),
}
}
idxStr := part[1 : len(part)-1]
idx, err := strconv.Atoi(idxStr)
if err != nil {
// for `MyMap[key.with.periods]` form
err := setMap(i, idxStr)
if err != nil {
return nil, err
}
continue
}
for value.Kind() == reflect.Interface || value.Kind() == reflect.Ptr {
value = value.Elem()
}
if value.IsValid() && value.Kind() != reflect.Slice && value.Kind() != reflect.Array {
if hs, ok := value.Interface().(set.HashedSet[string, any]); ok {
value = reflect.ValueOf(hs.ToSlice())
} else {
return nil, &PropertyPathError{
Path: pathParts[:i-1],
Cause: fmt.Errorf("expected array, got %s", value.Type()),
}
}
}
if !value.IsValid() || value.IsZero() {
value = reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf((*any)(nil)).Elem()), 0, idx+1)
}
if idx < 0 || idx >= value.Len() {
return nil, &PropertyPathError{
Path: pathParts[:i],
Cause: fmt.Errorf("array index out of bounds: %d (length %d)", idx, value.Len()),
}
}
path[i] = &arrayIndexPathItem{
_parent: path[i-1],
a: value,
index: idx,
}
if value.IsValid() {
value = value.Index(idx)
}
}
}
return path, nil
}
type PropertyPathError struct {
Path []string
Cause error
}
func (e PropertyPathError) Error() string {
return fmt.Sprintf("error in path %s: %v",
strings.Join(e.Path, ""),
e.Cause,
)
}
func itemToPath(i PropertyPathItem) []string {
path, ok := i.(PropertyPath)
if ok {
return path.Parts()
}
var items []PropertyPathItem
for i != nil {
items = append(items, i)
i = i.parent()
}
// reverse items so that we get the path in the correct order
for idx := 0; idx < len(items)/2; idx++ {
items[idx], items[len(items)-idx-1] = items[len(items)-idx-1], items[idx]
}
return PropertyPath(items).Parts()
}
func (e PropertyPathError) Unwrap() error {
return e.Cause
}
func pathPanicRecover(i PropertyPathItem, operation string, err *error) {
if r := recover(); r != nil {
rerr, ok := r.(error)
if !ok {
rerr = fmt.Errorf("panic: %v", r)
}
*err = &PropertyPathError{
Path: itemToPath(i),
Cause: fmt.Errorf("recovered panic during '%s': %w", operation, rerr),
}
}
}
func (i *mapValuePathItem) ensureMap() error {
if !i.m.IsValid() {
i.m = reflect.MakeMap(reflect.MapOf(i.key.Type(), reflect.TypeOf((*any)(nil)).Elem()))
return i._parent.Set(i.m.Interface())
}
return nil
}
func (i *mapValuePathItem) Set(value any) (err error) {
defer pathPanicRecover(i, "Set on map", &err)
if err := i.ensureMap(); err != nil {
return err
}
i.m.SetMapIndex(i.key, reflect.ValueOf(value))
return nil
}
func appendValue(appendTo reflect.Value, value reflect.Value) (reflect.Value, error) {
a := appendTo
for a.Kind() == reflect.Interface || a.Kind() == reflect.Ptr {
a = a.Elem()
}
if !a.IsValid() {
// Appending to empty, create a new slice or map based on what value's type.
switch value.Kind() {
case reflect.Slice, reflect.Array:
// append(nil, []T{...}} => []T{...}
a = reflect.MakeSlice(value.Type(), 0, value.Len())
case reflect.Map:
// append(nil, map[K]V{...}) => map[K]V{...}
a = reflect.MakeMap(reflect.MapOf(value.Type().Key(), value.Type().Elem()))
default:
// append(nil, T) => []T{...}
a = reflect.MakeSlice(reflect.SliceOf(value.Type()), 0, 1)
}
}
switch a.Kind() {
case reflect.Slice, reflect.Array:
var values []reflect.Value
if (value.Kind() == reflect.Slice || value.Kind() == reflect.Array) &&
value.Type().Elem().AssignableTo(a.Type().Elem()) {
// append(a []T, b []T) => []T{a..., b...}
values = make([]reflect.Value, value.Len())
for i := 0; i < value.Len(); i++ {
values[i] = value.Index(i)
}
} else if value.Type().AssignableTo(a.Type().Elem()) {
// append(a []T, b T) => []T{a..., b}
values = []reflect.Value{value}
} else {
return a, fmt.Errorf("expected %s or []%[1]s value for append, got %s", a.Type().Elem(), value.Type())
}
// NOTE(gg): If we ever need to allow for duplicates in a list, we'll likely need that behaviour
// specified in a template, which means this logic will need to be promoted out of here and into
// somewhere that has access to the templates.
toAdd := make([]reflect.Value, 0, len(values))
valuesLoop:
for _, v := range values {
for i := 0; i < a.Len(); i++ {
existing := a.Index(i)
if reflect.DeepEqual(existing.Interface(), v.Interface()) {
continue valuesLoop
}
}
toAdd = append(toAdd, v)
}
return reflect.Append(a, toAdd...), nil
case reflect.Map:
aType := a.Type()
valType := value.Type()
if valType.Kind() != reflect.Map {
return a, fmt.Errorf("expected map value for append, got %s", valType)
}
if !valType.Key().AssignableTo(aType.Key()) {
return a, fmt.Errorf("expected map key type %s, got %s", aType.Key(), valType.Key())
}
if !valType.Elem().AssignableTo(aType.Elem()) {
return a, fmt.Errorf("expected map value type %s, got %s", aType.Elem(), valType.Elem())
}
for _, key := range value.MapKeys() {
a.SetMapIndex(key, value.MapIndex(key))
}
return a, nil
case reflect.Struct:
val := value.Interface()
original := a.Interface()
current, ok := original.(set.HashedSet[string, any])
if !ok {
return a, fmt.Errorf("expected HashedSet as original struct, got %T", original)
}
additional, ok := val.(set.HashedSet[string, any])
if !ok {
return a, fmt.Errorf("expected HashedSet as additional struct, got %T", val)
}
current.Add(additional.ToSlice()...)
return reflect.ValueOf(current), nil
}
return a, fmt.Errorf("expected array, hashedset, or map destination for append, got %s", a.Kind())
}
func (i *mapValuePathItem) Append(value any) (err error) {
defer pathPanicRecover(i, "Append on map", &err)
if err := i.ensureMap(); err != nil {
return err
}
kv := i.m.MapIndex(i.key)
appended, err := appendValue(kv, reflect.ValueOf(value))
if err != nil {
return &PropertyPathError{Path: itemToPath(i), Cause: err}
}
i.m.SetMapIndex(i.key, appended)
return nil
}
func arrRemoveByValue(arr reflect.Value, value reflect.Value) (reflect.Value, error) {
newArr := reflect.MakeSlice(arr.Type(), 0, arr.Len())
for i := 0; i < arr.Len(); i++ {
item := arr.Index(i)
if !item.Equal(value) {
newArr = reflect.Append(newArr, item)
}
}
if newArr.Len() == arr.Len() {
return arr, fmt.Errorf("value %v not found in array", value)
}
return newArr, nil
}
func (i *mapValuePathItem) Remove(value any) (err error) {
defer pathPanicRecover(i, "Remove on map", &err)
if !i.m.IsValid() {
return
}
if value == nil {
i.m.SetMapIndex(i.key, reflect.Value{})
return nil
}
arr := i.m.MapIndex(i.key)
for arr.Kind() == reflect.Interface || arr.Kind() == reflect.Ptr {
arr = arr.Elem()
}
if arr.Kind() != reflect.Slice && arr.Kind() != reflect.Array {
if hs, ok := arr.Interface().(set.HashedSet[string, any]); ok {
if hs.Contains(value) {
removed := hs.Remove(value)
if !removed {
return &PropertyPathError{
Path: itemToPath(i),
Cause: fmt.Errorf("value %v not removed from set", value),
}
}
} else {
return &PropertyPathError{
Path: itemToPath(i),
Cause: fmt.Errorf("value %v not found in set", value),
}
}
return nil
}
return &PropertyPathError{
Path: itemToPath(i),
Cause: fmt.Errorf("for non-nil value'd (%v), must be array (got %s) to remove by value", value, arr.Type()),
}
}
newArr, err := arrRemoveByValue(arr, reflect.ValueOf(value))
if err != nil {
return &PropertyPathError{Path: itemToPath(i), Cause: err}
}
i.m.SetMapIndex(i.key, newArr)
return nil
}
func (i *mapValuePathItem) Get() (any, bool) {
if !i.m.IsValid() {
return nil, false
}
v := i.m.MapIndex(i.key)
if !v.IsValid() {
return nil, false
}
return v.Interface(), true
}
func (i *mapValuePathItem) parent() PropertyPathItem {
return i._parent
}
func (i *mapValuePathItem) Key() PropertyPathItem {
return (*mapKeyPathItem)(i)
}
func (i *mapKeyPathItem) Get() (any, bool) {
return i.key.Interface(), true
}
func (i *mapKeyPathItem) Set(value any) (err error) {
defer pathPanicRecover(i, "Set on map key", &err)
mapValue := i.m.MapIndex(i.key)
i.m.SetMapIndex(i.key, reflect.Value{})
i.m.SetMapIndex(reflect.ValueOf(value), mapValue)
return nil
}
func (i *mapKeyPathItem) Append(value any) (err error) {
return &PropertyPathError{
Path: itemToPath(i),
Cause: fmt.Errorf("cannot append to map key"),
}
}
func (i *mapKeyPathItem) Remove(value any) (err error) {
defer pathPanicRecover(i, "Remove on map key", &err)
i.m.SetMapIndex(i.key, reflect.Value{})
return nil
}
func (i *mapKeyPathItem) parent() PropertyPathItem {
return i._parent
}
func (i *arrayIndexPathItem) Set(value any) (err error) {
defer pathPanicRecover(i, "Set on array", &err)
i.a.Index(i.index).Set(reflect.ValueOf(value))
return nil
}
func (i *arrayIndexPathItem) Append(value any) (err error) {
defer pathPanicRecover(i, "Append on array", &err)
ival := i.a.Index(i.index)
appended, err := appendValue(ival, reflect.ValueOf(value))
if err != nil {
return &PropertyPathError{Path: itemToPath(i), Cause: err}
}
ival.Set(appended)
return nil
}
func (i *arrayIndexPathItem) Remove(value any) (err error) {
defer pathPanicRecover(i, "Remove on array", &err)
if !i.a.IsValid() {
return
}
if value == nil {
i.a = reflect.AppendSlice(i.a.Slice(0, i.index), i.a.Slice(i.index+1, i.a.Len()))
return i._parent.Set(i.a.Interface())
}
arr := i.a.Index(i.index)
for arr.Kind() == reflect.Interface || arr.Kind() == reflect.Ptr {
arr = arr.Elem()
}
if arr.Kind() != reflect.Slice && arr.Kind() != reflect.Array {
return &PropertyPathError{
Path: itemToPath(i),
Cause: fmt.Errorf("for non-nil value'd (%v), must be array (got %s) to remove by value", value, arr.Type()),
}
}
newArr, err := arrRemoveByValue(arr, reflect.ValueOf(value))
if err != nil {
return &PropertyPathError{Path: itemToPath(i), Cause: err}
}
arr.Set(newArr)
return nil
}
func (i *arrayIndexPathItem) Get() (any, bool) {
if !i.a.IsValid() || !reflectutil.IsAnyOf(reflectutil.GetConcreteElement(i.a), reflect.Slice, reflect.Array) {
return nil, false
}
if i.a.Len() <= i.index {
return nil, false
}
return i.a.Index(i.index).Interface(), true
}
func (i *arrayIndexPathItem) parent() PropertyPathItem {
return i._parent
}
// Set sets the value at this path item.
func (i PropertyPath) Set(value any) error {
return i[len(i)-1].Set(value)
}
// Append appends a value to the item. Only supported on array items.
func (i PropertyPath) Append(value any) error {
return i[len(i)-1].Append(value)
}
// Remove removes the value at this path item. If value is nil, it is interpreted
// to remove the item itself. Non-nil valued remove is only supported on array items, to
// remove a value from the array.
func (i PropertyPath) Remove(value any) error {
return i[len(i)-1].Remove(value)
}
// Get returns the value at this path item.
func (i PropertyPath) Get() (any, bool) {
return i[len(i)-1].Get()
}
func (i PropertyPath) parent() PropertyPathItem {
return i[len(i)-1].parent()
}
func (i PropertyPath) Parts() []string {
parts := make([]string, len(i))
for idx, item := range i {
switch item := item.(type) {
case *mapValuePathItem:
key := item.key.String()
if strings.ContainsAny(key, ".[") {
key = "[" + key + "]"
} else if idx > 0 {
key = "." + key
}
parts[idx] = key
case *arrayIndexPathItem:
parts[idx] = fmt.Sprintf("[%d]", item.index)
}
}
return parts
}
func (i PropertyPath) String() string {
return strings.Join(i.Parts(), "")
}
func (i PropertyPath) Last() PropertyPathItem {
return i[len(i)-1]
}
type WalkPropertiesFunc func(path PropertyPath, err error) error
var stringerType = reflect.TypeOf((*fmt.Stringer)(nil)).Elem()
func mapKeys(m reflect.Value) ([]reflect.Value, error) {
var toString func(elem reflect.Value) string
keyType := m.Type().Key()
switch {
case keyType.Kind() == reflect.String:
toString = func(elem reflect.Value) string { return elem.String() }
case keyType.Implements(stringerType):
toString = func(elem reflect.Value) string { return elem.Interface().(fmt.Stringer).String() }
default:
return nil, fmt.Errorf("expected map[string|fmt.Stringer]..., got %s", m.Type())
}
keys := m.MapKeys()
sort.Slice(keys, func(i, j int) bool {
a := toString(keys[i])
b := toString(keys[j])
return a < b
})
return keys, nil
}
var SkipProperty = fmt.Errorf("skip property")
func (r *Resource) WalkProperties(fn WalkPropertiesFunc) error {
return r.Properties.WalkProperties(fn)
}
// WalkProperties walks the properties of the resource, calling fn for each property. If fn returns
// SkipProperty, the property and its descendants (if a map or array type) are skipped. If fn returns
// StopWalk, the walk is stopped.
// NOTE: does not walk over the _keys_ of any maps, only values.
func (p Properties) WalkProperties(fn WalkPropertiesFunc) error {
queue := make([]PropertyPath, len(p))
props := reflect.ValueOf(p)
keys, _ := mapKeys(props)
for i, k := range keys {
queue[i] = PropertyPath{&mapValuePathItem{m: props, key: k}}
}
var err error
var current PropertyPath
for len(queue) > 0 {
current, queue = queue[0], queue[1:]
appendPath := func(item PropertyPathItem) PropertyPath {
n := make(PropertyPath, len(current)+1)
copy(n, current)
n[len(n)-1] = item
return n
}
err = fn(current, err)
if err == StopWalk {
return nil
}
if err == SkipProperty {
err = nil
continue
}
added := make(set.Set[string])
rv, ok := current.Get()
if !ok {
continue
}
v := reflect.ValueOf(rv)
switch v.Kind() {
case reflect.Map:
keys, err := mapKeys(v)
if err != nil {
return err
}
for _, k := range keys {
queue = append(queue, appendPath(&mapValuePathItem{
_parent: current.Last(),
m: v,
key: k,
}))
added.Add(queue[len(queue)-1].String())
}
case reflect.Array, reflect.Slice:
// Go backwards so that if the walk function removes an item, we don't skip items (or cause a panic)
// due to items shifting down.
for i := v.Len() - 1; i >= 0; i-- {
queue = append(queue, appendPath(&arrayIndexPathItem{
_parent: current.Last(),
a: v,
index: i,
}))
added.Add(queue[len(queue)-1].String())
}
case reflect.Struct:
// Only support HashedSet[string, any]
hs, ok := v.Interface().(set.HashedSet[string, any])
if !ok {
continue
}
v = reflect.ValueOf(hs.M)
keys, err := mapKeys(v)
if err != nil {
return err
}
for _, k := range keys {
queue = append(queue, appendPath(&mapValuePathItem{
_parent: current.Last(),
m: v,
key: k,
}))
}
}
}
return err
}
func (p Properties) MarshalYAML() (interface{}, error) {
if len(p) == 0 {
return nil, nil
}
// Is there a way to get the sorting for nested maps to work? This only does top-level.
return yaml_util.MarshalMap(p, func(a, b string) bool { return a < b })
}
package construct
import (
"fmt"
"strings"
)
type PropertyRef struct {
Resource ResourceId
Property string
}
func (v PropertyRef) String() string {
if v.IsZero() {
return ""
}
return v.Resource.String() + "#" + v.Property
}
func (v PropertyRef) MarshalText() ([]byte, error) {
return []byte(v.String()), nil
}
func (v *PropertyRef) Parse(s string) error {
res, prop, ok := strings.Cut(s, "#")
if !ok {
return fmt.Errorf("invalid PropertyRef format: %s", s)
}
v.Property = prop
return v.Resource.Parse(res)
}
func (v *PropertyRef) Validate() error {
return v.Resource.Validate()
}
func (v *PropertyRef) UnmarshalText(b []byte) error {
if err := v.Parse(string(b)); err != nil {
return err
}
return v.Validate()
}
func (v *PropertyRef) Equals(ref interface{}) bool {
other, ok := ref.(PropertyRef)
if !ok {
return false
}
return v.Resource == other.Resource && v.Property == other.Property
}
func (v *PropertyRef) IsZero() bool {
return v.Resource.IsZero() && v.Property == ""
}
package construct
type Resource struct {
ID ResourceId
Properties Properties
Imported bool
}
func (r Resource) Equals(other any) bool {
switch other := other.(type) {
case Resource:
return r.ID == other.ID && r.Properties.Equals(other.Properties)
case *Resource:
return r.ID == other.ID && r.Properties.Equals(other.Properties)
default:
return false
}
}
package construct
import (
"encoding/json"
"errors"
"fmt"
"regexp"
"strings"
)
type ResourceId struct {
Provider string `yaml:"provider" toml:"provider"`
Type string `yaml:"type" toml:"type"`
// Namespace is optional and is used to disambiguate resources that might have
// the same name. It can also be used to associate an imported resource with
// a specific namespace such as a subnet to a VPC.
Namespace string `yaml:"namespace" toml:"namespace"`
Name string `yaml:"name" toml:"name"`
}
// ResourceIdChangeResults is a map of old ResourceIds to new ResourceIds
type ResourceIdChangeResults map[ResourceId]ResourceId
type ResourceList []ResourceId
func (m *ResourceIdChangeResults) Merge(other ResourceIdChangeResults) {
if *m == nil {
*m = make(ResourceIdChangeResults)
}
for k, v := range other {
(*m)[k] = v
}
}
func (m ResourceIdChangeResults) RemoveNoop() {
for k, v := range m {
if k == v {
delete(m, k)
}
}
}
func (l ResourceList) String() string {
if len(l) == 1 {
return l[0].String()
}
b, err := json.Marshal(l)
if err != nil {
panic(fmt.Errorf("could not marshal resource list: %w", err))
}
return string(b)
}
func (l *ResourceList) UnmarshalText(b []byte) error {
var id ResourceId
if id.UnmarshalText(b) == nil {
*l = []ResourceId{id}
return nil
}
var ids []ResourceId
if err := json.Unmarshal(b, &ids); err == nil {
*l = ids
return nil
}
return fmt.Errorf("could not unmarshal resource list: %s", string(b))
}
func (l ResourceList) MatchesAny(id ResourceId) bool {
for _, rid := range l {
if rid.Matches(id) {
return true
}
}
return false
}
var zeroId = ResourceId{}
func (id ResourceId) IsZero() bool {
return id == zeroId
}
func (id ResourceId) String() string {
if id.IsZero() {
return ""
}
sb := strings.Builder{}
const numberOfColons = 3 // the maximum number of colons used as separators
sb.Grow(len(id.Provider) + len(id.Type) + len(id.Namespace) + len(id.Name) + numberOfColons)
sb.WriteString(id.Provider)
sb.WriteByte(':')
sb.WriteString(id.Type)
if id.Namespace != "" || strings.Contains(id.Name, ":") {
sb.WriteByte(':')
sb.WriteString(id.Namespace)
}
if id.Name != "" {
sb.WriteByte(':')
sb.WriteString(id.Name)
}
return sb.String()
}
func (id ResourceId) QualifiedTypeName() string {
return id.Provider + ":" + id.Type
}
func (id ResourceId) MarshalText() ([]byte, error) {
return []byte(id.String()), nil
}
// IsAbstractResource returns true if the resource is an abstract resource
func (id ResourceId) IsAbstractResource() bool {
return id.Provider == "klotho"
}
// Matches uses `id` (the receiver) as a filter for `other` (the argument) and returns true if all the non-empty fields from
// `id` match the corresponding fields in `other`.
func (id ResourceId) Matches(other ResourceId) bool {
if id.Provider != "" && id.Provider != other.Provider {
return false
}
if id.Type != "" && id.Type != other.Type {
return false
}
if id.Namespace != "" && id.Namespace != other.Namespace {
return false
}
if id.Name != "" && id.Name != other.Name {
return false
}
return true
}
func SelectIds(ids []ResourceId, selector ResourceId) []ResourceId {
result := make([]ResourceId, 0, len(ids))
for _, id := range ids {
if selector.Matches(id) {
result = append(result, id)
}
}
return result
}
var (
resourceProviderPattern = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
resourceTypePattern = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
resourceNamespacePattern = regexp.MustCompile(`^[a-zA-Z0-9_./\-\[\]]*$`) // like name, but `:` not allowed
resourceNamePattern = regexp.MustCompile(`^[a-zA-Z0-9_./\-:\[\]]*$`)
)
func (id *ResourceId) Parse(s string) error {
parts := strings.SplitN(s, ":", 4)
switch len(parts) {
case 4:
id.Name = parts[3]
fallthrough
case 3:
if len(parts) == 4 {
id.Namespace = parts[2]
} else {
id.Name = parts[2]
}
fallthrough
case 2:
id.Type = parts[1]
id.Provider = parts[0]
case 1:
if parts[0] != "" {
return fmt.Errorf("must have trailing ':' for provider-only ID")
}
}
return nil
}
func (id *ResourceId) Validate() error {
if id.IsZero() {
return nil
}
var err error
if !resourceProviderPattern.MatchString(id.Provider) {
err = errors.Join(err, fmt.Errorf("invalid provider '%s' (must match %s)", id.Provider, resourceProviderPattern))
}
if id.Type != "" && !resourceTypePattern.MatchString(id.Type) {
err = errors.Join(err, fmt.Errorf("invalid type '%s' (must match %s)", id.Type, resourceTypePattern))
}
if id.Namespace != "" && !resourceNamespacePattern.MatchString(id.Namespace) {
err = errors.Join(err, fmt.Errorf("invalid namespace '%s' (must match %s)", id.Namespace, resourceNamespacePattern))
}
if !resourceNamePattern.MatchString(id.Name) {
err = errors.Join(err, fmt.Errorf("invalid name '%s' (must match %s)", id.Name, resourceNamePattern))
}
if err != nil {
return fmt.Errorf("invalid resource id '%s': %w", id.String(), err)
}
return nil
}
func (id *ResourceId) UnmarshalText(data []byte) error {
err := id.Parse(string(data))
if err != nil {
return err
}
return id.Validate()
}
func (id ResourceId) MarshalTOML() ([]byte, error) {
return id.MarshalText()
}
func (id *ResourceId) UnmarshalTOML(data []byte) error {
return id.UnmarshalText(data)
}
package dot
import (
"fmt"
"sort"
"strings"
)
func AttributesToString(attribs map[string]string) string {
if len(attribs) == 0 {
return ""
}
var keys []string
for k := range attribs {
keys = append(keys, k)
}
sort.Strings(keys)
var list []string
for _, k := range keys {
v := attribs[k]
if len(v) > 1 && v[0] == '<' && v[len(v)-1] == '>' {
list = append(list, fmt.Sprintf(`%s=%s`, k, v))
} else {
v = strings.ReplaceAll(v, `"`, `\"`)
list = append(list, fmt.Sprintf(`%s="%s"`, k, v))
}
}
return " [" + strings.Join(list, ", ") + "]"
}
package dot
import (
"bytes"
"fmt"
"io"
"os/exec"
"regexp"
"strings"
"github.com/google/pprof/third_party/svgpan"
"go.uber.org/zap"
)
// THe following adds SVG pan to the SVG output from DOT, taken from
// https://github.com/google/pprof/blob/main/internal/driver/svg.go
var (
viewBox = regexp.MustCompile(`<svg\s*width="[^"]+"\s*height="[^"]+"\s*viewBox="[^"]+"`)
graphID = regexp.MustCompile(`<g id="graph\d"`)
svgClose = regexp.MustCompile(`</svg>`)
)
// SvgPan enhances the SVG output from DOT to provide better
// panning inside a web browser. It uses the svgpan library, which is
// embedded into the svgpan.JSSource variable.
func SvgPan(svg string) string {
// Work around for dot bug which misses quoting some ampersands,
// resulting on unparsable SVG.
svg = strings.Replace(svg, "&;", "&;", -1)
// Dot's SVG output is
//
// <svg width="___" height="___"
// viewBox="___" xmlns=...>
// <g id="graph0" transform="...">
// ...
// </g>
// </svg>
//
// Change it to
//
// <svg width="100%" height="100%"
// xmlns=...>
// <script type="text/ecmascript"><![CDATA[` ..$(svgpan.JSSource)... `]]></script>`
// <g id="viewport" transform="translate(0,0)">
// <g id="graph0" transform="...">
// ...
// </g>
// </g>
// </svg>
if loc := viewBox.FindStringIndex(svg); loc != nil {
svg = svg[:loc[0]] +
`<svg width="100%" height="100%"` +
svg[loc[1]:]
}
if loc := graphID.FindStringIndex(svg); loc != nil {
svg = svg[:loc[0]] +
`<script type="text/ecmascript"><![CDATA[` + svgpan.JSSource + `]]></script>` +
`<g id="viewport" transform="scale(0.5,0.5) translate(0,0)">` +
svg[loc[0]:]
}
if loc := svgClose.FindStringIndex(svg); loc != nil {
svg = svg[:loc[0]] +
`</g>` +
svg[loc[0]:]
}
return svg
}
func Execute(input io.Reader, output io.Writer) error {
errBuff := new(bytes.Buffer)
cmd := exec.Command("dot", "-Tsvg")
cmd.Stdin = input
cmd.Stdout = output
cmd.Stderr = errBuff
err := cmd.Run()
if err != nil {
return fmt.Errorf("could not run 'dot': %w: %s", err, errBuff.String())
}
return nil
}
func ExecPan(input io.Reader) (string, error) {
out := new(bytes.Buffer)
err := Execute(input, out)
if err != nil {
return "", err
}
zap.S().Named("dot").Debugf("dot output %d bytes", out.Len())
return SvgPan(out.String()), nil
}
package engine
import (
"github.com/klothoplatform/klotho/pkg/collectionutil"
construct "github.com/klothoplatform/klotho/pkg/construct"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
func (e *Engine) ListResources() []construct.ResourceId {
resourceTemplates := e.Kb.ListResources()
resources := []construct.ResourceId{}
for _, res := range resourceTemplates {
resources = append(resources, res.Id())
}
return resources
}
func (e *Engine) ListProviders() []string {
resourceTemplates := e.Kb.ListResources()
providers := []string{}
for _, res := range resourceTemplates {
provider := res.Id().Provider
if !collectionutil.Contains(providers, provider) {
providers = append(providers, provider)
}
}
return providers
}
func (e *Engine) ListFunctionalities() []knowledgebase.Functionality {
functionalities := []knowledgebase.Functionality{}
resourceTemplates := e.Kb.ListResources()
for _, res := range resourceTemplates {
functionality := res.GetFunctionality()
if !collectionutil.Contains(functionalities, functionality) {
functionalities = append(functionalities, functionality)
}
}
return functionalities
}
func (e *Engine) ListAttributes() []string {
attributes := []string{}
resourceTemplates := e.Kb.ListResources()
for _, res := range resourceTemplates {
for _, is := range res.Classification.Is {
if !collectionutil.Contains(attributes, is) {
attributes = append(attributes, is)
}
}
for _, gives := range res.Classification.Gives {
if !collectionutil.Contains(attributes, gives.Attribute) {
attributes = append(attributes, gives.Attribute)
}
}
}
return attributes
}
package engine
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"os"
"reflect"
"strings"
"sync"
"github.com/iancoleman/strcase"
clicommon "github.com/klothoplatform/klotho/pkg/cli_common"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/constraints"
engine_errs "github.com/klothoplatform/klotho/pkg/engine/errors"
"github.com/klothoplatform/klotho/pkg/engine/solution"
kio "github.com/klothoplatform/klotho/pkg/io"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/klothoplatform/klotho/pkg/provider/aws"
"github.com/klothoplatform/klotho/pkg/templates"
"github.com/pkg/errors"
"github.com/spf13/cobra"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
)
type (
EngineMain struct {
Engine *Engine
cleanup func()
}
)
var commonCfg clicommon.CommonConfig
var engineCfg struct {
provider string
guardrails string
}
var architectureEngineCfg struct {
provider string
guardrails string
inputGraph string
constraints string
outputDir string
globalTag string
}
var getValidEdgeTargetsCfg struct {
guardrails string
inputGraph string
configFile string
outputDir string
}
func (em *EngineMain) AddEngineCli(root *cobra.Command) {
em.cleanup = clicommon.SetupCoreCommand(root, &commonCfg)
engineGroup := &cobra.Group{
ID: "engine",
Title: "engine",
}
listResourceTypesCmd := &cobra.Command{
Use: "ListResourceTypes",
Short: "List resource types available in the klotho engine",
GroupID: engineGroup.ID,
RunE: em.ListResourceTypes,
}
flags := listResourceTypesCmd.Flags()
flags.StringVarP(&engineCfg.provider, "provider", "p", "aws", "Provider to use")
flags.StringVar(&engineCfg.guardrails, "guardrails", "", "Guardrails file")
listAttributesCmd := &cobra.Command{
Use: "ListAttributes",
Short: "List attributes available in the klotho engine",
GroupID: engineGroup.ID,
RunE: em.ListAttributes,
}
flags = listAttributesCmd.Flags()
flags.StringVarP(&engineCfg.provider, "provider", "p", "aws", "Provider to use")
flags.StringVar(&engineCfg.guardrails, "guardrails", "", "Guardrails file")
runCmd := &cobra.Command{
Use: "Run",
Short: "Run the klotho engine",
GroupID: engineGroup.ID,
Run: func(cmd *cobra.Command, args []string) {
exitCode := em.RunEngine(cmd, args)
em.cleanup()
os.Exit(exitCode)
},
}
flags = runCmd.Flags()
flags.StringVarP(&architectureEngineCfg.provider, "provider", "p", "aws", "Provider to use")
flags.StringVar(&architectureEngineCfg.guardrails, "guardrails", "", "Guardrails file")
flags.StringVarP(&architectureEngineCfg.inputGraph, "input-graph", "i", "", "Input graph file")
flags.StringVarP(&architectureEngineCfg.constraints, "constraints", "c", "", "Constraints file")
flags.StringVarP(&architectureEngineCfg.outputDir, "output-dir", "o", "", "Output directory")
flags.StringVarP(&architectureEngineCfg.globalTag, "global-tag", "t", "", "Global tag")
getPossibleEdgesCmd := &cobra.Command{
Use: "GetValidEdgeTargets",
Short: "Get the valid topological edge targets for the supplied configuration and input graph",
GroupID: engineGroup.ID,
RunE: em.GetValidEdgeTargets,
}
flags = getPossibleEdgesCmd.Flags()
flags.StringVar(&getValidEdgeTargetsCfg.guardrails, "guardrails", "", "Guardrails file")
flags.StringVarP(&getValidEdgeTargetsCfg.inputGraph, "input-graph", "i", "", "Input graph file")
flags.StringVarP(&getValidEdgeTargetsCfg.configFile, "config", "c", "", "config file")
flags.StringVarP(&getValidEdgeTargetsCfg.outputDir, "output-dir", "o", "", "Output directory")
root.AddGroup(engineGroup)
root.AddCommand(listResourceTypesCmd)
root.AddCommand(listAttributesCmd)
root.AddCommand(runCmd)
root.AddCommand(getPossibleEdgesCmd)
}
func (em *EngineMain) AddEngine() error {
kb, err := templates.NewKBFromTemplates()
if err != nil {
return err
}
em.Engine = NewEngine(kb)
return nil
}
type resourceInfo struct {
Classifications []string `json:"classifications"`
DisplayName string `json:"displayName"`
Properties map[string]any `json:"properties"`
Views map[string]string `json:"views"`
}
var validationFields = []string{"MinLength", "MaxLength", "MinValue", "MaxValue", "AllowedValues", "UniqueItems", "UniqueKeys", "MinSize", "MaxSize"}
func addSubProperties(properties map[string]any, subProperties map[string]knowledgebase.Property) {
for _, subProperty := range subProperties {
details := subProperty.Details()
properties[details.Name] = map[string]any{
"type": subProperty.Type(),
"deployTime": details.DeployTime,
"configurationDisabled": details.ConfigurationDisabled,
"required": details.Required,
"description": details.Description,
"important": details.IsImportant,
}
for _, validationField := range validationFields {
valField := reflect.ValueOf(subProperty).Elem().FieldByName(validationField)
if valField.IsValid() && !valField.IsZero() {
val := valField.Interface()
properties[details.Name].(map[string]any)[strcase.ToLowerCamel(validationField)] = val
}
}
if subProperty.SubProperties() != nil {
properties[details.Name].(map[string]any)["properties"] = map[string]any{}
addSubProperties(properties[details.Name].(map[string]any)["properties"].(map[string]any), subProperty.SubProperties())
}
}
}
func (em *EngineMain) ListResourceTypes(cmd *cobra.Command, args []string) error {
err := em.AddEngine()
if err != nil {
return err
}
resourceTypes := em.Engine.Kb.ListResources()
typeAndClassifications := map[string]resourceInfo{}
for _, resourceType := range resourceTypes {
properties := map[string]any{}
addSubProperties(properties, resourceType.Properties)
typeAndClassifications[resourceType.QualifiedTypeName] = resourceInfo{
Classifications: resourceType.Classification.Is,
Properties: properties,
DisplayName: resourceType.DisplayName,
Views: resourceType.Views,
}
}
b, err := json.Marshal(typeAndClassifications)
if err != nil {
return err
}
fmt.Println(string(b))
return nil
}
func (em *EngineMain) ListAttributes(cmd *cobra.Command, args []string) error {
err := em.AddEngine()
if err != nil {
return err
}
attributes := em.Engine.ListAttributes()
fmt.Println(strings.Join(attributes, "\n"))
return nil
}
func extractEngineErrors(err error) []engine_errs.EngineError {
if err == nil {
return nil
}
var errs []engine_errs.EngineError
queue := []error{err}
for len(queue) > 0 {
err := queue[0]
queue = queue[1:]
switch err := err.(type) {
case engine_errs.EngineError:
errs = append(errs, err)
case interface{ Unwrap() []error }:
queue = append(queue, err.Unwrap()...)
case interface{ Unwrap() error }:
queue = append(queue, err.Unwrap())
}
}
if len(errs) == 0 {
errs = append(errs, engine_errs.InternalError{Err: err})
}
return errs
}
func (em *EngineMain) Run(ctx context.Context, req *SolveRequest) (int, solution.Solution, []engine_errs.EngineError) {
returnCode := 0
var engErrs []engine_errs.EngineError
log := zap.S().Named("engine")
log.Info("Running engine")
sol, err := em.Engine.Run(ctx, req)
if err != nil {
// When the engine returns an error, that indicates that it halted evaluation, thus is a fatal error.
// This is returned as exit code 1, and add the details to be printed to stdout.
returnCode = 1
engErrs = append(engErrs, extractEngineErrors(err)...)
log.Errorf("Engine returned error: %v", err)
}
writeDebugGraphs(sol)
// If there are any decisions that are engine errors, add them to the list of error details
// to be printed to stdout. These are returned as exit code 2 unless it is already code 1.
for _, d := range sol.GetDecisions() {
d, ok := d.(solution.AsEngineError)
if !ok {
continue
}
ee := d.TryEngineError()
if ee == nil {
continue
}
engErrs = append(engErrs, ee)
if returnCode != 1 {
returnCode = 2
}
}
return returnCode, sol, engErrs
}
func writeEngineErrsJson(errs []engine_errs.EngineError, out io.Writer) error {
enc := json.NewEncoder(out)
enc.SetIndent("", " ")
// NOTE: since this isn't used in a web context (it's a CLI), we can disable escaping.
enc.SetEscapeHTML(false)
outErrs := make([]map[string]any, len(errs))
for i, e := range errs {
outErrs[i] = e.ToJSONMap()
outErrs[i]["error_code"] = e.ErrorCode()
wrapped := errors.Unwrap(e)
if wrapped != nil {
outErrs[i]["error"] = engine_errs.ErrorsToTree(wrapped)
}
}
return enc.Encode(outErrs)
}
func (em *EngineMain) RunEngine(cmd *cobra.Command, args []string) (exitCode int) {
var engErrs []engine_errs.EngineError
internalError := func(err error) {
engErrs = append(engErrs, engine_errs.InternalError{Err: err})
exitCode = 1
}
log := zap.S().Named("engine")
defer func() { // defer functions execute in FILO order, so this executes after the 'recover'.
err := writeEngineErrsJson(engErrs, os.Stdout)
if err != nil {
log.Errorf("failed to output errors to stdout: %v", err)
}
}()
defer func() {
r := recover()
if r == nil {
return
}
log.Errorf("panic: %v", r)
switch r := r.(type) {
case engine_errs.EngineError:
engErrs = append(engErrs, r)
case error:
engErrs = append(engErrs, engine_errs.InternalError{Err: r})
default:
engErrs = append(engErrs, engine_errs.InternalError{Err: fmt.Errorf("panic: %v", r)})
}
}()
err := em.AddEngine()
if err != nil {
internalError(err)
return
}
context := &SolveRequest{
GlobalTag: architectureEngineCfg.globalTag,
}
if architectureEngineCfg.inputGraph != "" {
var input FileFormat
log.Info("Loading input graph")
inputF, err := os.Open(architectureEngineCfg.inputGraph)
if err != nil {
internalError(err)
return
}
defer inputF.Close()
err = yaml.NewDecoder(inputF).Decode(&input)
if err != nil {
internalError(fmt.Errorf("failed to decode input graph: %w", err))
return
}
context.InitialState = input.Graph
if architectureEngineCfg.constraints == "" {
context.Constraints = input.Constraints
}
} else {
context.InitialState = construct.NewGraph()
}
log.Info("Loading constraints")
if architectureEngineCfg.constraints != "" {
runConstraints, err := constraints.LoadConstraintsFromFile(architectureEngineCfg.constraints)
if err != nil {
internalError(fmt.Errorf("failed to load constraints: %w", err))
return
}
context.Constraints = runConstraints
}
// len(engErrs) == 0 at this point so overwriting it is safe
// All other assignments prior are via 'internalError' and return
exitCode, sol, engErrs := em.Run(cmd.Context(), context)
if exitCode == 1 {
return
}
var files []kio.File
configErrors := new(bytes.Buffer)
err = writeEngineErrsJson(engErrs, configErrors)
if err != nil {
internalError(fmt.Errorf("failed to write config errors: %w", err))
return
}
files = append(files, &kio.RawFile{
FPath: "config_errors.json",
Content: configErrors.Bytes(),
})
log.Info("Engine finished running... Generating views")
vizFiles, err := em.Engine.VisualizeViews(sol)
if err != nil {
internalError(fmt.Errorf("failed to generate views %w", err))
return
}
files = append(files, vizFiles...)
log.Info("Generating resources.yaml")
b, err := yaml.Marshal(construct.YamlGraph{Graph: sol.DataflowGraph()})
if err != nil {
internalError(fmt.Errorf("failed to marshal graph: %w", err))
return
}
files = append(files,
&kio.RawFile{
FPath: "resources.yaml",
Content: b,
},
)
if architectureEngineCfg.provider == "aws" {
polictBytes, err := aws.DeploymentPermissionsPolicy(sol)
if err != nil {
internalError(fmt.Errorf("failed to generate deployment permissions policy: %w", err))
return
}
files = append(files,
&kio.RawFile{
FPath: "deployment_permissions_policy.json",
Content: polictBytes,
},
)
}
err = kio.OutputTo(files, architectureEngineCfg.outputDir)
if err != nil {
internalError(fmt.Errorf("failed to write output files: %w", err))
return
}
return
}
func (em *EngineMain) GetValidEdgeTargets(cmd *cobra.Command, args []string) error {
log := zap.S().Named("engine")
err := em.AddEngine()
if err != nil {
return err
}
log.Info("loading config")
inputF, err := os.ReadFile(getValidEdgeTargetsCfg.inputGraph)
if err != nil {
return err
}
config, err := ReadGetValidEdgeTargetsConfig(getValidEdgeTargetsCfg.configFile)
if err != nil {
return fmt.Errorf("failed to load constraints: %w", err)
}
context := &GetPossibleEdgesContext{
InputGraph: inputF,
GetValidEdgeTargetsConfig: config,
}
log.Info("getting valid edge targets")
validTargets, err := em.Engine.GetValidEdgeTargets(context)
if err != nil {
return fmt.Errorf("failed to run engine: %w", err)
}
log.Info("writing output files")
b, err := yaml.Marshal(validTargets)
if err != nil {
return fmt.Errorf("failed to marshal possible edges: %w", err)
}
var files []kio.File
files = append(files, &kio.RawFile{
FPath: "valid_edge_targets.yaml",
Content: b,
})
err = kio.OutputTo(files, getValidEdgeTargetsCfg.outputDir)
if err != nil {
return fmt.Errorf("failed to write output files: %w", err)
}
return nil
}
func writeDebugGraphs(sol solution.Solution) {
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
err := GraphToSVG(sol.KnowledgeBase(), sol.DataflowGraph(), "dataflow")
if err != nil {
zap.S().Named("engine").Errorf("failed to write dataflow graph: %w", err)
}
}()
go func() {
defer wg.Done()
err := GraphToSVG(sol.KnowledgeBase(), sol.DeploymentGraph(), "iac")
if err != nil {
zap.S().Named("engine").Errorf("failed to write iac graph: %w", err)
}
}()
wg.Wait()
}
package engine
import (
"errors"
"fmt"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/constraints"
"github.com/klothoplatform/klotho/pkg/engine/reconciler"
"github.com/klothoplatform/klotho/pkg/engine/solution"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/klothoplatform/klotho/pkg/tui"
)
func ApplyConstraints(sol solution.Solution) error {
prog := tui.GetProgress(sol.Context())
cs := sol.Constraints()
current, total := 0, len(cs.Application)+len(cs.Edges)+len(cs.Resources)
var errs []error
for _, constraint := range cs.Application {
err := applyApplicationConstraint(sol, constraint)
if err != nil {
errs = append(errs, fmt.Errorf("failed to apply constraint %#v: %w", constraint, err))
}
current++
prog.Update("Loading constraints", current, total)
}
if len(errs) > 0 {
return errors.Join(errs...)
}
for _, constraint := range cs.Edges {
err := applyEdgeConstraint(sol, constraint)
if err != nil {
errs = append(errs, fmt.Errorf("failed to apply constraint %#v: %w", constraint, err))
}
current++
prog.Update("Loading constraints", current, total)
}
if len(errs) > 0 {
return errors.Join(errs...)
}
resourceConstraints := cs.Resources
for i := range resourceConstraints {
err := applySanitization(sol, &resourceConstraints[i])
if err != nil {
errs = append(errs, fmt.Errorf("failed to apply constraint %#v: %w", resourceConstraints[i], err))
}
current++
prog.Update("Loading constraints", current, total)
}
if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}
// applyApplicationConstraint returns a resource to be made operational, if needed. Otherwise, it returns nil.
func applyApplicationConstraint(ctx solution.Solution, constraint constraints.ApplicationConstraint) error {
res, err := knowledgebase.CreateResource(ctx.KnowledgeBase(), constraint.Node)
if err != nil {
return err
}
switch constraint.Operator {
case constraints.AddConstraintOperator:
return ctx.OperationalView().AddVertex(res)
case constraints.MustExistConstraintOperator:
err := ctx.OperationalView().AddVertex(res)
if errors.Is(err, graph.ErrVertexAlreadyExists) {
return nil
}
return err
case constraints.ImportConstraintOperator:
res.Imported = true
return ctx.OperationalView().AddVertex(res)
case constraints.RemoveConstraintOperator:
return reconciler.RemoveResource(ctx, res.ID, true)
case constraints.MustNotExistConstraintOperator:
err := reconciler.RemoveResource(ctx, res.ID, false)
if errors.Is(err, graph.ErrVertexNotFound) {
return nil
}
return err
case constraints.ReplaceConstraintOperator:
node, err := ctx.RawView().Vertex(res.ID)
if err != nil {
return fmt.Errorf("could not find resource for %s: %w", res.ID, err)
}
if node.ID.QualifiedTypeName() == constraint.ReplacementNode.QualifiedTypeName() {
rt, err := ctx.KnowledgeBase().GetResourceTemplate(constraint.ReplacementNode)
if err != nil {
return err
}
constraint.ReplacementNode.Name, err = rt.SanitizeName(constraint.ReplacementNode.Name)
if err != nil {
return err
}
return ctx.OperationalView().UpdateResourceID(res.ID, constraint.ReplacementNode)
} else {
replacement, err := knowledgebase.CreateResource(ctx.KnowledgeBase(), constraint.ReplacementNode)
if err != nil {
return err
}
return construct.ReplaceResource(ctx.OperationalView(), res.ID, replacement)
}
default:
return fmt.Errorf("unknown operator %s", constraint.Operator)
}
}
// applyEdgeConstraint applies an edge constraint to the either the engines working state construct graph or end state resource graph
//
// The following actions are taken for each operator
// - MustExistConstraintOperator, the edge is added to the working state construct graph
// - MustNotExistConstraintOperator, the edge is removed from the working state construct graph if the source and targets refer to klotho constructs. Otherwise the action fails
//
// The following operators are handled during path selection, so any existing paths must be
// - MustContainConstraintOperator, the constraint is applied to the edge before edge expansion, so when we use the knowledgebase to expand it ensures the node in the constraint is present in the expanded path
// - MustNotContainConstraintOperator, the constraint is applied to the edge before edge expansion, so when we use the knowledgebase to expand it ensures the node in the constraint is not present in the expanded path
func applyEdgeConstraint(ctx solution.Solution, constraint constraints.EdgeConstraint) error {
for _, id := range []*construct.ResourceId{&constraint.Target.Source, &constraint.Target.Target} {
rt, err := ctx.KnowledgeBase().GetResourceTemplate(*id)
if err != nil {
res := "source"
if *id == constraint.Target.Target {
res = "target"
}
return fmt.Errorf("could not get template for %s: %w", res, err)
}
(*id).Name, err = rt.SanitizeName((*id).Name)
if err != nil {
res := "source"
if *id == constraint.Target.Target {
res = "target"
}
return fmt.Errorf("could not sanitize %s name: %w", res, err)
}
}
if constraint.Target.Source.Name == "" || constraint.Target.Target.Name == "" {
if constraint.Target.Source.Name == "" && constraint.Target.Target.Name == "" {
return fmt.Errorf("source and target names are empty")
}
// This is considered a global constraint for the type which does not have a name and
// will be applied anytime a new resource is added to the graph
return nil
}
switch constraint.Operator {
case constraints.AddConstraintOperator:
return ctx.OperationalView().AddEdge(constraint.Target.Source, constraint.Target.Target)
case constraints.MustExistConstraintOperator:
err := ctx.OperationalView().AddEdge(constraint.Target.Source, constraint.Target.Target)
if errors.Is(err, graph.ErrEdgeAlreadyExists) {
return nil
}
return err
case constraints.RemoveConstraintOperator:
return reconciler.RemovePath(constraint.Target.Source, constraint.Target.Target, ctx)
case constraints.MustNotExistConstraintOperator:
err := reconciler.RemovePath(constraint.Target.Source, constraint.Target.Target, ctx)
if errors.Is(err, graph.ErrTargetNotReachable) {
return nil
}
}
return nil
}
// applySanitization applies sanitization to the resource name in ResourceConstraints. This is not needed on
// Application or Edge constraints due to them applying within the graph (to make sure that even generated resources
// are sanitized).
func applySanitization(ctx solution.Solution, constraint *constraints.ResourceConstraint) error {
rt, err := ctx.KnowledgeBase().GetResourceTemplate(constraint.Target)
if err != nil {
return err
}
constraint.Target.Name, err = rt.SanitizeName(constraint.Target.Name)
return err
}
package constraints
import (
"errors"
"fmt"
construct "github.com/klothoplatform/klotho/pkg/construct"
)
type (
// ApplicationConstraint is a struct that represents constraints that can be applied on the entire resource graph
//
// Example
//
// To specify a constraint detailing application level intents in yaml
//
//- scope: application
// operator: add
// node: klotho:execution_unit:my_compute
//
// The end result of this should be that the execution unit construct is added to the construct graph for processing
ApplicationConstraint struct {
Operator ConstraintOperator `yaml:"operator" json:"operator"`
Node construct.ResourceId `yaml:"node" json:"node"`
ReplacementNode construct.ResourceId `yaml:"replacement_node,omitempty" json:"replacement_node,omitempty"`
}
)
func (constraint *ApplicationConstraint) Scope() ConstraintScope {
return ApplicationConstraintScope
}
func (constraint *ApplicationConstraint) IsSatisfied(ctx ConstraintGraph) bool {
switch constraint.Operator {
case AddConstraintOperator, MustExistConstraintOperator:
nodeToSearchFor := constraint.Node
// If the add was for a construct, we need to check if any resource references the construct
if constraint.Node.IsAbstractResource() {
nodeToSearchFor = ctx.GetConstructsResource(constraint.Node).ID
}
res, _ := ctx.GetResource(nodeToSearchFor)
return res != nil
case RemoveConstraintOperator, MustNotExistConstraintOperator:
nodeToSearchFor := constraint.Node
// If the remove was for a construct, we need to check if any resource references the construct
if constraint.Node.IsAbstractResource() {
nodeToSearchFor = ctx.GetConstructsResource(constraint.Node).ID
}
res, _ := ctx.GetResource(nodeToSearchFor)
return res == nil
case ReplaceConstraintOperator:
// We should ensure edges are copied from the original source to the new replacement node in the dag
// Ignoring for now, but will be an extra check we can make later to ensure that the Replace constraint is fully satisfied
// If any of the nodes are abstract constructs, we need to check if any resource references the construct
if constraint.Node.IsAbstractResource() && constraint.ReplacementNode.IsAbstractResource() {
return ctx.GetConstructsResource(constraint.Node) == nil && ctx.GetConstructsResource(constraint.ReplacementNode) != nil
} else if constraint.Node.IsAbstractResource() && !constraint.ReplacementNode.IsAbstractResource() {
res, err := ctx.GetResource(constraint.ReplacementNode)
if err != nil {
return false
}
return ctx.GetConstructsResource(constraint.Node) == nil && res != nil
} else if !constraint.Node.IsAbstractResource() && constraint.ReplacementNode.IsAbstractResource() {
res, err := ctx.GetResource(constraint.Node)
if err != nil {
return false
}
return res == nil && ctx.GetConstructsResource(constraint.ReplacementNode) != nil
}
node, _ := ctx.GetResource(constraint.Node)
replacementNode, _ := ctx.GetResource(constraint.ReplacementNode)
return node == nil && replacementNode != nil
}
return false
}
func (constraint *ApplicationConstraint) Validate() error {
switch constraint.Operator {
case AddConstraintOperator, MustExistConstraintOperator:
if constraint.Node.IsZero() {
return errors.New("add/must_exist constraint must have a node defined")
}
case RemoveConstraintOperator, MustNotExistConstraintOperator:
if constraint.Node.IsZero() {
return errors.New("remove/must_not_exist constraint must have a node defined")
}
case ReplaceConstraintOperator:
if constraint.Node.IsZero() || constraint.ReplacementNode.IsZero() {
return errors.New("replace constraint must have a node and replacement node defined")
}
}
return nil
}
func (constraint *ApplicationConstraint) String() string {
return fmt.Sprintf("ApplicationConstraint: %s %s %s", constraint.Operator, constraint.Node, constraint.ReplacementNode)
}
package constraints
import (
"encoding/json"
"errors"
"fmt"
"os"
construct "github.com/klothoplatform/klotho/pkg/construct"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/klothoplatform/klotho/pkg/yaml_util"
"gopkg.in/yaml.v3"
)
type (
// Constraint is an interface detailing different intents that can be applied to a resource graph
Constraint interface {
// Scope returns where on the resource graph the constraint is applied
Scope() ConstraintScope
// IsSatisfied returns whether or not the constraint is satisfied based on the resource graph
// For a resource graph to be valid all constraints must be satisfied
IsSatisfied(ctx ConstraintGraph) bool
// Validate returns whether or not the constraint is valid
Validate() error
String() string
}
ConstraintGraph interface {
GetConstructsResource(construct.ResourceId) *construct.Resource
GetResource(construct.ResourceId) (*construct.Resource, error)
AllPaths(src, dst construct.ResourceId) ([][]*construct.Resource, error)
GetClassification(construct.ResourceId) knowledgebase.Classification
GetOutput(string) construct.Output
}
// BaseConstraint is the base struct for all constraints
// BaseConstraint is used in our parsing to determine the Scope of the constraint and what go struct it corresponds to
BaseConstraint struct {
Scope ConstraintScope `yaml:"scope"`
}
// Edge is a struct that represents how we take in data about an edge in the resource graph
Edge struct {
Source construct.ResourceId `yaml:"source"`
Target construct.ResourceId `yaml:"target"`
}
// ConstraintScope is an enum that represents the different scopes that a constraint can be applied to
ConstraintScope string
// ConstraintOperator is an enum that represents the different operators that can be applied to a constraint
ConstraintOperator string
ConstraintList []Constraint
Constraints struct {
Application []ApplicationConstraint
Construct []ConstructConstraint
Resources []ResourceConstraint
Edges []EdgeConstraint
Outputs []OutputConstraint
}
)
const (
ApplicationConstraintScope ConstraintScope = "application"
ConstructConstraintScope ConstraintScope = "construct"
EdgeConstraintScope ConstraintScope = "edge"
ResourceConstraintScope ConstraintScope = "resource"
OutputConstraintScope ConstraintScope = "output"
MustExistConstraintOperator ConstraintOperator = "must_exist"
MustNotExistConstraintOperator ConstraintOperator = "must_not_exist"
AddConstraintOperator ConstraintOperator = "add"
ImportConstraintOperator ConstraintOperator = "import"
RemoveConstraintOperator ConstraintOperator = "remove"
ReplaceConstraintOperator ConstraintOperator = "replace"
EqualsConstraintOperator ConstraintOperator = "equals"
)
func (cs ConstraintList) MarshalYAML() (interface{}, error) {
var list []yaml.Node
for _, c := range cs {
var n yaml.Node
err := n.Encode(c)
if err != nil {
return nil, err
}
scope := []*yaml.Node{
{
Kind: yaml.ScalarNode,
Value: "scope",
},
{
Kind: yaml.ScalarNode,
Value: string(c.Scope()),
},
}
n.Content = append(scope, n.Content...)
list = append(list, n)
}
return list, nil
}
func (cs ConstraintList) MarshalJSON() ([]byte, error) {
list := make([]map[string]interface{}, len(cs))
for i, c := range cs {
m := map[string]interface{}{
"scope": c.Scope(),
}
b, err := json.Marshal(c)
if err != nil {
return nil, err
}
err = json.Unmarshal(b, &m)
if err != nil {
return nil, err
}
list[i] = m
}
return json.Marshal(list)
}
func (cs *ConstraintList) UnmarshalYAML(node *yaml.Node) error {
var list []yaml_util.RawNode
err := node.Decode(&list)
if err != nil {
return err
}
*cs = make(ConstraintList, len(list))
var errs error
for i, raw := range list {
var base BaseConstraint
err = raw.Decode(&base)
if err != nil {
errs = errors.Join(errs, err)
continue
}
var c Constraint
switch base.Scope {
case ApplicationConstraintScope:
var constraint ApplicationConstraint
err = raw.Decode(&constraint)
c = &constraint
case ConstructConstraintScope:
var constraint ConstructConstraint
err = raw.Decode(&constraint)
c = &constraint
case EdgeConstraintScope:
var constraint EdgeConstraint
err = raw.Decode(&constraint)
c = &constraint
case ResourceConstraintScope:
var constraint ResourceConstraint
err = raw.Decode(&constraint)
c = &constraint
case OutputConstraintScope:
var constraint OutputConstraint
err = raw.Decode(&constraint)
c = &constraint
default:
err = fmt.Errorf("invalid scope %q", base.Scope)
}
if err != nil {
errs = errors.Join(errs, err)
continue
}
if err := c.Validate(); err != nil {
errs = errors.Join(errs, err)
continue
}
(*cs)[i] = c
}
return errs
}
func (list ConstraintList) ToConstraints() (Constraints, error) {
var constraints Constraints
for _, constraint := range list {
switch c := constraint.(type) {
case *ApplicationConstraint:
constraints.Application = append(constraints.Application, *c)
case *ConstructConstraint:
constraints.Construct = append(constraints.Construct, *c)
case *ResourceConstraint:
constraints.Resources = append(constraints.Resources, *c)
case *EdgeConstraint:
constraints.Edges = append(constraints.Edges, *c)
case *OutputConstraint:
constraints.Outputs = append(constraints.Outputs, *c)
default:
return Constraints{}, fmt.Errorf("invalid constraint type %T", constraint)
}
}
return constraints, nil
}
func (list ConstraintList) NaturalSort(i, j int) bool {
a, b := list[i], list[j]
if a, b := a.Scope(), b.Scope(); a != b {
return a < b
}
return a.String() < b.String()
}
func LoadConstraintsFromFile(path string) (Constraints, error) {
var input struct {
Constraints ConstraintList `yaml:"constraints"`
}
f, err := os.Open(path)
if err != nil {
return Constraints{}, err
}
defer f.Close() //nolint:errcheck
err = yaml.NewDecoder(f).Decode(&input)
if err != nil {
return Constraints{}, err
}
return input.Constraints.ToConstraints()
}
// ParseConstraintsFromFile parses a yaml file into a map of constraints
//
// Future spec may include ordering of the application of constraints, but for now we assume that the order of the constraints is based on the yaml file and they cannot be grouped outside of scope
func ParseConstraintsFromFile(bytes []byte) (Constraints, error) {
var list ConstraintList
err := yaml.Unmarshal(bytes, &list)
if err != nil {
return Constraints{}, err
}
return list.ToConstraints()
}
func (c Constraints) ToList() ConstraintList {
var list ConstraintList
for i := range c.Application {
list = append(list, &c.Application[i])
}
for i := range c.Construct {
list = append(list, &c.Construct[i])
}
for i := range c.Resources {
list = append(list, &c.Resources[i])
}
for i := range c.Edges {
list = append(list, &c.Edges[i])
}
for i := range c.Outputs {
list = append(list, &c.Outputs[i])
}
return list
}
func (c Constraints) MarshalYAML() (interface{}, error) {
return c.ToList(), nil
}
func (c *Constraints) UnmarshalYAML(node *yaml.Node) error {
var list ConstraintList
err := node.Decode(&list)
if err != nil {
return err
}
*c, err = list.ToConstraints()
return err
}
func (c *Constraints) Append(other Constraints) {
c.Application = append(c.Application, other.Application...)
c.Construct = append(c.Construct, other.Construct...)
c.Resources = append(c.Resources, other.Resources...)
c.Edges = append(c.Edges, other.Edges...)
c.Outputs = append(c.Outputs, other.Outputs...)
}
package constraints
import (
"errors"
"fmt"
construct "github.com/klothoplatform/klotho/pkg/construct"
)
type (
// ConstructConstraint is a struct that represents constraints that can be applied on a specific construct in the resource graph
//
// Example
//
// To specify a constraint detailing construct expansion and configuration in yaml
//
// - scope: construct
// operator: equals
// target: klotho:orm:my_orm
// type: rds_instance
//
// The end result of this should be that the orm construct is expanded into an rds instance + necessary resources
ConstructConstraint struct {
Operator ConstraintOperator `yaml:"operator" json:"operator"`
Target construct.ResourceId `yaml:"target" json:"target"`
Type string `yaml:"type" json:"type"`
Attributes map[string]any `yaml:"attributes" json:"attributes"`
}
)
func (constraint *ConstructConstraint) Scope() ConstraintScope {
return ConstructConstraintScope
}
func (constraint *ConstructConstraint) IsSatisfied(ctx ConstraintGraph) bool {
switch constraint.Operator {
case EqualsConstraintOperator:
// Well look at all resources to see if there is a resource matching the type, that references the base construct passed in
// Cuirrently attributes go unchecked
res := ctx.GetConstructsResource(constraint.Target)
if res == nil {
return false
}
if constraint.Type != "" && res.ID.Type != constraint.Type {
return false
}
return true
}
return false
}
func (constraint *ConstructConstraint) Validate() error {
if !constraint.Target.IsAbstractResource() {
return errors.New("node constraint must be applied to an abstract construct")
}
return nil
}
func (constraint *ConstructConstraint) String() string {
return fmt.Sprintf("Constraint: %s %s %s", constraint.Scope(), constraint.Operator, constraint.Target)
}
package constraints
import (
"fmt"
construct "github.com/klothoplatform/klotho/pkg/construct"
)
type (
// EdgeConstraint is a struct that represents constraints that can be applied on a specific edge in the resource graph
//
// Example
//
// To specify a constraint showing an edge must contain an intermediate node in its path, use the yaml below.
//
//- scope: edge
// operator: must_contain
// target:
// source: klotho:execution_unit:my_compute
// target: klotho:orm:my_orm
// node: aws:rds_proxy:my_proxy
//
// The end result of this should be a path of klotho:execution_unit:my_compute -> aws:rds_proxy:my_proxy -> klotho:orm:my_orm with N intermediate nodes to satisfy the path's expansion
EdgeConstraint struct {
Operator ConstraintOperator `yaml:"operator" json:"operator"`
Target Edge `yaml:"target" json:"target"`
Data construct.EdgeData `yaml:"data" json:"data"`
}
)
func (constraint *EdgeConstraint) Scope() ConstraintScope {
return EdgeConstraintScope
}
func (constraint *EdgeConstraint) IsSatisfied(ctx ConstraintGraph) bool {
src := constraint.Target.Source
dst := constraint.Target.Target
// If we receive an abstract construct, we need to find all resources that reference the abstract construct
//
// This relies on resources only referencing an abstract provider if they are the direct child of the abstract construct
// example
// when we expand execution unit, the lambda would reference the execution unit as a construct, but the role and other resources would reference the lambda
if constraint.Target.Source.IsAbstractResource() {
srcRes := ctx.GetConstructsResource(constraint.Target.Source)
if srcRes == nil {
return false
}
src = srcRes.ID
}
if constraint.Target.Target.IsAbstractResource() {
dstRes := ctx.GetConstructsResource(constraint.Target.Target)
if dstRes == nil {
return false
}
dst = dstRes.ID
}
paths, err := ctx.AllPaths(src, dst)
if err != nil {
return false
}
for _, path := range paths {
if constraint.checkSatisfication(path, ctx) {
return true
}
}
return false
}
func (constraint *EdgeConstraint) checkSatisfication(path []*construct.Resource, ctx ConstraintGraph) bool {
switch constraint.Operator {
case AddConstraintOperator, MustExistConstraintOperator:
return len(path) > 0
case RemoveConstraintOperator, MustNotExistConstraintOperator:
return len(path) == 0
}
return false
}
func (constraint *EdgeConstraint) Validate() error {
if constraint.Target.Source == constraint.Target.Target {
return fmt.Errorf("edge constraint must not have a source and target be the same node")
}
if (constraint.Target.Source == construct.ResourceId{} || constraint.Target.Target == construct.ResourceId{}) {
return fmt.Errorf("edge constraint must have a source and target defined")
}
return nil
}
func (constraint *EdgeConstraint) String() string {
return fmt.Sprintf("EdgeConstraint{Operator: %s, Target: %s}", constraint.Operator, constraint.Target)
}
package constraints
import (
"fmt"
construct "github.com/klothoplatform/klotho/pkg/construct"
)
type (
// OutputConstraint is a struct that represents a constraint exports some output from the resource graph
//
// Example
//
// To specify a constraint detailing application level intents in yaml
//
//- scope: output
// operator: add
// ref: aws:ec2:instance:my_instance#public_ip
// name: my_instance_public_ip
//
// The end result of this should be that the execution unit construct is added to the construct graph for processing
OutputConstraint struct {
Operator ConstraintOperator `yaml:"operator" json:"operator"`
Ref construct.PropertyRef `yaml:"ref" json:"ref"`
Name string `yaml:"name" json:"name"`
Value any `yaml:"value" json:"value"`
}
)
func (constraint *OutputConstraint) Scope() ConstraintScope {
return OutputConstraintScope
}
func (constraint *OutputConstraint) IsSatisfied(ctx ConstraintGraph) bool {
return true
}
func (constraint *OutputConstraint) Validate() error {
return nil
}
func (constraint *OutputConstraint) String() string {
return fmt.Sprintf("OutputConstraint: %s %s %s", constraint.Operator, constraint.Name, constraint.Ref)
}
package constraints
import (
"errors"
"fmt"
"reflect"
"strings"
construct "github.com/klothoplatform/klotho/pkg/construct"
)
type (
// ResourceConstraint is a struct that represents constraints that can be applied on a specific node in the resource graph.
// ResourceConstraints are used to control intrinsic properties of a resource in the resource graph
//
// Example
//
// To specify a constraint detailing a property of a resource in yaml
//
// - scope: resource
// operator: equals
// target: aws:rds_instance:my_instance
// property: InstanceClass
// value: db.t3.micro
//
// The end result of this should be that the the rds instance's InstanceClass property should be set to db.t3.micro
ResourceConstraint struct {
Operator ConstraintOperator `yaml:"operator" json:"operator"`
Target construct.ResourceId `yaml:"target" json:"target"`
Property string `yaml:"property" json:"property"`
Value any `yaml:"value" json:"value"`
}
)
func (constraint *ResourceConstraint) Scope() ConstraintScope {
return ResourceConstraintScope
}
func (constraint *ResourceConstraint) IsSatisfied(ctx ConstraintGraph) bool {
switch constraint.Operator {
case EqualsConstraintOperator:
res, _ := ctx.GetResource(constraint.Target)
if res == nil {
return false
}
strct := reflect.ValueOf(res)
for strct.Kind() == reflect.Ptr {
strct = strct.Elem()
}
val := strct.FieldByName(constraint.Property)
if !val.IsValid() {
// Try to find the field by its json or yaml tag (especially to handle case [upper/lower] [Pascal/snake])
// Replicated from resource_configuration.go#parseFieldName so there's no dependency
for i := 0; i < strct.NumField(); i++ {
field := strct.Type().Field(i)
if constraint.Property == strings.ToLower(field.Name) {
// When YAML marshalling fields that don't have a tag, they're just lower cased
// so this condition should catch those.
val = strct.Field(i)
break
}
tag := strings.Split(field.Tag.Get("json"), ",")[0]
if constraint.Property == tag {
val = strct.Field(i)
break
}
tag = strings.Split(field.Tag.Get("yaml"), ",")[0]
if constraint.Property == tag {
val = strct.Field(i)
break
}
}
if !val.IsValid() {
return false
}
}
return val.Interface() == constraint.Value
case AddConstraintOperator:
res, _ := ctx.GetResource(constraint.Target)
if res == nil {
return false
}
parent := reflect.ValueOf(res).Elem()
val := parent.FieldByName(constraint.Property)
if !val.IsValid() {
return false
}
switch val.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < val.Len(); i++ {
if val.Index(i).Interface() == constraint.Value {
return true
}
}
return false
}
}
return true
}
func (constraint *ResourceConstraint) Validate() error {
if constraint.Target.IsAbstractResource() {
return errors.New("node constraint cannot be applied to an abstract construct")
}
if constraint.Property == "" {
return errors.New("node constraint must have a property defined")
}
return nil
}
func (constraint *ResourceConstraint) String() string {
return fmt.Sprintf("ResourceConstraint: %s %s %s %v", constraint.Target, constraint.Property, constraint.Operator, constraint.Value)
}
package constructexpansion
import (
"errors"
"fmt"
"github.com/dominikbraun/graph"
"github.com/klothoplatform/klotho/pkg/collectionutil"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/constraints"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"go.uber.org/zap"
)
type (
ExpansionSet struct {
Construct *construct.Resource
Attributes []string
}
ExpansionSolution struct {
Edges []graph.Edge[construct.Resource]
DirectlyMappedResource construct.ResourceId
}
ConstructExpansionContext struct {
Construct *construct.Resource
Kb knowledgebase.TemplateKB
}
)
// ExpandConstructs expands all constructs in the working state using the engines provider
//
// The resources that result from the expanded constructs are written to the engines resource graph
// All dependencies are copied over to the resource graph
// If a dependency in the working state included a construct, the engine copies the dependency to all directly linked resources
func (ctx *ConstructExpansionContext) ExpandConstruct(res *construct.Resource, constraints []constraints.ConstructConstraint) ([]ExpansionSolution, error) {
if res.ID.IsAbstractResource() {
return nil, fmt.Errorf("unable to expand construct %s, resource is not an abstract construct", res.ID)
}
zap.S().Debugf("Expanding construct %s", res.ID)
constructType := ""
attributes := make(map[string]any)
for _, constructConstraint := range constraints {
if constructConstraint.Target == res.ID {
constructType = constructConstraint.Type
if constructType != "" && constructType != constructConstraint.Type {
return nil, fmt.Errorf("unable to expand construct %s, conflicting types in constraints", res.ID)
}
for k, v := range constructConstraint.Attributes {
if val, ok := attributes[k]; ok {
if v != val {
return nil, fmt.Errorf("unable to expand construct %s, attribute %s has conflicting values", res.ID, k)
}
}
attributes[k] = v
}
}
}
expansionSet := ExpansionSet{Construct: res}
for attribute := range attributes {
expansionSet.Attributes = append(expansionSet.Attributes, attribute)
}
return ctx.findPossibleExpansions(expansionSet, constructType)
}
func (ctx *ConstructExpansionContext) findPossibleExpansions(expansionSet ExpansionSet, constructQualifiedType string) ([]ExpansionSolution, error) {
var possibleExpansions []ExpansionSolution
var joinedErr error
functionality := knowledgebase.GetFunctionality(ctx.Kb, expansionSet.Construct.ID)
for _, res := range ctx.Kb.ListResources() {
if constructQualifiedType != "" && res.Id().QualifiedTypeName() != constructQualifiedType {
continue
}
classifications := res.Classification
if !collectionutil.Contains(classifications.Is, string(functionality)) {
continue
}
unsatisfiedAttributes := []string{}
for _, ms := range expansionSet.Attributes {
if !collectionutil.Contains(classifications.Is, ms) {
unsatisfiedAttributes = append(unsatisfiedAttributes, ms)
}
}
baseRes, err := knowledgebase.CreateResource(ctx.Kb, construct.ResourceId{
Provider: res.Id().Provider,
Type: res.Id().Type,
Name: expansionSet.Construct.ID.Name,
})
if err != nil {
joinedErr = errors.Join(joinedErr, err)
continue
}
expansions, err := ctx.findExpansions(unsatisfiedAttributes, []graph.Edge[construct.Resource](nil), *baseRes, functionality)
if err != nil {
joinedErr = errors.Join(joinedErr, err)
continue
}
for _, expansion := range expansions {
possibleExpansions = append(possibleExpansions, ExpansionSolution{Edges: expansion, DirectlyMappedResource: baseRes.ID})
}
}
if len(possibleExpansions) == 0 {
return nil, fmt.Errorf("no expansions found for attributes %v", expansionSet.Attributes)
}
return possibleExpansions, nil
}
// findExpansions finds all possible expansions for a given construct and a set of attributes
// It returns a list of all possible expansions by recursing down and calling itself until
func (ctx *ConstructExpansionContext) findExpansions(attributes []string, edges []graph.Edge[construct.Resource], baseResource construct.Resource, functionality knowledgebase.Functionality) ([][]graph.Edge[construct.Resource], error) {
if len(attributes) == 0 {
return [][]graph.Edge[construct.Resource]{edges}, nil
}
var result [][]graph.Edge[construct.Resource]
for _, attribute := range attributes {
for _, res := range ctx.Kb.ListResources() {
if res.Id().QualifiedTypeName() == baseResource.ID.QualifiedTypeName() {
continue
}
if ctx.Kb.HasFunctionalPath(baseResource.ID, res.Id()) {
if res.GivesAttributeForFunctionality(attribute, functionality) {
resource := construct.Resource{
ID: construct.ResourceId{Type: res.Id().Type, Name: baseResource.ID.Name, Provider: res.Id().Provider},
Properties: make(construct.Properties),
}
edges = append(edges, graph.Edge[construct.Resource]{Source: baseResource, Target: resource})
unsatisfiedAttributes := []string{}
for _, ms := range attributes {
if ms != attribute {
unsatisfiedAttributes = append(unsatisfiedAttributes, ms)
}
}
expansions, err := ctx.findExpansions(unsatisfiedAttributes, edges, baseResource, functionality)
if err != nil {
return nil, err
}
result = append(result, expansions...)
}
}
}
}
if len(result) == 0 {
return nil, fmt.Errorf("no expansions found for attributes %v", attributes)
}
return result, nil
}
package debug
import (
"context"
"os"
)
type contextKey string
var debugDirKey contextKey = "debugDir"
func GetDebugDir(ctx context.Context) string {
d := ctx.Value(debugDirKey)
if d == nil {
return os.Getenv("KLOTHO_DEBUG_DIR")
}
return d.(string)
}
func WithDebugDir(ctx context.Context, debugDir string) context.Context {
return context.WithValue(ctx, debugDirKey, debugDir)
}
package engine
import (
"bytes"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"sort"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/dot"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
func dotAttributes(kb knowledgebase.TemplateKB, r *construct.Resource, props graph.VertexProperties) map[string]string {
a := make(map[string]string)
for k, v := range props.Attributes {
a[k] = v
}
a["label"] = r.ID.String()
a["shape"] = "box"
tmpl, _ := kb.GetResourceTemplate(r.ID)
if tmpl != nil && len(tmpl.Classification.Is) > 0 {
a["label"] += fmt.Sprintf("\n%v", tmpl.Classification.Is)
}
return a
}
func dotEdgeAttributes(kb knowledgebase.TemplateKB, g construct.Graph, e construct.ResourceEdge) map[string]string {
a := make(map[string]string)
_ = e.Source.WalkProperties(func(path construct.PropertyPath, nerr error) error {
v, _ := path.Get()
if v == e.Target.ID {
a["label"] = path.String()
return construct.StopWalk
}
return nil
})
if e.Properties.Weight > 0 {
if a["label"] == "" {
a["label"] = fmt.Sprintf("%d", e.Properties.Weight)
} else {
a["label"] = fmt.Sprintf("%s\n%d", a["label"], e.Properties.Weight)
}
}
sideEffect, err := knowledgebase.IsOperationalResourceSideEffect(g, kb, e.Source.ID, e.Target.ID)
if err == nil && sideEffect {
a["color"] = "green"
}
return a
}
func GraphToDOT(kb knowledgebase.TemplateKB, g construct.Graph, out io.Writer) error {
ids, err := construct.TopologicalSort(g)
if err != nil {
return err
}
var errs []error
printf := func(s string, args ...any) {
_, err := fmt.Fprintf(out, s, args...)
if err != nil {
errs = append(errs, err)
}
}
printf(`digraph {
rankdir = TB
`)
for _, id := range ids {
n, props, err := g.VertexWithProperties(id)
if err != nil {
errs = append(errs, err)
continue
}
printf(" %q%s\n", n.ID, dot.AttributesToString(dotAttributes(kb, n, props)))
}
if err := errors.Join(errs...); err != nil {
return err
}
topoIndex := func(id construct.ResourceId) int {
for i, id2 := range ids {
if id2 == id {
return i
}
}
return -1
}
edges, err := g.Edges()
if err != nil {
return err
}
sort.Slice(edges, func(i, j int) bool {
ti, tj := topoIndex(edges[i].Source), topoIndex(edges[j].Source)
if ti != tj {
return ti < tj
}
ti, tj = topoIndex(edges[i].Target), topoIndex(edges[j].Target)
return ti < tj
})
for _, e := range edges {
edge, err := g.Edge(e.Source, e.Target)
if err != nil {
errs = append(errs, err)
continue
}
printf(" %q -> %q%s\n", e.Source, e.Target, dot.AttributesToString(dotEdgeAttributes(kb, g, edge)))
}
printf("}\n")
return errors.Join(errs...)
}
func GraphToSVG(kb knowledgebase.TemplateKB, g construct.Graph, prefix string) error {
if debugDir := os.Getenv("KLOTHO_DEBUG_DIR"); debugDir != "" {
prefix = filepath.Join(debugDir, prefix)
}
f, err := os.Create(prefix + ".gv")
if err != nil {
return err
}
defer f.Close()
dotContent := new(bytes.Buffer)
err = GraphToDOT(kb, g, io.MultiWriter(f, dotContent))
if err != nil {
return fmt.Errorf("could not render graph to file %s: %v", prefix+".gv", err)
}
svgContent, err := dot.ExecPan(bytes.NewReader(dotContent.Bytes()))
if err != nil {
return fmt.Errorf("could not run 'dot' for %s: %v", prefix+".gv", err)
}
svgFile, err := os.Create(prefix + ".gv.svg")
if err != nil {
return fmt.Errorf("could not create file %s: %v", prefix+".gv.svg", err)
}
defer svgFile.Close()
_, err = fmt.Fprint(svgFile, svgContent)
return err
}
package engine
import (
"context"
"os"
"slices"
"sync"
"time"
"github.com/klothoplatform/klotho/pkg/engine/constraints"
"github.com/klothoplatform/klotho/pkg/engine/solution"
"github.com/alitto/pond"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/path_selection"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
)
type (
GetValidEdgeTargetsConfig struct {
Resources struct {
Sources []construct.ResourceId
Targets []construct.ResourceId
}
ResourceTypes struct {
Sources []construct.ResourceId
Targets []construct.ResourceId
}
Tags []Tag
}
GetPossibleEdgesContext struct {
InputGraph []byte
GetValidEdgeTargetsConfig
}
)
// EdgeCanBeExpanded returns true if there is a set of kb paths between the source and target
// that satisfies path satisfaction classifications for both the source and target.
//
// This is used to determine (on a best-effort basis) if an edge can be expanded
// without fully solving the graph (which is expensive).
func (e *Engine) EdgeCanBeExpanded(sol *engineSolution, source construct.ResourceId, target construct.ResourceId) (result bool, cacheable bool, err error) {
cacheable = true
edgeExpander := path_selection.EdgeExpand{Ctx: sol}
if source.Matches(target) {
return false, cacheable, nil
}
satisfactions, err := e.Kb.GetPathSatisfactionsFromEdge(source, target)
if err != nil {
return false, cacheable, err
}
sourceSatisfactionCount := 0
targetSatisfactionCount := 0
for _, satisfaction := range satisfactions {
if satisfaction.Source.Classification != "" {
sourceSatisfactionCount++
}
if satisfaction.Target.Classification != "" {
targetSatisfactionCount++
}
}
if sourceSatisfactionCount == 0 || targetSatisfactionCount == 0 {
return false, cacheable, nil
}
for _, satisfaction := range satisfactions {
classification := satisfaction.Classification
if classification == "" {
continue
}
var sourceReferencedResources []construct.ResourceId
var targetReferencedResources []construct.ResourceId
if satisfaction.Source.PropertyReference != "" {
cacheable = false
sourceReferencedResources, err = solution.GetResourcesFromPropertyReference(sol, source, satisfaction.Source.PropertyReference)
if len(sourceReferencedResources) == 0 || err != nil {
continue // ignore satisfaction if we can't resolve the property reference
}
}
if satisfaction.Target.PropertyReference != "" {
cacheable = false
targetReferencedResources, err = solution.GetResourcesFromPropertyReference(sol, target, satisfaction.Target.PropertyReference)
if len(targetReferencedResources) == 0 || err != nil {
continue // ignore satisfaction if we can't resolve the property reference
}
}
tempSource := source
if len(sourceReferencedResources) > 0 {
tempSource = sourceReferencedResources[len(sourceReferencedResources)-1]
}
tempTarget := target
if len(targetReferencedResources) > 0 {
tempTarget = targetReferencedResources[len(targetReferencedResources)-1]
}
tempGraph, err := path_selection.BuildPathSelectionGraph(
sol.Context(),
construct.SimpleEdge{
Source: tempSource,
Target: tempTarget,
}, sol.KnowledgeBase(), classification, false)
if err != nil {
return false, cacheable, err
}
tempSourceResource, err := tempGraph.Vertex(tempSource)
if err != nil {
continue
}
tempTargetResource, err := tempGraph.Vertex(tempTarget)
if err != nil {
continue
}
_, err = edgeExpander.ExpandEdge(path_selection.ExpansionInput{
SatisfactionEdge: construct.ResourceEdge{
Source: tempSourceResource,
Target: tempTargetResource,
},
Classification: classification,
TempGraph: tempGraph,
})
if err != nil {
return false, cacheable, err
}
}
return true, cacheable, nil
}
func ReadGetValidEdgeTargetsConfig(path string) (GetValidEdgeTargetsConfig, error) {
var config GetValidEdgeTargetsConfig
yamlFile, err := os.ReadFile(path)
if err != nil {
return config, err
}
err = yaml.Unmarshal(yamlFile, &config)
if err != nil {
return config, err
}
return config, nil
}
/*
GetValidEdgeTargets returns a map of valid edge targets for each source resource in the supplied graph.
The returned map is keyed by the source resource's string representation.
The value for each source resource is a list of valid target resources.
Targets are considered valid if there is a set of kb paths between the source and target
that satisfies both source and target path satisfaction classifications.
A partial set of valid targets can be generated using the filter criteria in the context's config.
*/
func (e *Engine) GetValidEdgeTargets(req *GetPossibleEdgesContext) (map[string][]string, error) {
inputGraph, err := unmarshallInputGraph(req.InputGraph)
if err != nil {
return nil, err
}
solutionCtx := NewSolution(context.TODO(), e.Kb, "", &constraints.Constraints{})
err = solutionCtx.LoadGraph(inputGraph)
if err != nil {
return nil, err
}
topologyGraph, err := e.GetViewsDag(DataflowView, solutionCtx)
if err != nil {
return nil, err
}
var sources []construct.ResourceId
var targets []construct.ResourceId
qualifiedTypeMatcher := func(id construct.ResourceId) func(otherType construct.ResourceId) bool {
return func(otherType construct.ResourceId) bool {
return otherType.QualifiedTypeName() == id.QualifiedTypeName()
}
}
// filter resources based on the context
ids, err := construct.TopologicalSort(topologyGraph)
if err != nil {
return nil, err
}
for _, id := range ids {
tag := GetResourceVizTag(e.Kb, DataflowView, id)
if len(req.Tags) > 0 && !slices.Contains(req.Tags, tag) {
continue
}
isSource := true
isTarget := true
if len(req.Resources.Sources) > 0 && !slices.Contains(req.Resources.Sources, id) {
isSource = false
}
if len(req.Resources.Targets) > 0 && !slices.Contains(req.Resources.Targets, id) {
isTarget = false
}
if len(req.ResourceTypes.Sources) > 0 && !slices.ContainsFunc(req.ResourceTypes.Sources, qualifiedTypeMatcher(id)) {
isSource = false
}
if len(req.ResourceTypes.Targets) > 0 && !slices.ContainsFunc(req.ResourceTypes.Targets, qualifiedTypeMatcher(id)) {
isTarget = false
}
if isSource {
sources = append(sources, id)
}
if isTarget {
targets = append(targets, id)
}
}
results := make(chan *edgeValidity)
//var detectionGroup sync.WaitGroup
checkerPool := pond.New(5, 1000, pond.Strategy(pond.Lazy()))
knownTargetValidity := make(map[string]map[string]bool)
rwLock := &sync.RWMutex{}
// get all valid-edge combinations for resource types in the supplied graph
for _, s := range sources {
for _, t := range targets {
source := s
target := t
if source.Matches(target) {
continue
}
if source.Namespace == target.Name || target.Namespace == source.Name {
continue
}
path, err := graph.ShortestPath(topologyGraph, source, target)
if len(path) > 0 && err == nil {
continue
}
checkerPool.Submit(func() {
// check if we already know the validity of this edge
sourceType := source.QualifiedTypeName()
targetType := target.QualifiedTypeName()
//
isValid := false
previouslCached := false
rwLock.RLock()
if _, ok := knownTargetValidity[sourceType]; ok {
if isValid, ok = knownTargetValidity[sourceType][targetType]; ok {
previouslCached = true
}
}
rwLock.RUnlock()
cacheable := false
// only evaluate the edge if we haven't already done so for the same source and target types
if !previouslCached {
isValid, cacheable, _ = e.EdgeCanBeExpanded(solutionCtx, source, target)
} else {
zap.S().Debugf("Using cached result for %s -> %s: %t", source, target, isValid)
}
zap.S().Debugf("valid target: %s -> %s: %t", source, target, isValid)
results <- &edgeValidity{
Source: source,
Target: target,
IsValid: isValid,
}
if previouslCached {
return
}
// cache the result, so we don't have to recompute it for the same source and target types
// performance benefit is unclear given potential lock contention between goroutines
if cacheable {
rwLock.Lock()
if _, ok := knownTargetValidity[sourceType]; !ok {
knownTargetValidity[sourceType] = make(map[string]bool)
}
knownTargetValidity[sourceType][targetType] = isValid
rwLock.Unlock()
}
})
}
}
output := make(map[string][]string, len(sources))
var processResultsGroup sync.WaitGroup
processResultsGroup.Add(1)
go func() {
defer processResultsGroup.Done()
for result := range results {
if result.IsValid {
if _, ok := output[result.Source.String()]; !ok {
output[result.Source.String()] = []string{}
}
output[result.Source.String()] = append(output[result.Source.String()], result.Target.String())
}
}
}()
checkerPool.StopAndWaitFor(60 * time.Second)
close(results)
processResultsGroup.Wait()
return output, nil
}
func unmarshallInputGraph(input []byte) (construct.Graph, error) {
var yamlGraph construct.YamlGraph
err := yaml.Unmarshal(input, &yamlGraph)
if err != nil {
return nil, err
}
return yamlGraph.Graph, nil
}
type edgeValidity struct {
Source construct.ResourceId
Target construct.ResourceId
IsValid bool
}
package engine
import (
"context"
"fmt"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/constraints"
"github.com/klothoplatform/klotho/pkg/engine/solution"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"gopkg.in/yaml.v3"
)
type (
// Engine is a struct that represents the object which processes the resource graph and applies constraints
Engine struct {
Kb knowledgebase.TemplateKB
}
// SolveRequest is a struct that represents the context of the engine
// The context is used to store the state of the engine
SolveRequest struct {
Constraints constraints.Constraints
InitialState construct.Graph
GlobalTag string
}
)
func NewEngine(kb knowledgebase.TemplateKB) *Engine {
return &Engine{
Kb: kb,
}
}
func (e *Engine) Run(ctx context.Context, req *SolveRequest) (solution.Solution, error) {
sol := NewSolution(ctx, e.Kb, req.GlobalTag, &req.Constraints)
err := sol.LoadGraph(req.InitialState)
if err != nil {
return sol, err
}
err = ApplyConstraints(sol)
if err != nil {
return sol, err
}
err = sol.Solve()
return sol, err
}
func (req SolveRequest) MarshalYAML() (interface{}, error) {
var initState yaml.Node
if err := initState.Encode(construct.YamlGraph{Graph: req.InitialState}); err != nil {
return nil, fmt.Errorf("failed to marshal initial state: %w", err)
}
var constraints yaml.Node
if err := constraints.Encode(req.Constraints); err != nil {
return nil, fmt.Errorf("failed to marshal constraints: %w", err)
}
content := make([]*yaml.Node, 0, len(constraints.Content)+len(initState.Content))
content = append(content,
&yaml.Node{Kind: yaml.ScalarNode, Value: "constraints"},
&constraints,
)
content = append(content, initState.Content...)
return yaml.Node{
Kind: yaml.MappingNode,
Content: content,
}, nil
}
package enginetesting
import (
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/stretchr/testify/mock"
)
type MockKB struct {
mock.Mock
}
func (m *MockKB) ListResources() []*knowledgebase.ResourceTemplate {
args := m.Called()
return args.Get(0).([]*knowledgebase.ResourceTemplate)
}
func (m *MockKB) GetModel(model string) *knowledgebase.Model {
args := m.Called(model)
return args.Get(0).(*knowledgebase.Model)
}
func (m *MockKB) Edges() ([]graph.Edge[*knowledgebase.ResourceTemplate], error) {
args := m.Called()
return args.Get(0).([]graph.Edge[*knowledgebase.ResourceTemplate]), args.Error(1)
}
func (m *MockKB) AddResourceTemplate(template *knowledgebase.ResourceTemplate) error {
args := m.Called(template)
return args.Error(0)
}
func (m *MockKB) AddEdgeTemplate(template *knowledgebase.EdgeTemplate) error {
args := m.Called(template)
return args.Error(0)
}
func (m *MockKB) GetResourceTemplate(id construct.ResourceId) (*knowledgebase.ResourceTemplate, error) {
args := m.Called(id)
return args.Get(0).(*knowledgebase.ResourceTemplate), args.Error(1)
}
func (m *MockKB) GetEdgeTemplate(from, to construct.ResourceId) *knowledgebase.EdgeTemplate {
args := m.Called(from, to)
return args.Get(0).(*knowledgebase.EdgeTemplate)
}
func (m *MockKB) HasDirectPath(from, to construct.ResourceId) bool {
args := m.Called(from, to)
return args.Bool(0)
}
func (m *MockKB) HasFunctionalPath(from, to construct.ResourceId) bool {
args := m.Called(from, to)
return args.Bool(0)
}
func (m *MockKB) AllPaths(from, to construct.ResourceId) ([][]*knowledgebase.ResourceTemplate, error) {
args := m.Called(from, to)
return args.Get(0).([][]*knowledgebase.ResourceTemplate), args.Error(1)
}
func (m *MockKB) GetAllowedNamespacedResourceIds(
ctx knowledgebase.DynamicValueContext,
resourceId construct.ResourceId,
) ([]construct.ResourceId, error) {
args := m.Called(ctx, resourceId)
return args.Get(0).([]construct.ResourceId), args.Error(1)
}
func (m *MockKB) GetFunctionality(id construct.ResourceId) knowledgebase.Functionality {
args := m.Called(id)
return args.Get(0).(knowledgebase.Functionality)
}
func (m *MockKB) GetClassification(id construct.ResourceId) knowledgebase.Classification {
args := m.Called(id)
return args.Get(0).(knowledgebase.Classification)
}
func (m *MockKB) GetResourcesNamespaceResource(resource *construct.Resource) (construct.ResourceId, error) {
args := m.Called(resource)
return args.Get(0).(construct.ResourceId), args.Error(1)
}
func (m *MockKB) GetResourcePropertyType(resource construct.ResourceId, propertyName string) string {
args := m.Called(resource, propertyName)
return args.String(0)
}
func (m *MockKB) GetPathSatisfactionsFromEdge(
source, target construct.ResourceId,
) ([]knowledgebase.EdgePathSatisfaction, error) {
args := m.Called(source, target)
return args.Get(0).([]knowledgebase.EdgePathSatisfaction), args.Error(1)
}
package enginetesting
import (
"context"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/constraints"
"github.com/klothoplatform/klotho/pkg/engine/solution"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/stretchr/testify/mock"
)
type MockSolution struct {
mock.Mock
KB MockKB
}
func (m *MockSolution) Context() context.Context {
// context is not used for any computation-critical operations so for ease of use, don't mock it
return context.Background()
}
func (m *MockSolution) KnowledgeBase() knowledgebase.TemplateKB {
args := m.Called()
return args.Get(0).(knowledgebase.TemplateKB)
}
func (m *MockSolution) Constraints() *constraints.Constraints {
args := m.Called()
return args.Get(0).(*constraints.Constraints)
}
func (m *MockSolution) RecordDecision(d solution.SolveDecision) {
m.Called(d)
}
func (m *MockSolution) GetDecisions() []solution.SolveDecision {
args := m.Called()
return args.Get(0).([]solution.SolveDecision)
}
func (m *MockSolution) DataflowGraph() construct.Graph {
args := m.Called()
return args.Get(0).(construct.Graph)
}
func (m *MockSolution) DeploymentGraph() construct.Graph {
args := m.Called()
return args.Get(0).(construct.Graph)
}
func (m *MockSolution) OperationalView() solution.OperationalView {
args := m.Called()
return args.Get(0).(solution.OperationalView)
}
func (m *MockSolution) RawView() construct.Graph {
args := m.Called()
return args.Get(0).(construct.Graph)
}
func (m *MockSolution) GlobalTag() string {
args := m.Called()
return args.String(0)
}
func (m *MockSolution) Outputs() map[string]construct.Output {
args := m.Called()
return args.Get(0).(map[string]construct.Output)
}
package enginetesting
import (
"context"
"testing"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/construct/graphtest"
"github.com/klothoplatform/klotho/pkg/engine/constraints"
"github.com/klothoplatform/klotho/pkg/engine/solution"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/stretchr/testify/mock"
)
type TestSolution struct {
mock.Mock
KB MockKB
Constr constraints.Constraints
dataflow, deployment construct.Graph
}
func NewTestSolution() *TestSolution {
sol := &TestSolution{
dataflow: construct.NewGraph(),
deployment: construct.NewAcyclicGraph(),
}
return sol
}
func (sol *TestSolution) Context() context.Context {
return context.Background()
}
func (sol *TestSolution) UseEmptyTemplates() {
sol.KB.On("GetResourceTemplate", mock.Anything).Return(&knowledgebase.ResourceTemplate{}, nil)
sol.KB.On("GetEdgeTemplate", mock.Anything, mock.Anything).Return(&knowledgebase.EdgeTemplate{}, nil)
}
func (sol *TestSolution) LoadState(t *testing.T, initGraph ...any) {
graphtest.MakeGraph(t, sol.RawView(), initGraph...)
// Start recording changes after initial graph is loaded.
sol.dataflow = graphtest.RecordChanges(sol.dataflow)
sol.deployment = graphtest.RecordChanges(sol.deployment)
}
func (sol *TestSolution) DataflowChanges() *graphtest.GraphChanges {
return sol.dataflow.(*graphtest.GraphChanges)
}
func (sol *TestSolution) DeploymentChanges() *graphtest.GraphChanges {
return sol.deployment.(*graphtest.GraphChanges)
}
func (sol *TestSolution) KnowledgeBase() knowledgebase.TemplateKB {
return &sol.KB
}
func (sol *TestSolution) Constraints() *constraints.Constraints {
return &sol.Constr
}
func (sol *TestSolution) RecordDecision(d solution.SolveDecision) {}
func (sol *TestSolution) GetDecisions() []solution.SolveDecision {
return nil
}
func (sol *TestSolution) DataflowGraph() construct.Graph {
return sol.dataflow
}
func (sol *TestSolution) DeploymentGraph() construct.Graph {
return sol.deployment
}
func (sol *TestSolution) OperationalView() solution.OperationalView {
return testOperationalView{Graph: sol.RawView(), Mock: &sol.Mock}
}
func (sol *TestSolution) RawView() construct.Graph {
return solution.NewRawView(sol)
}
func (sol *TestSolution) GlobalTag() string {
return "test"
}
func (sol *TestSolution) Outputs() map[string]construct.Output {
return nil
}
type testOperationalView struct {
construct.Graph
Mock *mock.Mock
}
func (view testOperationalView) MakeResourcesOperational(resources []*construct.Resource) error {
args := view.Mock.Called(resources)
return args.Error(0)
}
func (view testOperationalView) UpdateResourceID(oldId, newId construct.ResourceId) error {
args := view.Mock.Called(oldId, newId)
return args.Error(0)
}
func (view testOperationalView) MakeEdgesOperational(edges []construct.Edge) error {
args := view.Mock.Called(edges)
return args.Error(0)
}
type ExpectedGraphs struct {
Dataflow, Deployment []any
}
func (expect ExpectedGraphs) AssertEqual(t *testing.T, sol solution.Solution) {
if expect.Dataflow != nil {
graphtest.AssertGraphEqual(t,
graphtest.MakeGraph(t, construct.NewGraph(), expect.Dataflow...),
sol.DataflowGraph(),
"Dataflow",
)
}
if expect.Deployment != nil {
graphtest.AssertGraphEqual(t,
graphtest.MakeGraph(t, construct.NewGraph(), expect.Deployment...),
sol.DeploymentGraph(),
"Deployment",
)
}
}
package engine_errs
import (
"fmt"
"strings"
)
type (
ErrorTree struct {
Chain []string `json:"chain,omitempty"`
Children []ErrorTree `json:"children,omitempty"`
}
chainErr interface {
error
Unwrap() error
}
joinErr interface {
error
Unwrap() []error
}
)
func unwrapChain(err error) (chain []string, last joinErr) {
for current := err; current != nil; {
var next error
cc, ok := current.(chainErr)
if ok {
next = cc.Unwrap()
} else {
joined, ok := current.(joinErr)
if ok {
jerrs := joined.Unwrap()
if len(jerrs) == 1 {
next = jerrs[0]
} else {
last = joined
return
}
} else {
chain = append(chain, current.Error())
return
}
}
msg := strings.TrimSuffix(strings.TrimSuffix(current.Error(), next.Error()), ": ")
if msg != "" {
chain = append(chain, msg)
}
current = next
}
return
}
func ErrorsToTree(err error) (tree ErrorTree) {
if err == nil {
return
}
if t, ok := err.(ErrorTree); ok {
return t
}
var joined joinErr
tree.Chain, joined = unwrapChain(err)
if joined != nil {
errs := joined.Unwrap()
tree.Children = make([]ErrorTree, len(errs))
for i, e := range errs {
tree.Children[i] = ErrorsToTree(e)
}
}
return
}
func (t ErrorTree) Error() string {
sb := &strings.Builder{}
t.print(sb, 0, 0)
return sb.String()
}
func (t ErrorTree) print(out *strings.Builder, indent int, childChar rune) {
prefix := strings.Repeat("\t", indent)
delim := ""
if childChar != 0 {
delim = string(childChar) + " "
}
fmt.Fprintf(out, "%s%s%v\n", prefix, delim, t.Chain)
for i, child := range t.Children {
char := '├'
if i == len(t.Children)-1 {
char = '└'
}
child.print(out, indent+1, char)
}
}
package engine_errs
import (
"fmt"
construct "github.com/klothoplatform/klotho/pkg/construct"
)
type (
EngineError interface {
error
// ToJSONMap returns a map that can be marshaled to JSON. Uses this instead of MarshalJSON to avoid
// repeated marshalling of common fields (such as 'error_code') and to allow for consistent formatting
// (eg for pretty-print).
ToJSONMap() map[string]any
ErrorCode() ErrorCode
}
ErrorCode string
)
const (
InternalErrCode ErrorCode = "internal"
ConfigInvalidCode ErrorCode = "config_invalid"
EdgeInvalidCode ErrorCode = "edge_invalid"
EdgeUnsupportedCode ErrorCode = "edge_unsupported"
)
type InternalError struct {
Err error
}
func (e InternalError) Error() string {
return fmt.Sprintf("internal error: %v", e.Err)
}
func (e InternalError) ErrorCode() ErrorCode {
return InternalErrCode
}
func (e InternalError) ToJSONMap() map[string]any {
return map[string]any{}
}
func (e InternalError) Unwrap() error {
return e.Err
}
type UnsupportedExpansionErr struct {
// ExpandEdge is the overall edge that is being expanded
ExpandEdge construct.SimpleEdge
// SatisfactionEdge is the specific edge that was being expanded when the error occurred
SatisfactionEdge construct.SimpleEdge
Classification string
}
func (e UnsupportedExpansionErr) Error() string {
if e.SatisfactionEdge.Source.IsZero() || e.ExpandEdge == e.SatisfactionEdge {
return fmt.Sprintf("unsupported expansion %s in %s", e.ExpandEdge, e.Classification)
}
return fmt.Sprintf(
"while expanding %s, unsupported expansion of %s in %s",
e.ExpandEdge,
e.Classification,
e.SatisfactionEdge,
)
}
func (e UnsupportedExpansionErr) ErrorCode() ErrorCode {
return EdgeUnsupportedCode
}
func (e UnsupportedExpansionErr) ToJSONMap() map[string]any {
m := map[string]any{
"satisfaction_edge": e.SatisfactionEdge,
}
if !e.ExpandEdge.Source.IsZero() {
m["expand_edge"] = e.ExpandEdge
}
if e.Classification != "" {
m["classification"] = e.Classification
}
return m
}
type InvalidPathErr struct {
// ExpandEdge is the overall edge that is being expanded
ExpandEdge construct.SimpleEdge
// SatisfactionEdge is the specific edge that was being expanded when the error occurred
SatisfactionEdge construct.SimpleEdge
Classification string
}
func (e InvalidPathErr) Error() string {
if e.SatisfactionEdge.Source.IsZero() || e.ExpandEdge == e.SatisfactionEdge {
return fmt.Sprintf("invalid expansion %s in %s", e.ExpandEdge, e.Classification)
}
return fmt.Sprintf(
"while expanding %s, invalid expansion of %s in %s",
e.ExpandEdge,
e.Classification,
e.SatisfactionEdge,
)
}
func (e InvalidPathErr) ErrorCode() ErrorCode {
return EdgeInvalidCode
}
func (e InvalidPathErr) ToJSONMap() map[string]any {
m := map[string]any{
"satisfaction_edge": e.SatisfactionEdge,
}
if !e.ExpandEdge.Source.IsZero() {
m["expand_edge"] = e.ExpandEdge
}
if e.Classification != "" {
m["classification"] = e.Classification
}
return m
}
package operational_eval
import (
"bytes"
"fmt"
"io"
"os"
"path/filepath"
"sync"
"github.com/dominikbraun/graph"
"github.com/klothoplatform/klotho/pkg/dot"
"github.com/klothoplatform/klotho/pkg/engine/debug"
"github.com/klothoplatform/klotho/pkg/logging"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
)
const (
attribAddedIn = "added_in"
attribError = "error"
attribReady = "ready"
attribAddedBy = "added_by"
attribDuration = "duration"
)
func PrintGraph(g Graph) {
topo, err := graph.TopologicalSort(g)
if err != nil {
zap.S().Errorf("could not topologically sort graph: %v", err)
return
}
adj, err := g.AdjacencyMap()
if err != nil {
zap.S().Errorf("could not get adjacency map: %v", err)
return
}
for _, v := range topo {
for dep := range adj[v] {
fmt.Printf("-> %s\n", dep)
}
}
}
func (eval *Evaluator) writeGraph(prefix string) {
if debugDir := debug.GetDebugDir(eval.Solution.Context()); debugDir != "" {
prefix = filepath.Join(debugDir, prefix)
} else {
return
}
log := logging.GetLogger(eval.Solution.Context()).Sugar()
if err := os.MkdirAll(filepath.Dir(prefix), 0755); err != nil {
log.Errorf("could not create debug directory %s: %v", filepath.Dir(prefix), err)
return
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
writeGraph(eval, prefix, graphToClusterDOT)
}()
go func() {
defer wg.Done()
writeGraph(eval, prefix+"_flat", graphToDOT)
}()
wg.Wait()
}
func writeGraph(eval *Evaluator, filename string, toDot func(*Evaluator, io.Writer) error) {
log := logging.GetLogger(eval.Solution.Context()).Sugar()
f, err := os.Create(filename + ".gv")
if err != nil {
log.Errorf("could not create file %s: %v", filename, err)
return
}
defer f.Close()
dotContent := new(bytes.Buffer)
err = toDot(eval, io.MultiWriter(f, dotContent))
if err != nil {
log.Errorf("could not render graph to file %s: %v", filename, err)
return
}
svgContent, err := dot.ExecPan(bytes.NewReader(dotContent.Bytes()))
if err != nil {
log.Errorf("could not run 'dot' for %s: %v", filename, err)
return
}
svgFile, err := os.Create(filename + ".gv.svg")
if err != nil {
log.Errorf("could not create file %s: %v", filename, err)
return
}
defer svgFile.Close()
fmt.Fprint(svgFile, svgContent)
}
func (eval *Evaluator) writeExecOrder() {
path := "exec-order.yaml"
if debugDir := debug.GetDebugDir(eval.Solution.Context()); debugDir != "" {
path = filepath.Join(debugDir, path)
} else {
return
}
log := logging.GetLogger(eval.Solution.Context()).Sugar()
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
log.Errorf("could not create debug directory %s: %v", filepath.Dir(path), err)
return
}
f, err := os.Create(path)
if err != nil {
log.Errorf("could not create file %s: %v", path, err)
return
}
defer f.Close()
order := make([][]string, len(eval.evaluatedOrder))
for i, group := range eval.evaluatedOrder {
order[i] = make([]string, len(group))
for j, key := range group {
order[i][j] = key.String()
}
}
err = yaml.NewEncoder(f).Encode(order)
if err != nil {
log.Errorf("could not write exec order to file %s: %v", path, err)
}
}
package operational_eval
import (
"errors"
"fmt"
"io"
"strings"
"text/template"
construct "github.com/klothoplatform/klotho/pkg/construct"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
//go:generate mockgen -source=./dependency_capture.go --destination=./dependency_capture_mock_test.go --package=operational_eval
type (
dependencyCapturer interface {
ExecuteOpRule(data knowledgebase.DynamicValueData, rule knowledgebase.OperationalRule) error
ExecutePropertyRule(data knowledgebase.DynamicValueData, rule knowledgebase.PropertyRule) error
DAG() construct.Graph
KB() knowledgebase.TemplateKB
ExecuteDecode(tmpl string, data knowledgebase.DynamicValueData, value interface{}) error
GetChanges() graphChanges
}
// fauxConfigContext acts like a [knowledgebase.DynamicValueContext] but replaces the [FieldValue] function
// with one that just returns the zero value of the property type and records the property reference.
fauxConfigContext struct {
propRef construct.PropertyRef
inner knowledgebase.DynamicValueContext
changes graphChanges
src Key
}
)
func newDepCapture(inner knowledgebase.DynamicValueContext, changes graphChanges, src Key) *fauxConfigContext {
return &fauxConfigContext{
propRef: src.Ref,
inner: inner,
changes: changes,
src: src,
}
}
func (ctx *fauxConfigContext) addRef(ref construct.PropertyRef) {
ctx.changes.addEdge(ctx.src, Key{Ref: ref})
}
func (ctx *fauxConfigContext) addGraphState(v *graphStateVertex) {
ctx.changes.nodes[v.Key()] = v
ctx.changes.addEdge(ctx.src, v.Key())
}
func (ctx *fauxConfigContext) DAG() construct.Graph {
return ctx.inner.DAG()
}
func (ctx *fauxConfigContext) KB() knowledgebase.TemplateKB {
return ctx.inner.KB()
}
func (ctx *fauxConfigContext) GetChanges() graphChanges {
return ctx.changes
}
func (ctx *fauxConfigContext) ExecuteDecode(tmpl string, data knowledgebase.DynamicValueData, value interface{}) error {
t, err := template.New("config").Funcs(ctx.TemplateFunctions()).Parse(tmpl)
if err != nil {
return fmt.Errorf("could not parse template: %w", err)
}
_ = ctx.inner.ExecuteTemplateDecode(t, data, value)
return nil
}
func (ctx *fauxConfigContext) ExecuteValue(v any, data knowledgebase.DynamicValueData) {
_, _ = knowledgebase.TransformToPropertyValue(ctx.propRef.Resource, ctx.propRef.Property, v, ctx, data)
}
func (ctx *fauxConfigContext) Execute(v any, data knowledgebase.DynamicValueData) error {
vStr, ok := v.(string)
if !ok {
return nil
}
tmpl, err := template.New(ctx.propRef.String()).Funcs(ctx.TemplateFunctions()).Parse(vStr)
if err != nil {
return fmt.Errorf("could not parse template %w", err)
}
// Ignore execution errors for when the zero value is invalid due to other assumptions
// if there is an error with the template, this will be caught later when actually processing it.
_ = tmpl.Execute(
io.Discard, // we don't care about the results, just the side effect of appending to propCtx.refs
data,
)
return nil
}
func (ctx *fauxConfigContext) DecodeConfigRef(
data knowledgebase.DynamicValueData,
rule knowledgebase.ConfigurationRule,
) (construct.PropertyRef, error) {
var ref construct.PropertyRef
err := ctx.ExecuteDecode(rule.Config.Field, data, &ref.Property)
if err != nil {
return ref, fmt.Errorf("could not execute field template: %w", err)
}
err = ctx.ExecuteDecode(rule.Resource, data, &ref.Resource)
if err != nil {
return ref, fmt.Errorf("could not execute resource template: %w", err)
}
return ref, nil
}
func (ctx *fauxConfigContext) ExecuteOpRule(
data knowledgebase.DynamicValueData,
opRule knowledgebase.OperationalRule,
) error {
var errs error
exec := func(v any) {
errs = errors.Join(errs, ctx.Execute(v, data))
}
originalSrc := ctx.src
for _, rule := range opRule.ConfigurationRules {
if rule.Resource != "" {
ref, err := ctx.DecodeConfigRef(data, rule)
if err != nil {
errs = errors.Join(errs, err)
continue
}
if ref.Resource.IsZero() || ref.Property == "" {
// Can't determine the ref yet, continue
// NOTE(gg): It's possible that whatever this will eventually resolve to
// would get evaluated before this has a chance to add the dependency.
// If that ever occurs, we may need to add speculative dependencies
// for all refs that could match this.
continue
}
// Check to see if we're setting a list element's property
// If we are, we need to depend on the list resolving first.
res, err := ctx.DAG().Vertex(ref.Resource)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("could not find rule's resource: %w", err))
continue
}
_, err = res.GetProperty(ref.Property)
if err != nil {
if bracketIdx := strings.Index(ref.Property, "["); bracketIdx != -1 {
listRef := ref
listRef.Property = ref.Property[:bracketIdx]
ctx.addRef(listRef)
} else {
errs = errors.Join(errs, fmt.Errorf("could not find rule's property: %w", err))
continue
}
} else {
// set the source to the ref that is being configured, not necessarily the key that dependencies are being
// calculated for, but only when the reference exists
ctx.src = Key{Ref: ref}
}
}
exec(opRule.If)
ctx.ExecuteValue(rule.Config.Value, data)
if ctx.src != originalSrc {
// Make sure the configured property depends on the edge
ctx.changes.addEdge(ctx.src, originalSrc)
// reset inside the loop in case the next rule doesn't specify the ref
ctx.src = originalSrc
}
}
if len(opRule.Steps) > 0 {
exec(opRule.If)
}
for _, step := range opRule.Steps {
errs = errors.Join(errs, ctx.executeOpStep(data, step))
}
return errs
}
func (ctx *fauxConfigContext) ExecutePropertyRule(
data knowledgebase.DynamicValueData,
propRule knowledgebase.PropertyRule,
) error {
var errs error
exec := func(v any) {
errs = errors.Join(errs, ctx.Execute(v, data))
}
exec(propRule.If)
if propRule.Value != nil {
ctx.ExecuteValue(propRule.Value, data)
}
errs = errors.Join(errs, ctx.executeOpStep(data, propRule.Step))
return errs
}
func (ctx *fauxConfigContext) executeOpStep(
data knowledgebase.DynamicValueData,
step knowledgebase.OperationalStep,
) error {
var errs error
exec := func(v any) {
errs = errors.Join(errs, ctx.Execute(v, data))
}
for _, stepRes := range step.Resources {
exec(stepRes.Selector)
for _, propValue := range stepRes.Properties {
exec(propValue)
}
}
return errs
}
func (ctx *fauxConfigContext) TemplateFunctions() template.FuncMap {
funcs := ctx.inner.TemplateFunctions()
funcs["hasField"] = ctx.HasField
funcs["fieldValue"] = ctx.FieldValue
funcs["hasUpstream"] = ctx.HasUpstream
funcs["upstream"] = ctx.Upstream
funcs["allUpstream"] = ctx.AllUpstream
funcs["hasDownstream"] = ctx.HasDownstream
funcs["downstream"] = ctx.Downstream
funcs["closestDownstream"] = ctx.ClosestDownstream
funcs["allDownstream"] = ctx.AllDownstream
return funcs
}
func (ctx *fauxConfigContext) HasField(field string, resource any) (bool, error) {
resId, err := knowledgebase.TemplateArgToRID(resource)
if err != nil {
return false, err
}
if resId.IsZero() {
return false, nil
}
ref := construct.PropertyRef{
Resource: resId,
Property: field,
}
if bracketIdx := strings.Index(field, "["); bracketIdx != -1 {
// Cannot depend on properties within lists, stop at the list itself
ref.Property = field[:bracketIdx]
}
ctx.addRef(ref)
return ctx.inner.HasField(field, resId)
}
func (ctx *fauxConfigContext) FieldValue(field string, resource any) (any, error) {
resId, err := knowledgebase.TemplateArgToRID(resource)
if err != nil {
return "", err
}
if resId.IsZero() {
return nil, nil
}
ref := construct.PropertyRef{
Resource: resId,
Property: field,
}
value, err := ctx.inner.FieldValue(field, resId)
if err != nil {
if bracketIdx := strings.Index(field, "["); bracketIdx != -1 {
// Cannot depend on properties within lists, stop at the list itself
ref.Property = field[:bracketIdx]
}
}
ctx.addRef(ref)
if value != nil {
return value, nil
}
tmpl, err := ctx.inner.KB().GetResourceTemplate(resId)
if err != nil {
return "", err
}
return emptyValue(tmpl, field)
}
func emptyValue(tmpl *knowledgebase.ResourceTemplate, property string) (any, error) {
prop := tmpl.GetProperty(property)
if prop == nil {
return nil, fmt.Errorf("could not find property %s on template %s", property, tmpl.Id())
}
return prop.ZeroValue(), nil
}
func (ctx *fauxConfigContext) HasUpstream(selector any, resource construct.ResourceId) (bool, error) {
selId, err := knowledgebase.TemplateArgToRID(selector)
if err != nil {
return false, err
}
has, innerErr := ctx.inner.HasUpstream(selector, resource)
if innerErr == nil && has {
ctx.addGraphState(&graphStateVertex{
repr: graphStateRepr(fmt.Sprintf("hasUpstream(%s, %s)", selId, resource)),
Test: func(g construct.Graph) (ReadyPriority, error) { return ReadyNow, nil },
})
return true, nil
}
ctx.addGraphState(&graphStateVertex{
repr: graphStateRepr(fmt.Sprintf("hasUpstream(%s, %s)", selId, resource)),
Test: func(g construct.Graph) (ReadyPriority, error) {
upstream, err := knowledgebase.Upstream(g, ctx.KB(), resource, knowledgebase.FirstFunctionalLayer)
if err != nil {
return NotReadyMax, err
}
for _, up := range upstream {
if selId.Matches(up) {
return ReadyNow, nil
}
}
return NotReadyMid, nil
},
})
return has, innerErr
}
func (ctx *fauxConfigContext) Upstream(selector any, resource construct.ResourceId) (construct.ResourceId, error) {
selId, err := knowledgebase.TemplateArgToRID(selector)
if err != nil {
return construct.ResourceId{}, err
}
up, innerErr := ctx.inner.Upstream(selector, resource)
if innerErr == nil && !up.IsZero() {
ctx.addGraphState(&graphStateVertex{
repr: graphStateRepr(fmt.Sprintf("Upstream(%s, %s)", selId, resource)),
Test: func(g construct.Graph) (ReadyPriority, error) { return ReadyNow, nil },
})
return up, nil
}
ctx.addGraphState(&graphStateVertex{
repr: graphStateRepr(fmt.Sprintf("Upstream(%s, %s)", selId, resource)),
Test: func(g construct.Graph) (ReadyPriority, error) {
upstream, err := knowledgebase.Upstream(g, ctx.KB(), resource, knowledgebase.FirstFunctionalLayer)
if err != nil {
return NotReadyMax, err
}
for _, up := range upstream {
if selId.Matches(up) {
return ReadyNow, nil
}
}
return NotReadyMid, nil
},
})
return up, innerErr
}
func (ctx *fauxConfigContext) AllUpstream(selector any, resource construct.ResourceId) (construct.ResourceList, error) {
ctx.addGraphState(&graphStateVertex{
repr: graphStateRepr(fmt.Sprintf("AllUpstream(%s, %s)", selector, resource)),
Test: func(g construct.Graph) (ReadyPriority, error) {
// Can never say that [AllUpstream] is ready until it must be evaluated due to being one of the final ones
return NotReadyHigh, nil
},
})
return ctx.inner.AllUpstream(selector, resource)
}
func (ctx *fauxConfigContext) HasDownstream(selector any, resource construct.ResourceId) (bool, error) {
selId, err := knowledgebase.TemplateArgToRID(selector)
if err != nil {
return false, err
}
has, innerErr := ctx.inner.HasDownstream(selector, resource)
if innerErr == nil && has {
ctx.addGraphState(&graphStateVertex{
repr: graphStateRepr(fmt.Sprintf("hasDownstream(%s, %s)", selId, resource)),
Test: func(g construct.Graph) (ReadyPriority, error) { return ReadyNow, nil },
})
return true, nil
}
ctx.addGraphState(&graphStateVertex{
repr: graphStateRepr(fmt.Sprintf("hasDownstream(%s, %s)", selId, resource)),
Test: func(g construct.Graph) (ReadyPriority, error) {
downstream, err := knowledgebase.Downstream(g, ctx.KB(), resource, knowledgebase.FirstFunctionalLayer)
if err != nil {
return NotReadyMax, err
}
for _, down := range downstream {
if selId.Matches(down) {
return ReadyNow, nil
}
}
return NotReadyMid, nil
},
})
return has, innerErr
}
func (ctx *fauxConfigContext) Downstream(selector any, resource construct.ResourceId) (construct.ResourceId, error) {
selId, err := knowledgebase.TemplateArgToRID(selector)
if err != nil {
return construct.ResourceId{}, err
}
down, innerErr := ctx.inner.Downstream(selector, resource)
if innerErr == nil && !down.IsZero() {
ctx.addGraphState(&graphStateVertex{
repr: graphStateRepr(fmt.Sprintf("Downstream(%s, %s)", selId, resource)),
Test: func(g construct.Graph) (ReadyPriority, error) { return ReadyNow, nil },
})
return down, nil
}
ctx.addGraphState(&graphStateVertex{
repr: graphStateRepr(fmt.Sprintf("Downstream(%s, %s)", selId, resource)),
Test: func(g construct.Graph) (ReadyPriority, error) {
downstream, err := knowledgebase.Downstream(g, ctx.KB(), resource, knowledgebase.FirstFunctionalLayer)
if err != nil {
return NotReadyMax, err
}
for _, down := range downstream {
if selId.Matches(down) {
return ReadyNow, nil
}
}
return NotReadyMid, nil
},
})
return down, innerErr
}
func (ctx *fauxConfigContext) ClosestDownstream(selector any, resource construct.ResourceId) (construct.ResourceId, error) {
ctx.addGraphState(&graphStateVertex{
repr: graphStateRepr(fmt.Sprintf("closestDownstream(%s, %s)", selector, resource)),
Test: func(g construct.Graph) (ReadyPriority, error) {
// Can never say that [ClosestDownstream] is ready because something closer could always be added
return NotReadyMid, nil
},
})
return ctx.inner.ClosestDownstream(selector, resource)
}
func (ctx *fauxConfigContext) AllDownstream(selector any, resource construct.ResourceId) (construct.ResourceList, error) {
ctx.addGraphState(&graphStateVertex{
repr: graphStateRepr(fmt.Sprintf("allDownstream(%s, %s)", selector, resource)),
Test: func(g construct.Graph) (ReadyPriority, error) {
// Can never say that [AllDownstream] is ready until it must be evaluated due to being one of the final ones
return NotReadyHigh, nil
},
})
return ctx.inner.AllDownstream(selector, resource)
}
package operational_eval
import (
"errors"
"fmt"
"io"
"slices"
"strings"
"github.com/klothoplatform/klotho/pkg/dot"
)
const (
rankSize = 20
errorColour = "#e87b7b"
unevaluatedColour = "#e3cf9d"
addedByColour = "#3f822b"
)
func keyAttributes(eval *Evaluator, key Key) map[string]string {
attribs := make(map[string]string)
var style []string
switch key.keyType() {
case keyTypeProperty:
attribs["label"] = fmt.Sprintf(`%s\n#%s`, key.Ref.Resource, key.Ref.Property)
attribs["shape"] = "box"
case keyTypeEdge:
attribs["label"] = fmt.Sprintf(`%s\n→ %s`, key.Edge.Source, key.Edge.Target)
attribs["shape"] = "parallelogram"
case keyTypeGraphState:
attribs["label"] = string(key.GraphState)
attribs["shape"] = "box"
style = append(style, "dashed")
case keyTypePathExpand:
attribs["label"] = fmt.Sprintf(`%s\n→ %s`, key.Edge.Source, key.Edge.Target)
var extra []string
if key.PathSatisfication.Classification != "" {
extra = append(extra, fmt.Sprintf("<%s>", key.PathSatisfication.Classification))
}
if key.PathSatisfication.Target.PropertyReferenceChangesBoundary() {
extra = append(extra, fmt.Sprintf("target#%s", key.PathSatisfication.Target.PropertyReference))
}
if key.PathSatisfication.Source.PropertyReferenceChangesBoundary() {
extra = append(extra, fmt.Sprintf("source#%s", key.PathSatisfication.Target.PropertyReference))
}
if len(extra) > 0 {
attribs["label"] += `\n` + strings.Join(extra, " ")
}
attribs["shape"] = "parallelogram"
style = append(style, "dashed")
default:
attribs["label"] = fmt.Sprintf(`%s\n(UNKOWN)`, key)
attribs["color"] = "#fc8803"
}
if _, props, err := eval.graph.VertexWithProperties(key); err == nil {
if dur := props.Attributes[attribDuration]; dur != "" {
attribs["label"] = fmt.Sprintf(`%s\n%s`, attribs["label"], dur)
}
}
if eval.errored.Contains(key) {
style = append(style, "filled")
attribs["fillcolor"] = errorColour
}
attribs["style"] = strings.Join(style, ",")
return attribs
}
type evalRank struct {
Unevaluated bool
Rank int
SubRanks [][]Key
}
func toRanks(eval *Evaluator) ([]evalRank, error) {
ranks := make([]evalRank, len(eval.evaluatedOrder), len(eval.evaluatedOrder)+1)
pred, err := eval.graph.PredecessorMap()
if err != nil {
return nil, err
}
adj, err := eval.graph.AdjacencyMap()
if err != nil {
return nil, err
}
for i, keys := range eval.evaluatedOrder {
ranks[i] = evalRank{Rank: i}
rank := &ranks[i]
if len(keys) > rankSize {
// split large ranks into smaller ones
var noDeps []Key
var onlyDownstream []Key
var hasUpstream []Key
for _, key := range keys {
switch {
case len(pred[key]) == 0 && len(adj[key]) == 0:
noDeps = append(noDeps, key)
case len(pred[key]) == 0:
onlyDownstream = append(onlyDownstream, key)
default:
hasUpstream = append(hasUpstream, key)
}
}
if len(onlyDownstream) > 0 {
for i := 0; i < len(onlyDownstream); i += rankSize {
rank.SubRanks = append(rank.SubRanks, onlyDownstream[i:min(i+rankSize, len(onlyDownstream))])
}
}
for i := 0; i < len(noDeps); i += rankSize {
rank.SubRanks = append(rank.SubRanks, noDeps[i:min(i+rankSize, len(noDeps))])
}
if len(hasUpstream) > 0 {
for i := 0; i < len(hasUpstream); i += rankSize {
rank.SubRanks = append(rank.SubRanks, hasUpstream[i:min(i+rankSize, len(hasUpstream))])
}
}
} else {
rank.SubRanks = [][]Key{keys}
}
}
var unevaluated []Key
for key := range pred {
evaluated := false
for _, keys := range eval.evaluatedOrder {
if slices.Contains(keys, key) {
evaluated = true
break
}
}
if !evaluated {
unevaluated = append(unevaluated, key)
}
}
if len(unevaluated) > 0 {
rank := evalRank{
Unevaluated: true,
Rank: len(ranks),
}
for i := 0; i < len(unevaluated); i += rankSize {
rank.SubRanks = append(rank.SubRanks, unevaluated[i:min(i+rankSize, len(unevaluated))])
}
ranks = append(ranks, rank)
}
return ranks, nil
}
func graphToClusterDOT(eval *Evaluator, out io.Writer) error {
var errs error
printf := func(s string, args ...any) {
_, err := fmt.Fprintf(out, s, args...)
errs = errors.Join(errs, err)
}
printf(`strict digraph {
rankdir = "BT"
ranksep = 4
newrank = true
concentrate = true
`)
ranks, err := toRanks(eval)
if err != nil {
return err
}
adj, err := eval.graph.AdjacencyMap()
if err != nil {
return err
}
for _, evalRank := range ranks {
rank := evalRank.Rank
printf(" subgraph cluster_%d {\n", rank)
if evalRank.Unevaluated {
printf(` label = "Unevaluated"
style=filled
color="%s"
`, unevaluatedColour)
} else {
printf(" label = \"Evaluation Order %d\"\n", rank)
}
printf(" labelloc=b\n")
for i, subrank := range evalRank.SubRanks {
printf(" {")
if evalRank.Unevaluated {
printf("\n")
} else {
printf("rank=same\n")
}
for _, key := range subrank {
attribs := keyAttributes(eval, key)
attribs["group"] = fmt.Sprintf("group%d.%d", rank, i)
printf(" %q%s\n", key, dot.AttributesToString(attribs))
for tgt, e := range adj[key] {
if addedBy := e.Properties.Attributes[attribAddedBy]; addedBy == tgt.String() {
continue
}
printf(" %q -> %q\n", key, tgt)
}
}
printf(" }\n")
if i == 0 {
if rank > 0 {
prevRank := ranks[rank-1]
lastSubrank := prevRank.SubRanks[len(prevRank.SubRanks)-1]
printf(" %q -> %q [style=invis, weight=10]\n", subrank[0], lastSubrank[0])
printf(" %q -> %q [style=invis, weight=10]\n", subrank[len(subrank)-1], lastSubrank[len(lastSubrank)-1])
}
} else {
lastSubrank := evalRank.SubRanks[i-1]
printf(" %q -> %q [style=invis, weight=10]\n", subrank[0], lastSubrank[0])
printf(" %q -> %q [style=invis, weight=10]\n", subrank[len(subrank)-1], lastSubrank[len(lastSubrank)-1])
}
}
printf(" }\n")
}
printf("}\n")
return errs
}
func graphToDOT(eval *Evaluator, out io.Writer) error {
var errs error
printf := func(s string, args ...any) {
_, err := fmt.Fprintf(out, s, args...)
errs = errors.Join(errs, err)
}
printf(`strict digraph {
rankdir = BT
ranksep = 1
`)
adj, err := eval.graph.AdjacencyMap()
if err != nil {
return err
}
evalOrder := make(map[Key]int)
for i, keys := range eval.evaluatedOrder {
for _, key := range keys {
evalOrder[key] = i
}
}
for src, a := range adj {
attribs := keyAttributes(eval, src)
order, hasOrder := evalOrder[src]
if hasOrder {
attribs["label"] = fmt.Sprintf("[%d] %s", order, attribs["label"])
} else {
attribs["label"] = fmt.Sprintf("[?] %s", attribs["label"])
if s, ok := attribs["style"]; ok {
attribs["style"] = s + ",filled"
} else {
attribs["style"] = "filled"
}
attribs["fillcolor"] = "#e3cf9d"
}
_, props, _ := eval.graph.VertexWithProperties(src)
if props.Attributes != nil {
if group := props.Attributes[attribAddedIn]; group != "" {
attribs["label"] = fmt.Sprintf(`%s\n+%s`, attribs["label"], group)
}
if ready := props.Attributes[attribReady]; ready != "" && ready != ReadyNow.String() {
attribs["label"] = fmt.Sprintf(`%s\n%s`, attribs["label"], ready)
}
}
printf(" %q%s\n", src, dot.AttributesToString(attribs))
for tgt, e := range a {
edgeAttribs := make(map[string]string)
if group := e.Properties.Attributes[attribAddedIn]; group != "" {
edgeAttribs["label"] = fmt.Sprintf("+%s", group)
}
if addedBy := e.Properties.Attributes[attribAddedBy]; addedBy == tgt.String() {
edgeAttribs["color"] = addedByColour
edgeAttribs["style"] = "dashed"
}
if errored := e.Properties.Attributes[attribError]; errored != "" {
edgeAttribs["color"] = errorColour
edgeAttribs["penwidth"] = "2"
}
printf(" %q -> %q%s\n", src, tgt, dot.AttributesToString(edgeAttribs))
}
}
printf("}\n")
return errs
}
package operational_eval
import (
"errors"
"fmt"
)
type EnqueueErrors map[Key]error
func (e EnqueueErrors) Error() string {
return fmt.Sprintf("enqueue errors: %v", map[Key]error(e))
}
func (e EnqueueErrors) Unwrap() []error {
errs := make([]error, 0, len(e))
for k, err := range e {
errs = append(errs, fmt.Errorf("%s: %w", k, err))
}
return errs
}
func (e *EnqueueErrors) Append(key Key, err error) {
if err == nil {
return
}
if *e == nil {
*e = make(EnqueueErrors)
}
if x, ok := (*e)[key]; ok {
err = errors.Join(x, err)
}
(*e)[key] = err
}
package operational_eval
import (
"errors"
"fmt"
"sort"
"strings"
"time"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/graph_addons"
"github.com/klothoplatform/klotho/pkg/tui"
"go.uber.org/zap"
)
func (eval *Evaluator) updateSolveProgress() error {
prog := tui.GetProgress(eval.Solution.Context())
size, err := eval.unevaluated.Order()
if err != nil {
return err
}
totalSize, err := eval.graph.Order()
if err != nil {
return err
}
prog.Update("Solving", totalSize-size, totalSize)
return nil
}
func (eval *Evaluator) Evaluate() error {
defer eval.writeGraph("property_deps")
defer eval.writeExecOrder()
for {
size, err := eval.unevaluated.Order()
if err != nil {
return fmt.Errorf("failed to get unevaluated order: %w", err)
}
if size == 0 {
return nil
}
// add to evaluatedOrder so that in popReady it has the correct group number
// which is based on `len(eval.evaluatedOrder)`
eval.evaluatedOrder = append(eval.evaluatedOrder, []Key{})
ready, err := eval.pollReady()
if err != nil {
return fmt.Errorf("failed to poll ready: %w", err)
}
if len(ready) == 0 {
return fmt.Errorf("possible circular dependency detected in properties graph: %d remaining", size)
}
log := eval.Log().Named("eval")
groupStart := time.Now()
var errs []error
for _, v := range ready {
k := v.Key()
addErr := func(err error) {
errs = append(errs, fmt.Errorf("failed to evaluate %s: %w", k, err))
}
_, err := eval.unevaluated.Vertex(k)
switch {
case err != nil && !errors.Is(err, graph.ErrVertexNotFound):
addErr(err)
continue
case errors.Is(err, graph.ErrVertexNotFound):
// vertex was removed by earlier ready vertex
continue
}
log.Debugf("Evaluating %s", k)
eval.evaluatedOrder[len(eval.evaluatedOrder)-1] = append(eval.evaluatedOrder[len(eval.evaluatedOrder)-1], k)
eval.currentKey = &k
if err := graph_addons.RemoveVertexAndEdges(eval.unevaluated, v.Key()); err != nil {
addErr(err)
}
start := time.Now()
err = v.Evaluate(eval)
duration := time.Since(start)
if err != nil {
eval.errored.Add(k)
addErr(err)
}
if _, props, err := eval.graph.VertexWithProperties(k); err != nil {
log.Errorf("failed to get properties for %s: %s", k, err)
} else {
props.Attributes[attribDuration] = duration.String()
}
if err := eval.updateSolveProgress(); err != nil {
return err
}
}
log.Debugf("Completed group in %s", time.Since(groupStart))
if len(errs) > 0 {
return fmt.Errorf("failed to evaluate group %d: %w", len(eval.evaluatedOrder), errors.Join(errs...))
}
recalcStart := time.Now()
if err := eval.RecalculateUnevaluated(); err != nil {
return err
}
log.Debugf("Recalculated unevaluated in %s", time.Since(recalcStart))
}
}
func (eval *Evaluator) printUnevaluated() {
log := eval.Log().Named("eval.poll-deps")
if !log.Desugar().Core().Enabled(zap.DebugLevel) {
return
}
adj, err := eval.unevaluated.AdjacencyMap()
if err != nil {
log.Errorf("Could not get adjacency map: %s", err)
return
}
keys := make([]Key, 0, len(adj))
for k := range adj {
keys = append(keys, k)
}
sort.SliceStable(keys, func(i, j int) bool {
return keys[i].Less(keys[j])
})
log.Debugf("Unevaluated vertices: %d", len(keys))
for _, k := range keys {
srcStr := fmt.Sprintf("%s (%d)", k, len(adj[k]))
srcV, err := eval.unevaluated.Vertex(k)
if err != nil {
srcStr += fmt.Sprintf(" [error: %s]", err)
} else {
if cond, ok := srcV.(conditionalVertex); ok {
vReady, err := cond.Ready(eval)
if err != nil {
srcStr += fmt.Sprintf(" [error: %s]", err)
} else {
srcStr += fmt.Sprintf(" [%s]", vReady)
}
}
}
log.Debug(srcStr)
ts := make([]Key, 0, len(adj[k]))
for t := range adj[k] {
ts = append(ts, t)
}
sort.SliceStable(ts, func(i, j int) bool {
return ts[i].Less(ts[j])
})
for _, t := range ts {
log.Debugf(" - %s", t)
}
}
}
func (eval *Evaluator) pollReady() ([]Vertex, error) {
log := eval.Log().Named("eval.dequeue")
adj, err := eval.unevaluated.AdjacencyMap()
if err != nil {
return nil, err
}
eval.printUnevaluated()
var readyKeys []Key
for v, deps := range adj {
if len(deps) == 0 {
readyKeys = append(readyKeys, v)
}
}
readyPriorities := make([][]Vertex, NotReadyMax)
var errs error
for _, key := range readyKeys {
v, err := eval.unevaluated.Vertex(key)
if err != nil {
errs = errors.Join(errs, err)
continue
}
if cond, ok := v.(conditionalVertex); ok {
vReady, err := cond.Ready(eval)
if err != nil {
errs = errors.Join(errs, err)
continue
}
_, props, _ := eval.graph.VertexWithProperties(key)
if props.Attributes != nil {
props.Attributes[attribReady] = vReady.String()
}
readyPriorities[vReady] = append(readyPriorities[vReady], v)
} else {
readyPriorities[ReadyNow] = append(readyPriorities[ReadyNow], v)
}
}
if errs != nil {
return nil, errs
}
var ready []Vertex
for i, prio := range readyPriorities {
if len(prio) > 0 && ready == nil {
ready = prio
sort.SliceStable(ready, func(i, j int) bool {
a, b := ready[i].Key(), ready[j].Key()
return a.Less(b)
})
log.Debugf("Dequeued [%s]: %d", ReadyPriority(i), len(ready))
for _, v := range ready {
log.Debugf(" - %s", v.Key())
}
} else if len(prio) > 0 {
log.Debugf("Remaining unready [%s]: %d", ReadyPriority(i), len(prio))
}
}
return ready, errs
}
// RecalculateUnevaluated is used to recalculate the dependencies of all the unevaluated vertices in case
// some parts have "opened up" due to the evaluation of other vertices via template `{{ if }}` conditions or
// chained dependencies (eg `{{ fieldValue "X" (fieldValue "SomeRef" .Self) }}`, the dependency of X won't be
// able to be resolved until SomeRef is evaluated).
// There is likely a way to determine which vertices need to be recalculated, but the runtime impact of just
// recalculating them all isn't large at the size of graphs we're currently running with.
func (eval *Evaluator) RecalculateUnevaluated() error {
topo, err := graph.TopologicalSort(eval.unevaluated)
if err != nil {
return err
}
var errs error
for _, key := range topo {
vertex, err := eval.unevaluated.Vertex(key)
if err != nil {
errs = errors.Join(errs, err)
continue
}
changes := newChanges()
err = changes.AddVertexAndDeps(eval, vertex)
if err == nil {
err = eval.enqueue(changes)
}
errs = errors.Join(errs, err)
}
return errs
}
func (eval *Evaluator) cleanupPropertiesSubVertices(ref construct.PropertyRef, resource *construct.Resource) error {
topo, err := graph.TopologicalSort(eval.unevaluated)
if err != nil {
return err
}
var errs error
for _, key := range topo {
if key.keyType() != keyTypeProperty {
continue
}
if key.Ref.Resource != ref.Resource {
continue
}
if strings.HasPrefix(key.Ref.Property, ref.Property) {
path, err := resource.PropertyPath(key.Ref.Property)
// an error would mean that the path no longer exists so we know we should remove the vertex
if err == nil {
// if the paths parent still exists then we know we will end up evaluating the vertex and should not remove it
parentIndex := len(path) - 2
if parentIndex < 0 {
continue
}
if parent, ok := path[parentIndex].Get(); ok && parent != nil {
continue
}
}
errs = errors.Join(errs, graph_addons.RemoveVertexAndEdges(eval.graph, key))
errs = errors.Join(errs, graph_addons.RemoveVertexAndEdges(eval.unevaluated, key))
}
}
return errs
}
package operational_eval
import (
"errors"
"fmt"
"strings"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/solution"
"github.com/klothoplatform/klotho/pkg/graph_addons"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/klothoplatform/klotho/pkg/logging"
"github.com/klothoplatform/klotho/pkg/set"
"go.uber.org/zap"
)
type (
Evaluator struct {
Solution solution.Solution
// graph holds all of the property dependencies regardless of whether they've been evaluated or not
graph Graph
unevaluated Graph
evaluatedOrder [][]Key
errored set.Set[Key]
currentKey *Key
log *zap.SugaredLogger
}
Key struct {
Ref construct.PropertyRef
RuleHash string
Edge construct.SimpleEdge
GraphState graphStateRepr
PathSatisfication knowledgebase.EdgePathSatisfaction
}
Vertex interface {
Key() Key
Evaluate(eval *Evaluator) error
UpdateFrom(other Vertex)
Dependencies(eval *Evaluator, propCtx dependencyCapturer) error
}
conditionalVertex interface {
Ready(*Evaluator) (ReadyPriority, error)
}
ReadyPriority int
graphChanges struct {
nodes map[Key]Vertex
// edges is map[source]targets
edges map[Key]set.Set[Key]
}
// keyType makes it easy to switch on, and by being an int makes sorting of keys easier
keyType int
)
const (
// ReadyNow indicates the vertex is ready to be evaluated
ReadyNow ReadyPriority = iota
// NotReadyLow is used when it's relatively certain that the vertex will be ready, but cannot be 100% certain.
NotReadyLow
// NotReadyMid is for cases which don't clearly fit in [NotReadyLow] or [NotReadyHigh]
NotReadyMid
// NotReadyHigh is used for verticies which can almost never be 100% certain that they're correct based on the
// current state.
NotReadyHigh
// NotReadyMax it reserved for when running the vertex would likely cause an error, rather than incorrect behaviour
NotReadyMax
)
const (
keyTypeProperty keyType = iota
keyTypeEdge
keyTypeGraphState
keyTypePathExpand
)
func NewEvaluator(ctx solution.Solution) *Evaluator {
return &Evaluator{
Solution: ctx,
graph: newGraph(nil),
unevaluated: newGraph(nil),
errored: make(set.Set[Key]),
}
}
func (key Key) keyType() keyType {
if !key.Ref.Resource.IsZero() {
return keyTypeProperty
}
if key.GraphState != "" {
return keyTypeGraphState
}
if key.PathSatisfication != (knowledgebase.EdgePathSatisfaction{}) {
return keyTypePathExpand
}
// make sure edge is last because PathExpand also has an edge
if key.Edge != (construct.SimpleEdge{}) {
return keyTypeEdge
}
return -1
}
func (key Key) String() string {
kt := key.keyType()
switch kt {
case keyTypeProperty:
return key.Ref.String()
case keyTypeEdge:
return key.Edge.String()
case keyTypeGraphState:
return string(key.GraphState)
case keyTypePathExpand:
args := []string{
key.Edge.String(),
}
if key.PathSatisfication.Classification != "" {
args = append(args, fmt.Sprintf("<%s>", key.PathSatisfication.Classification))
}
if key.PathSatisfication.Target.PropertyReferenceChangesBoundary() {
args = append(args, fmt.Sprintf("target#%s", key.PathSatisfication.Target.PropertyReference))
}
if key.PathSatisfication.Source.PropertyReferenceChangesBoundary() {
args = append(args, fmt.Sprintf("source#%s", key.PathSatisfication.Target.PropertyReference))
}
return fmt.Sprintf("Expand(%s)", strings.Join(args, ", "))
}
return "<empty>"
}
func (key Key) Less(other Key) bool {
myKT := key.keyType()
otherKT := other.keyType()
if myKT != otherKT {
return myKT < otherKT
}
switch myKT {
case keyTypeProperty:
if key.Ref.Resource != other.Ref.Resource {
return construct.ResourceIdLess(key.Ref.Resource, other.Ref.Resource)
}
return key.Ref.Property < other.Ref.Property
case keyTypeEdge:
return key.Edge.Less(other.Edge)
case keyTypeGraphState:
return key.GraphState < other.GraphState
case keyTypePathExpand:
if key.PathSatisfication.Classification != other.PathSatisfication.Classification {
return key.PathSatisfication.Classification < other.PathSatisfication.Classification
}
return key.Edge.Less(other.Edge)
}
// Empty key, put that last, though it should never happen
return false
}
func (r ReadyPriority) String() string {
switch r {
case ReadyNow:
return "ReadyNow"
case NotReadyLow:
return "NotReadyLow"
case NotReadyMid:
return "NotReadyMid"
case NotReadyHigh:
return "NotReadyHigh"
case NotReadyMax:
return "NotReadyMax"
default:
return fmt.Sprintf("ReadyPriority(%d)", r)
}
}
func (eval *Evaluator) EvalutedOrder() [][]Key {
return eval.evaluatedOrder
}
func (eval *Evaluator) Log() *zap.SugaredLogger {
if eval.log == nil {
eval.log = logging.GetLogger(eval.Solution.Context()).Named("engine.opeval").Sugar()
}
return eval.log.With("group", len(eval.evaluatedOrder))
}
func (eval *Evaluator) isEvaluated(k Key) (bool, error) {
_, err := eval.graph.Vertex(k)
if errors.Is(err, graph.ErrVertexNotFound) {
return false, nil
} else if err != nil {
return false, err
}
_, err = eval.unevaluated.Vertex(k)
if errors.Is(err, graph.ErrVertexNotFound) {
return true, nil
} else if err != nil {
return false, err
}
return false, nil
}
func (eval *Evaluator) addEdge(source, target Key) error {
log := eval.Log().Named("deps")
_, err := eval.graph.Edge(source, target)
if err == nil {
log.Debugf(" -> %s ✓", target)
return nil
}
err = eval.graph.AddEdge(source, target, func(ep *graph.EdgeProperties) {
ep.Attributes[attribAddedIn] = fmt.Sprintf("%d", len(eval.evaluatedOrder))
if eval.currentKey != nil {
ep.Attributes[attribAddedBy] = eval.currentKey.String()
}
})
if err != nil {
if errors.Is(err, graph.ErrEdgeCreatesCycle) {
path, _ := graph.ShortestPath(eval.graph, target, source)
pathS := make([]string, len(path))
for i, k := range path {
pathS[i] = `"` + k.String() + `"`
}
return fmt.Errorf(
"could not add edge %q -> %q: would create cycle exiting path: %s",
source, target, strings.Join(pathS, " -> "),
)
}
// NOTE(gg): If this fails with target not in graph, then we might need to add the target in with a
// new vertex type of "overwrite me later". It would be an odd situation though, which is why it is
// an error for now.
return fmt.Errorf("could not add edge %q -> %q: %w", source, target, err)
}
markError := func() {
_ = eval.graph.UpdateEdge(source, target, func(ep *graph.EdgeProperties) {
ep.Attributes[attribError] = "true"
})
}
_, err = eval.unevaluated.Vertex(target)
switch {
case errors.Is(err, graph.ErrVertexNotFound):
// the 'graph.AddEdge' succeeded, thus the target exists in the total graph
// which means that the target vertex is done, so ignore adding the edge to the unevaluated graph
log.Debugf(" -> %s (done)", target)
case err != nil:
markError()
return fmt.Errorf("could not get unevaluated vertex %s: %w", target, err)
default:
log.Debugf(" -> %s", target)
sourceEvaluated, err := eval.isEvaluated(source)
if err != nil {
markError()
return fmt.Errorf("could not check if source %s is evaluated: %w", source, err)
} else if sourceEvaluated {
markError()
return fmt.Errorf(
"could not add edge %q -> %q: source is already evaluated",
source, target)
}
err = eval.unevaluated.AddEdge(source, target)
if err != nil {
markError()
return fmt.Errorf("could not add unevaluated edge %q -> %q: %w", source, target, err)
}
}
return nil
}
func (eval *Evaluator) enqueue(changes graphChanges) error {
if len(changes.nodes) == 0 && len(changes.edges) == 0 {
// short circuit when there's nothing to change
return nil
}
log := eval.Log().Named("enqueue")
var errs EnqueueErrors
for key, v := range changes.nodes {
_, err := eval.graph.Vertex(key)
switch {
case errors.Is(err, graph.ErrVertexNotFound):
err := eval.graph.AddVertex(v, func(vp *graph.VertexProperties) {
vp.Attributes[attribAddedIn] = fmt.Sprintf("%d", len(eval.evaluatedOrder))
if eval.currentKey != nil {
vp.Attributes[attribAddedBy] = eval.currentKey.String()
}
})
if err != nil {
errs.Append(key, fmt.Errorf("could not add vertex %s: %w", key, err))
continue
}
if eval.currentKey != nil {
changes.addEdge(key, *eval.currentKey)
}
log.Debugf("Enqueued %s", key)
if err := eval.unevaluated.AddVertex(v); err != nil {
errs.Append(key, fmt.Errorf("could not add unevaluated vertex %s: %w", key, err))
}
case err == nil:
existing, err := eval.graph.Vertex(key)
if err != nil {
errs.Append(key, fmt.Errorf("could not get existing vertex %s: %w", key, err))
continue
}
if v != existing {
existing.UpdateFrom(v)
}
default:
errs.Append(key, fmt.Errorf("could not get existing vertex %s: %w", key, err))
}
}
if errs != nil {
return errs
}
log = eval.Log().Named("deps")
for source, targets := range changes.edges {
if len(targets) > 0 {
log.Debug(source)
}
for target := range targets {
if err := eval.addEdge(source, target); err != nil {
errs.Append(source, fmt.Errorf("-> %s: %w", target, err))
}
}
}
if errs != nil {
return errs
}
return nil
}
func newChanges() graphChanges {
return graphChanges{
nodes: make(map[Key]Vertex),
edges: make(map[Key]set.Set[Key]),
}
}
// addNode is a convenient lower-level add for [graphChanges.nodes]
//
//nolint:unused
func (changes graphChanges) addNode(v Vertex) {
changes.nodes[v.Key()] = v
}
// addEdge is a convenient lower-level add for [graphChanges.edges]
func (changes graphChanges) addEdge(source, target Key) {
out, ok := changes.edges[source]
if !ok {
out = make(set.Set[Key])
changes.edges[source] = out
}
out.Add(target)
}
// addEdges is a convenient lower-level add for [graphChanges.edges] for multiple targets
func (changes graphChanges) addEdges(source Key, targets set.Set[Key]) {
if len(targets) == 0 {
return
}
out, ok := changes.edges[source]
if !ok {
out = make(set.Set[Key])
changes.edges[source] = out
}
out.AddFrom(targets)
}
func (changes graphChanges) AddVertexAndDeps(eval *Evaluator, v Vertex) error {
changes.nodes[v.Key()] = v
depCaptureChanges := newChanges()
propCtx := newDepCapture(solution.DynamicCtx(eval.Solution), depCaptureChanges, v.Key())
err := v.Dependencies(eval, propCtx)
if err != nil {
return fmt.Errorf("could not get dependencies for %s: %w", v.Key(), err)
}
changes.Merge(depCaptureChanges)
return nil
}
func (changes graphChanges) Merge(other graphChanges) {
for k, v := range other.nodes {
changes.nodes[k] = v
}
for k, v := range other.edges {
changes.addEdges(k, v)
}
}
func (eval *Evaluator) UpdateId(oldId, newId construct.ResourceId) error {
if oldId == newId {
return nil
}
eval.Log().Infof("Updating id %s to %s", oldId, newId)
v, err := eval.Solution.RawView().Vertex(oldId)
if err != nil {
return err
}
v.ID = newId
// We have to operate on these graphs separately since the deployment graph can store edges based on property references.
// since these edges wont exist in the dataflow graph they would never get cleaned up if we passed in the raw view.
err = errors.Join(
construct.PropagateUpdatedId(eval.Solution.DataflowGraph(), oldId),
graph_addons.ReplaceVertex(eval.Solution.DeploymentGraph(), oldId, v, construct.ResourceHasher),
)
if err != nil {
return err
}
topo, err := graph.TopologicalSort(eval.graph)
if err != nil {
return err
}
// update all constraints that pertain to the old id
c := eval.Solution.Constraints()
for i, rc := range c.Resources {
if rc.Target == oldId {
c.Resources[i].Target = newId
}
}
var errs error
replaceVertex := func(oldKey Key, vertex Vertex) {
errs = errors.Join(errs,
graph_addons.ReplaceVertex(eval.graph, oldKey, Vertex(vertex), Vertex.Key),
)
if _, err := eval.unevaluated.Vertex(oldKey); err == nil {
errs = errors.Join(errs,
graph_addons.ReplaceVertex(eval.unevaluated, oldKey, Vertex(vertex), Vertex.Key),
)
} else if !errors.Is(err, graph.ErrVertexNotFound) {
errs = errors.Join(errs, err)
}
}
for _, key := range topo {
vertex, err := eval.graph.Vertex(key)
if err != nil {
errs = errors.Join(errs, err)
continue
}
switch vertex := vertex.(type) {
case *propertyVertex:
if key.Ref.Resource == oldId {
vertex.Ref.Resource = newId
replaceVertex(key, vertex)
}
for edge, rules := range vertex.EdgeRules {
if edge.Source == oldId || edge.Target == oldId {
delete(vertex.EdgeRules, edge)
vertex.EdgeRules[UpdateEdgeId(edge, oldId, newId)] = rules
}
}
case *edgeVertex:
if key.Edge.Source == oldId || key.Edge.Target == oldId {
updated := UpdateEdgeId(
construct.SimpleEdge{Source: vertex.Edge.Source, Target: vertex.Edge.Target},
oldId,
newId,
)
vertex.Edge.Source = updated.Source
vertex.Edge.Target = updated.Target
replaceVertex(key, vertex)
}
case *pathExpandVertex:
if key.Edge.Source == oldId || key.Edge.Target == oldId {
vertex.SatisfactionEdge = UpdateEdgeId(vertex.SatisfactionEdge, oldId, newId)
replaceVertex(key, vertex)
// because the temp graph contains the src and target as nodes, we need to update it if it exists
}
if vertex.TempGraph != nil {
_, err := vertex.TempGraph.Vertex(oldId)
switch {
case errors.Is(err, graph.ErrVertexNotFound):
// do nothing
case err != nil:
errs = errors.Join(errs, err)
default:
errs = errors.Join(errs, construct.ReplaceResource(vertex.TempGraph, oldId, &construct.Resource{ID: newId}))
}
}
}
}
if errs != nil {
return errs
}
for i, keys := range eval.evaluatedOrder {
for j, key := range keys {
oldKey := key
if key.Ref.Resource == oldId {
key.Ref.Resource = newId
}
key.Edge = UpdateEdgeId(key.Edge, oldId, newId)
if key != oldKey {
eval.evaluatedOrder[i][j] = key
}
}
}
for key := range eval.errored {
oldKey := key
if key.Ref.Resource == oldId {
key.Ref.Resource = newId
}
key.Edge = UpdateEdgeId(key.Edge, oldId, newId)
if key != oldKey {
eval.errored.Remove(oldKey)
eval.errored.Add(key)
}
}
if eval.currentKey != nil {
if eval.currentKey.Ref.Resource == oldId {
eval.currentKey.Ref.Resource = newId
}
eval.currentKey.Edge = UpdateEdgeId(eval.currentKey.Edge, oldId, newId)
}
return nil
}
package operational_eval
import (
"errors"
"fmt"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/path_selection"
"github.com/klothoplatform/klotho/pkg/graph_addons"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/klothoplatform/klotho/pkg/set"
"go.uber.org/zap"
)
type Graph graph.Graph[Key, Vertex]
func newGraph(log *zap.Logger) Graph {
g := graph.NewWithStore(
Vertex.Key,
graph_addons.NewMemoryStore[Key, Vertex](),
graph.Directed(),
graph.PreventCycles(),
)
if log != nil {
g = graph_addons.LoggingGraph[Key, Vertex]{
Graph: g,
Log: log.Sugar(),
Hash: Vertex.Key,
}
}
return g
}
func (eval *Evaluator) AddResources(rs ...*construct.Resource) error {
changes := newChanges()
var errs error
for _, res := range rs {
tmpl, err := eval.Solution.KnowledgeBase().GetResourceTemplate(res.ID)
if err != nil {
errs = errors.Join(errs, err)
continue
}
rvs, err := eval.resourceVertices(res, tmpl)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("could not add resource eval vertices %s: %w", res.ID, err))
continue
}
changes.Merge(rvs)
}
if errs != nil {
return errs
}
return eval.enqueue(changes)
}
func (eval *Evaluator) AddEdges(es ...construct.Edge) error {
changes := newChanges()
var errs error
for _, e := range es {
tmpl := eval.Solution.KnowledgeBase().GetEdgeTemplate(e.Source, e.Target)
var evs graphChanges
var err error
if tmpl == nil {
evs, err = eval.pathVertices(e.Source, e.Target)
} else {
evs, err = eval.edgeVertices(e, tmpl)
}
if err != nil {
errs = errors.Join(errs, fmt.Errorf("could not add edge eval vertex %s -> %s: %w", e.Source, e.Target, err))
continue
}
changes.Merge(evs)
}
if errs != nil {
return errs
}
return eval.enqueue(changes)
}
func (eval *Evaluator) pathVertices(source, target construct.ResourceId) (graphChanges, error) {
changes := newChanges()
src, err := eval.Solution.RawView().Vertex(source)
if err != nil {
return changes, fmt.Errorf("failed to get source vertex for %s: %w", source, err)
}
dst, err := eval.Solution.RawView().Vertex(target)
if err != nil {
return changes, fmt.Errorf("failed to get target vertex for %s: %w", target, err)
}
requireFullBuild := dst.Imported || src.Imported
generateAndAddVertex := func(
edge construct.SimpleEdge,
kb knowledgebase.TemplateKB,
satisfication knowledgebase.EdgePathSatisfaction,
) error {
if satisfication.Classification == "" {
return fmt.Errorf("edge %s has no classification to expand", edge)
}
buildTempGraph := true
// We are checking to see if either of the source or target nodes will change due to property references,
// if there are property references we want to ensure the correct dependency ordering is in place so
// we cannot yet split the expansion vertex up or build the temp graph
if satisfication.Source.PropertyReferenceChangesBoundary() || satisfication.Target.PropertyReferenceChangesBoundary() {
buildTempGraph = false
}
var tempGraph construct.Graph
if buildTempGraph {
var err error
tempGraph, err = path_selection.BuildPathSelectionGraph(
eval.Solution.Context(),
edge,
kb,
satisfication.Classification,
!requireFullBuild,
)
if err != nil {
return fmt.Errorf("could not build temp graph for %s: %w", edge, err)
}
}
vertex := &pathExpandVertex{SatisfactionEdge: edge, Satisfication: satisfication, TempGraph: tempGraph}
return changes.AddVertexAndDeps(eval, vertex)
}
kb := eval.Solution.KnowledgeBase()
edge := construct.SimpleEdge{Source: source, Target: target}
pathSatisfications, err := kb.GetPathSatisfactionsFromEdge(source, target)
if err != nil {
return changes, fmt.Errorf("could not get path satisfications for %s: %w", edge, err)
}
var errs error
for _, satisfication := range pathSatisfications {
errs = errors.Join(errs, generateAndAddVertex(edge, kb, satisfication))
}
if len(pathSatisfications) == 0 {
errs = errors.Join(errs, fmt.Errorf("could not find any path satisfications for %s", edge))
}
return changes, errs
}
func UpdateEdgeId(e construct.SimpleEdge, oldId, newId construct.ResourceId) construct.SimpleEdge {
switch {
case e.Source == oldId:
e.Source = newId
case e.Target == oldId:
e.Target = newId
}
return e
}
func (eval *Evaluator) resourceVertices(
res *construct.Resource,
tmpl *knowledgebase.ResourceTemplate,
) (graphChanges, error) {
changes := newChanges()
var errs error
addProp := func(prop knowledgebase.Property) error {
vertex := &propertyVertex{
Ref: construct.PropertyRef{Resource: res.ID, Property: prop.Details().Path},
Template: prop,
EdgeRules: make(map[construct.SimpleEdge][]knowledgebase.OperationalRule),
TransformRules: make(map[construct.SimpleEdge]*set.HashedSet[string, knowledgebase.OperationalRule]),
}
errs = errors.Join(errs, changes.AddVertexAndDeps(eval, vertex))
return nil
}
errs = errors.Join(errs, tmpl.LoopProperties(res, addProp))
for _, rule := range tmpl.AdditionalRules {
vertex := &resourceRuleVertex{
Resource: res.ID,
Rule: rule,
hash: rule.Hash(),
}
errs = errors.Join(errs, changes.AddVertexAndDeps(eval, vertex))
}
return changes, errs
}
func (eval *Evaluator) edgeVertices(
edge construct.Edge,
tmpl *knowledgebase.EdgeTemplate,
) (graphChanges, error) {
changes := newChanges()
opVertex := &edgeVertex{
Edge: edge,
Rules: tmpl.OperationalRules,
}
return changes, changes.AddVertexAndDeps(eval, opVertex)
}
func (eval *Evaluator) removeKey(k Key) error {
err := graph_addons.RemoveVertexAndEdges(eval.unevaluated, k)
if err == nil || errors.Is(err, graph.ErrVertexNotFound) {
return graph_addons.RemoveVertexAndEdges(eval.graph, k)
}
return err
}
func (eval *Evaluator) RemoveEdge(source, target construct.ResourceId) error {
g := eval.graph
edge := construct.SimpleEdge{Source: source, Target: target}
pred, err := g.PredecessorMap()
if err != nil {
return err
}
var errs error
checkStates := make(set.Set[Key])
for key := range pred {
v, err := g.Vertex(key)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("could not get vertex for %s: %w", key, err))
continue
}
switch v := v.(type) {
case *propertyVertex:
for vEdge := range v.EdgeRules {
if vEdge == edge {
delete(v.EdgeRules, edge)
}
}
case *edgeVertex:
if v.Edge.Source == edge.Source && v.Edge.Target == edge.Target {
errs = errors.Join(errs, eval.removeKey(v.Key()))
}
case *graphStateVertex:
checkStates.Add(v.Key())
}
}
if errs != nil {
return fmt.Errorf("could not remove edge %s: %w", edge, err)
}
// recompute the predecessors, since we may have removed some edges
pred, err = g.PredecessorMap()
if err != nil {
return err
}
// Clean up any graph state keys that are no longer referenced. They don't do any harm except the performance
// impact of recomputing the dependencies.
for v := range checkStates {
if len(pred[v]) == 0 {
errs = errors.Join(errs, eval.removeKey(v))
}
}
if errs != nil {
return fmt.Errorf("could not clean up graph state keys when removing %s: %w", edge, errs)
}
return nil
}
// RemoveResource removes all edges from the resource. any property references (as [ResourceId] or [PropertyRef])
// to the resource, and finally the resource itself.
func (eval *Evaluator) RemoveResource(id construct.ResourceId) error {
g := eval.graph
pred, err := g.PredecessorMap()
if err != nil {
return err
}
var errs error
checkStates := make(set.Set[Key])
for key := range pred {
v, err := g.Vertex(key)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("could not get vertex for %s: %w", key, err))
continue
}
switch v := v.(type) {
case *propertyVertex:
if v.Ref.Resource == id {
errs = errors.Join(errs, eval.removeKey(v.Key()))
continue
}
for edge := range v.EdgeRules {
if edge.Source == id || edge.Target == id {
delete(v.EdgeRules, edge)
}
}
case *edgeVertex:
if v.Edge.Source == id || v.Edge.Target == id {
errs = errors.Join(errs, eval.removeKey(v.Key()))
}
case *graphStateVertex:
checkStates.Add(v.Key())
case *resourceRuleVertex:
if v.Resource == id {
errs = errors.Join(errs, eval.removeKey(v.Key()))
}
}
}
if errs != nil {
return fmt.Errorf("could not remove resource %s: %w", id, errs)
}
// recompute the predecessors, since we may have removed some edges
pred, err = g.PredecessorMap()
if err != nil {
return err
}
// Clean up any graph state keys that are no longer referenced. They don't do any harm except the performance
// impact of recomputing the dependencies.
for v := range checkStates {
if len(pred[v]) == 0 {
errs = errors.Join(errs, eval.removeKey(v))
}
}
if errs != nil {
return fmt.Errorf("could not clean up graph state keys when removing %s: %w", id, errs)
}
return nil
}
package operational_eval
import (
"errors"
"fmt"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/constraints"
"github.com/klothoplatform/klotho/pkg/engine/operational_rule"
"github.com/klothoplatform/klotho/pkg/engine/solution"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
type edgeVertex struct {
Edge construct.Edge
// Rules are run in order of how they exist in the template so that the order of operations handles the rules inter dependencies
Rules []knowledgebase.OperationalRule
}
func (ev edgeVertex) Key() Key {
return Key{Edge: construct.ToSimpleEdge(ev.Edge)}
}
func (ev *edgeVertex) Dependencies(eval *Evaluator, propCtx dependencyCapturer) error {
data := knowledgebase.DynamicValueData{
Edge: &ev.Edge,
}
var errs error
for _, rule := range ev.Rules {
errs = errors.Join(errs, propCtx.ExecuteOpRule(data, rule))
}
if errs != nil {
return fmt.Errorf(
"could not execute dependencies for edge %s -> %s: %w",
ev.Edge.Source, ev.Edge.Target, errs,
)
}
// NOTE: begin hack - this is to help resolve the api_deployment#Triggers case
// The dependency graph doesn't currently handle this because:
// 1. Expand(api -> lambda) depends on Subnets to determine if it's reusable
// 2. Subnets depends on graphstate hasDownstream(vpc, lambda)
// 3. deploy -> api depends on graphstate allDownstream(integration, api)
// to add the deploy -> integration edges
// 4. deploy -> integration sets #Triggers
//
//
pred, err := eval.graph.PredecessorMap()
if err != nil {
return err
}
propChanges := propCtx.GetChanges()
for src := range propChanges.edges {
isEvaluated, err := eval.isEvaluated(src)
if err == nil && isEvaluated && len(pred[src]) == 0 {
// this is okay, since it has no dependencies then changing it during evaluation
// won't impact anything. Remove the dependency, since we'll handle it in this vertex's
// Evaluate
delete(propChanges.edges, src)
}
}
return nil
}
func (ev *edgeVertex) UpdateFrom(other Vertex) {
if ev == other {
return
}
otherEdge, ok := other.(*edgeVertex)
if !ok {
panic(fmt.Sprintf("cannot merge edge with non-edge vertex: %T", other))
}
if ev.Key().Edge != otherEdge.Key().Edge {
panic(fmt.Sprintf(
"cannot merge edges with different refs: %s != %s",
construct.ToSimpleEdge(ev.Edge),
construct.ToSimpleEdge(otherEdge.Edge),
))
}
ev.Rules = otherEdge.Rules
}
func (ev *edgeVertex) Evaluate(eval *Evaluator) error {
se := construct.ToSimpleEdge(ev.Edge)
cfgCtx := solution.DynamicCtx(eval.Solution)
opCtx := operational_rule.OperationalRuleContext{
Solution: eval.Solution,
Data: knowledgebase.DynamicValueData{
Edge: &ev.Edge,
},
}
var errs error
for _, rule := range ev.Rules {
configRules := rule.ConfigurationRules
rule.ConfigurationRules = nil
if len(rule.Steps) > 0 {
err := opCtx.HandleOperationalRule(rule, constraints.AddConstraintOperator)
if err != nil {
errs = errors.Join(errs, fmt.Errorf(
"could not apply edge %s operational rule: %w",
se, err,
))
continue
}
}
// the configurations that are returned can be executed out of band from the property vertex
// since the property vertex has already been evaluated. This is a hack to get around improper dep ordering
configuration, err := addConfigurationRuleToPropertyVertex(
knowledgebase.OperationalRule{
If: rule.If,
ConfigurationRules: configRules,
},
ev, cfgCtx, opCtx.Data, eval)
if err != nil {
errs = errors.Join(errs, fmt.Errorf(
"could not apply edge %s configuration rule: %w",
se, err,
))
}
rule.Steps = nil
for res, configRules := range configuration {
opCtx.Data.Resource = res
rule.ConfigurationRules = configRules
err := opCtx.HandleOperationalRule(rule, constraints.AddConstraintOperator)
if err != nil {
errs = errors.Join(errs, fmt.Errorf(
"could not apply edge %s (res: %s) operational rule: %w",
se, res, err,
))
continue
}
}
}
if errs != nil {
return errs
}
src, err := eval.Solution.DataflowGraph().Vertex(ev.Edge.Source)
if err != nil {
return err
}
target, err := eval.Solution.DataflowGraph().Vertex(ev.Edge.Target)
if err != nil {
return err
}
delays, err := knowledgebase.ConsumeFromResource(
src,
target,
solution.DynamicCtx(eval.Solution),
)
if err != nil {
return err
}
// we add constrains for the delayed consumption here since their property has not yet been evaluated
c := eval.Solution.Constraints()
for _, delay := range delays {
c.Resources = append(c.Resources, constraints.ResourceConstraint{
Operator: constraints.AddConstraintOperator,
Target: delay.Resource,
Property: delay.PropertyPath,
Value: delay.Value,
})
}
return nil
}
package operational_eval
import (
construct "github.com/klothoplatform/klotho/pkg/construct"
)
type (
graphTestFunc func(construct.Graph) (ReadyPriority, error)
graphStateRepr string
graphStateVertex struct {
repr graphStateRepr
Test graphTestFunc
}
)
func (gv graphStateVertex) Key() Key {
return Key{GraphState: gv.repr}
}
func (gv *graphStateVertex) Dependencies(eval *Evaluator, propCtx dependencyCapturer) error {
return nil
}
func (gv *graphStateVertex) UpdateFrom(other Vertex) {
if gv.repr != other.Key().GraphState {
panic("cannot merge graph states with different reprs")
}
}
func (gv *graphStateVertex) Evaluate(eval *Evaluator) error {
return nil
}
func (gv *graphStateVertex) Ready(eval *Evaluator) (ReadyPriority, error) {
return gv.Test(eval.Solution.DataflowGraph())
}
package operational_eval
import (
"errors"
"fmt"
"strings"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/constraints"
"github.com/klothoplatform/klotho/pkg/engine/operational_rule"
"github.com/klothoplatform/klotho/pkg/engine/path_selection"
"github.com/klothoplatform/klotho/pkg/engine/solution"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/klothoplatform/klotho/pkg/set"
"go.uber.org/zap"
)
//go:generate mockgen -source=./vertex_path_expand.go --destination=../operational_eval/vertex_path_expand_mock_test.go --package=operational_eval
type (
pathExpandVertex struct {
// ExpandEdge is the overall edge that is being expanded
ExpandEdge construct.SimpleEdge
// SatisfactionEdge is the specific edge that was being expanded when the error occurred
SatisfactionEdge construct.SimpleEdge
TempGraph construct.Graph
Satisfication knowledgebase.EdgePathSatisfaction
}
expansionRunner interface {
getExpansionsToRun(v *pathExpandVertex) ([]path_selection.ExpansionInput, error)
handleResultProperties(v *pathExpandVertex, result path_selection.ExpansionResult) error
addSubExpansion(result path_selection.ExpansionResult, expansion path_selection.ExpansionInput, v *pathExpandVertex) error
addResourcesAndEdges(result path_selection.ExpansionResult, expansion path_selection.ExpansionInput, v *pathExpandVertex) error
consumeExpansionProperties(expansion path_selection.ExpansionInput) error
}
pathExpandVertexRunner struct {
Eval *Evaluator
}
)
func (v *pathExpandVertex) Key() Key {
return Key{PathSatisfication: v.Satisfication, Edge: v.SatisfactionEdge}
}
func (v *pathExpandVertex) Evaluate(eval *Evaluator) error {
runner := &pathExpandVertexRunner{Eval: eval}
edgeExpander := &path_selection.EdgeExpand{Ctx: eval.Solution}
return v.runEvaluation(eval, runner, edgeExpander)
}
func (v *pathExpandVertex) runEvaluation(eval *Evaluator, runner expansionRunner, edgeExpander path_selection.EdgeExpander) error {
var errs error
expansions, err := runner.getExpansionsToRun(v)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("could not get expansions to run: %w", err))
}
log := eval.Log().Named("path_expand")
if len(expansions) > 1 && log.Desugar().Core().Enabled(zap.DebugLevel) {
log.Debugf("Expansion %s subexpansions:", v.SatisfactionEdge)
for _, expansion := range expansions {
log.Debugf(" %s -> %s", expansion.SatisfactionEdge.Source.ID, expansion.SatisfactionEdge.Target.ID)
}
}
createExpansionErr := func(err error) error {
return fmt.Errorf("could not run expansion %s -> %s <%s>: %w",
v.SatisfactionEdge.Source, v.SatisfactionEdge.Target, v.Satisfication.Classification, err,
)
}
for _, expansion := range expansions {
result, err := edgeExpander.ExpandEdge(expansion)
if err != nil {
errs = errors.Join(errs, createExpansionErr(err))
continue
}
resultStr, err := expansionResultString(result.Graph, expansion.SatisfactionEdge)
if err != nil {
errs = errors.Join(errs, createExpansionErr(err))
continue
}
if v.Satisfication.Classification != "" {
log.Infof("Satisfied %s for %s through %s", v.Satisfication.Classification, v.SatisfactionEdge, resultStr)
} else {
log.Infof("Satisfied %s -> %s through %s", v.SatisfactionEdge.Source, v.SatisfactionEdge.Target, resultStr)
}
if err := runner.addResourcesAndEdges(result, expansion, v); err != nil {
errs = errors.Join(errs, createExpansionErr(err))
continue
}
if err := runner.addSubExpansion(result, expansion, v); err != nil {
errs = errors.Join(errs, createExpansionErr(err))
continue
}
if err := runner.consumeExpansionProperties(expansion); err != nil {
errs = errors.Join(errs, createExpansionErr(err))
continue
}
// do this after weve added all resources and edges to the sol ctx so that we replace the ids properly
if err := runner.handleResultProperties(v, result); err != nil {
errs = errors.Join(errs, createExpansionErr(err))
continue
}
}
return errs
}
func (v *pathExpandVertex) UpdateFrom(other Vertex) {
otherVertex := other.(*pathExpandVertex)
v.TempGraph = otherVertex.TempGraph
}
// addDepsFromProps checks to see if any properties in `res` match any of the `dependencies`.
// If they do, add a dependency to that property - it may set up a resource that we could reuse,
// depending on the path chosen. This is a conservative dependency, since we don't know which path
// will be chosen.
func (v *pathExpandVertex) addDepsFromProps(
eval *Evaluator,
changes graphChanges,
res construct.ResourceId,
dependencies []construct.ResourceId,
) error {
tmpl, err := eval.Solution.KnowledgeBase().GetResourceTemplate(res)
if err != nil {
return err
}
var errs error
for k, prop := range tmpl.Properties {
// Only consider properties whose type can even accommodate a resource
if !strings.HasPrefix(prop.Type(), "resource") {
continue
}
details := prop.Details()
if details.OperationalRule == nil {
// If the property can't create resources, skip it.
continue
}
ready, err := operational_rule.EvaluateIfCondition(details.OperationalRule.If,
eval.Solution, knowledgebase.DynamicValueData{Resource: res})
if err != nil || !ready {
continue
}
ref := construct.PropertyRef{Resource: res, Property: k}
for _, dep := range dependencies {
if dep == v.SatisfactionEdge.Source || dep == v.SatisfactionEdge.Target {
continue
}
resource, err := eval.Solution.RawView().Vertex(res)
if err != nil {
errs = errors.Join(errs, err)
continue
}
// if this dependency could pass validation for the resources property, consider it as a dependent vertex
if err := prop.Validate(resource, dep, solution.DynamicCtx(eval.Solution)); err == nil {
changes.addEdge(v.Key(), Key{Ref: ref})
}
}
}
return errs
}
// addDepsFromEdge checks to see if the edge's template sets any properties via configuration rules.
// If it does, go through all the existing resources and add an incoming dependency to any that match
// the resource and property from that configuration rule.
func (v *pathExpandVertex) addDepsFromEdge(
eval *Evaluator,
changes graphChanges,
edge construct.Edge,
) error {
kb := eval.Solution.KnowledgeBase()
tmpl := kb.GetEdgeTemplate(edge.Source, edge.Target)
if tmpl == nil {
return nil
}
allRes, err := construct.TopologicalSort(eval.Solution.RawView())
if err != nil {
return err
}
se := construct.Edge{
Source: edge.Source,
Target: edge.Target,
Properties: edge.Properties,
}
se.Source.Name = ""
se.Target.Name = ""
addDepsMatching := func(ref construct.PropertyRef) error {
for _, res := range allRes {
if !ref.Resource.Matches(res) {
continue
}
tmpl, err := kb.GetResourceTemplate(res)
if err != nil {
return err
}
// TODO: Go into nested properties to determine dependencies
if _, hasProp := tmpl.Properties[ref.Property]; hasProp {
actualRef := construct.PropertyRef{
Resource: res,
Property: ref.Property,
}
changes.addEdge(Key{Ref: actualRef}, v.Key())
eval.Log().Named("path_expand").Debugf(
"Adding speculative dependency %s -> %s (matches %s from %s)",
actualRef, v.Key(), ref, se,
)
}
}
return nil
}
dyn := solution.DynamicCtx(eval.Solution)
var errs error
for i, rule := range tmpl.OperationalRules {
for j, cfg := range rule.ConfigurationRules {
var err error
data := knowledgebase.DynamicValueData{Edge: &se}
data.Resource, err = knowledgebase.ExecuteDecodeAsResourceId(dyn, cfg.Resource, data)
// We ignore the error because it just means that we cant resolve the resource yet
// therefore we cant add a dependency on this invocation
if err != nil || data.Resource.IsZero() {
continue
}
// NOTE(gg): does this need to consider `Fields`?
field := cfg.Config.Field
err = dyn.ExecuteDecode(field, data, &field)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("could not decode field for rule %d cfg %d: %w", i, j, err))
continue
}
if field == "" {
continue
}
ref := construct.PropertyRef{Resource: data.Resource, Property: field}
errs = errors.Join(errs, addDepsMatching(ref))
}
}
return errs
}
// getDepsForPropertyRef takes a property reference and recurses down until the property is not filled in on the resource
// When we reach resources with missing property references, we know they are the property vertex keys we must depend on
func getDepsForPropertyRef(
sol solution.Solution,
res construct.ResourceId,
propertyRef string,
) set.Set[Key] {
if propertyRef == "" {
return nil
}
keys := make(set.Set[Key])
cfgCtx := solution.DynamicCtx(sol)
currResources := []construct.ResourceId{res}
parts := strings.Split(propertyRef, "#")
for _, part := range parts {
var nextResources []construct.ResourceId
for _, currResource := range currResources {
keys.Add(Key{Ref: construct.PropertyRef{Resource: currResource, Property: part}})
val, err := cfgCtx.FieldValue(part, currResource)
if err != nil {
// The field hasn't resolved yet. Skip it for now, future calls to dependencies will pick it up.
continue
}
if id, ok := val.(construct.ResourceId); ok {
nextResources = append(nextResources, id)
} else if ids, ok := val.([]construct.ResourceId); ok {
nextResources = append(nextResources, ids...)
}
}
currResources = nextResources
}
return keys
}
func (v *pathExpandVertex) Dependencies(eval *Evaluator, propCtx dependencyCapturer) error {
changes := propCtx.GetChanges()
srcKey := v.Key()
changes.addEdges(srcKey, getDepsForPropertyRef(eval.Solution, v.SatisfactionEdge.Source, v.Satisfication.Source.PropertyReference))
changes.addEdges(srcKey, getDepsForPropertyRef(eval.Solution, v.SatisfactionEdge.Target, v.Satisfication.Target.PropertyReference))
// if we have a temp graph we can analyze the paths in it for possible dependencies on property vertices
// if we dont, we should return what we currently have
// This has to be run after we analyze the refs used in path expansion to make sure the operational rules
// dont create other resources that need to be operated on in the path expand vertex
if v.TempGraph == nil {
return nil
}
var errs error
srcDeps, err := construct.AllDownstreamDependencies(v.TempGraph, v.SatisfactionEdge.Source)
if err != nil {
return err
}
errs = errors.Join(errs, v.addDepsFromProps(eval, changes, v.SatisfactionEdge.Source, srcDeps))
targetDeps, err := construct.AllUpstreamDependencies(v.TempGraph, v.SatisfactionEdge.Target)
if err != nil {
return err
}
errs = errors.Join(errs, v.addDepsFromProps(eval, changes, v.SatisfactionEdge.Target, targetDeps))
if errs != nil {
return errs
}
edges, err := v.TempGraph.Edges()
if err != nil {
return err
}
for _, edge := range edges {
errs = errors.Join(errs, v.addDepsFromEdge(eval, changes, edge))
}
return errs
}
func (runner *pathExpandVertexRunner) getExpansionsToRun(v *pathExpandVertex) ([]path_selection.ExpansionInput, error) {
eval := runner.Eval
var errs error
sourceRes, err := eval.Solution.RawView().Vertex(v.SatisfactionEdge.Source)
if err != nil {
return nil, fmt.Errorf("could not find source resource %s: %w", v.SatisfactionEdge.Source, err)
}
targetRes, err := eval.Solution.RawView().Vertex(v.SatisfactionEdge.Target)
if err != nil {
return nil, fmt.Errorf("could not find target resource %s: %w", v.SatisfactionEdge.Target, err)
}
edge := construct.ResourceEdge{Source: sourceRes, Target: targetRes}
expansions, err := path_selection.DeterminePathSatisfactionInputs(eval.Solution, v.Satisfication, edge)
if err != nil {
errs = errors.Join(errs, err)
}
requireFullBuild := sourceRes.Imported || targetRes.Imported
result := make([]path_selection.ExpansionInput, len(expansions))
for i, expansion := range expansions {
input := path_selection.ExpansionInput{
ExpandEdge: v.ExpandEdge,
SatisfactionEdge: expansion.SatisfactionEdge,
Classification: expansion.Classification,
TempGraph: v.TempGraph,
}
if expansion.SatisfactionEdge.Source != edge.Source || expansion.SatisfactionEdge.Target != edge.Target {
simple := construct.SimpleEdge{Source: expansion.SatisfactionEdge.Source.ID, Target: expansion.SatisfactionEdge.Target.ID}
tempGraph, err := path_selection.BuildPathSelectionGraph(
runner.Eval.Solution.Context(),
simple,
eval.Solution.KnowledgeBase(),
expansion.Classification,
requireFullBuild,
)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("error getting expansions to run. could not build path selection graph: %w", err))
continue
}
input.TempGraph = tempGraph
}
result[i] = input
}
return result, errs
}
func (runner *pathExpandVertexRunner) addResourcesAndEdges(
result path_selection.ExpansionResult,
expansion path_selection.ExpansionInput,
v *pathExpandVertex,
) error {
eval := runner.Eval
op := eval.Solution.OperationalView()
adj, err := result.Graph.AdjacencyMap()
if err != nil {
return err
}
// Copy the edge data from a constraint that matches the expansion
// which enables setting, for example, lambda -> s3_bucket readonly,
// then applying that to the iam_role -> s3_bucket edge, which is what actually
// does something.
var data construct.EdgeData
for _, ec := range runner.Eval.Solution.Constraints().Edges {
if ec.Target.Source.Matches(v.SatisfactionEdge.Source) && ec.Target.Target.Matches(v.SatisfactionEdge.Target) {
data = ec.Data
break
}
}
if len(adj) > 2 {
_, err := op.Edge(v.SatisfactionEdge.Source, v.SatisfactionEdge.Target)
if err == nil {
if err := op.RemoveEdge(v.SatisfactionEdge.Source, v.SatisfactionEdge.Target); err != nil {
return err
}
} else if !errors.Is(err, graph.ErrEdgeNotFound) {
return err
}
} else if len(adj) == 2 {
err = op.AddEdge(
expansion.SatisfactionEdge.Source.ID,
expansion.SatisfactionEdge.Target.ID,
graph.EdgeData(data),
)
if err != nil && !errors.Is(err, graph.ErrEdgeAlreadyExists) {
// NOTE(gg): See note below why we're ignoring/allowing already exists errors.
return err
}
}
// Once the path is selected & expanded, first add all the resources to the graph
var errs error
for pathId := range adj {
_, err := op.Vertex(pathId)
switch {
case errors.Is(err, graph.ErrVertexNotFound):
res, err := result.Graph.Vertex(pathId)
if err != nil {
errs = errors.Join(errs, err)
continue
}
err = op.AddVertex(res)
if err != nil {
errs = errors.Join(errs, err)
continue
}
case err != nil:
errs = errors.Join(errs, err)
continue
}
}
if errs != nil {
return errs
}
// After all the resources, then add all the dependencies
for _, edgeMap := range adj {
for _, edge := range edgeMap {
_, err := op.Edge(edge.Source, edge.Target)
switch {
case errors.Is(err, graph.ErrEdgeNotFound):
err := op.AddEdge(edge.Source, edge.Target, graph.EdgeData(data))
if err != nil {
errs = errors.Join(errs, err)
continue
}
case err != nil:
errs = errors.Join(errs, err)
continue
default:
// NOTE(gg): we could update the edge to set the edge data, but if that edge already existed it either:
// 1) was added as a different path expansion process (another classification)
// 2) was not added as part of path expansion
// If (1), then the edge data is already correct
// If (2), then the edge isn't unique to the expansion, so don't copy the edge data. This may not be the
// correct behaviour, but without a use-case it's hard to tell.
}
}
}
if errs != nil {
return errs
}
return nil
}
func (runner *pathExpandVertexRunner) addSubExpansion(
result path_selection.ExpansionResult,
expansion path_selection.ExpansionInput,
v *pathExpandVertex,
) error {
// add sub expansions returned from the result, only for the classification of this expansion
eval := runner.Eval
changes := newChanges()
for _, subExpand := range result.Edges {
pathSatisfications, err := eval.Solution.KnowledgeBase().GetPathSatisfactionsFromEdge(subExpand.Source, subExpand.Target)
if err != nil {
return fmt.Errorf("could not get path satisfications for sub expansion %s -> %s: %w",
subExpand.Source, subExpand.Target, err)
}
for _, satisfication := range pathSatisfications {
if satisfication.Classification == v.Satisfication.Classification {
// we cannot evaluate these vertices immediately because we are unsure if their dependencies have settled
changes.addNode(&pathExpandVertex{
SatisfactionEdge: construct.SimpleEdge{Source: subExpand.Source, Target: subExpand.Target},
TempGraph: expansion.TempGraph,
Satisfication: satisfication,
})
}
}
}
return eval.enqueue(changes)
}
func (runner *pathExpandVertexRunner) consumeExpansionProperties(expansion path_selection.ExpansionInput) error {
delays, err := knowledgebase.ConsumeFromResource(
expansion.SatisfactionEdge.Source,
expansion.SatisfactionEdge.Target,
solution.DynamicCtx(runner.Eval.Solution),
)
if err != nil {
return err
}
// we add constrains for the delayed consumption here since their property has not yet been evaluated
c := runner.Eval.Solution.Constraints()
for _, delay := range delays {
c.Resources = append(c.Resources, constraints.ResourceConstraint{
Operator: constraints.AddConstraintOperator,
Target: delay.Resource,
Property: delay.PropertyPath,
Value: delay.Value,
})
}
return nil
}
// handleProperties
func (runner *pathExpandVertexRunner) handleResultProperties(
v *pathExpandVertex,
result path_selection.ExpansionResult,
) error {
eval := runner.Eval
adj, err := result.Graph.AdjacencyMap()
if err != nil {
return err
}
pred, err := result.Graph.PredecessorMap()
if err != nil {
return err
}
handleResultProperties := func(
res *construct.Resource,
rt *knowledgebase.ResourceTemplate,
resources map[construct.ResourceId]graph.Edge[construct.ResourceId],
Direction knowledgebase.Direction,
) error {
var errs error
for target := range resources {
targetRes, err := result.Graph.Vertex(target)
if err != nil {
errs = errors.Join(errs, err)
continue
}
errs = errors.Join(errs, rt.LoopProperties(res, func(prop knowledgebase.Property) error {
opRuleCtx := operational_rule.OperationalRuleContext{
Solution: eval.Solution,
Property: prop,
Data: knowledgebase.DynamicValueData{Resource: res.ID},
}
details := prop.Details()
if details.OperationalRule == nil || len(details.OperationalRule.Step.Resources) == 0 {
return nil
}
step := details.OperationalRule.Step
for _, selector := range step.Resources {
if step.Direction == Direction {
canUse, err := selector.CanUse(
solution.DynamicCtx(eval.Solution),
knowledgebase.DynamicValueData{Resource: res.ID},
targetRes,
)
if canUse && err == nil && !res.Imported {
err = opRuleCtx.SetField(res, targetRes, step)
if err != nil {
errs = errors.Join(errs, err)
}
}
}
}
return nil
}))
}
return errs
}
var errs error
for id, downstreams := range adj {
oldId := id
rt, err := eval.Solution.KnowledgeBase().GetResourceTemplate(id)
if err != nil {
errs = errors.Join(errs, err)
continue
}
res, err := eval.Solution.RawView().Vertex(id)
if err != nil {
errs = errors.Join(errs, err)
continue
}
errs = errors.Join(errs, handleResultProperties(res, rt, downstreams, knowledgebase.DirectionDownstream))
errs = errors.Join(errs, handleResultProperties(res, rt, pred[id], knowledgebase.DirectionUpstream))
if oldId != res.ID {
errs = errors.Join(errs, eval.UpdateId(oldId, res.ID))
}
}
return errs
}
func expansionResultString(result construct.Graph, dep construct.ResourceEdge) (string, error) {
sb := new(strings.Builder)
handled := make(set.Set[construct.SimpleEdge])
path, err := graph.ShortestPathStable(result, dep.Source.ID, dep.Target.ID, construct.ResourceIdLess)
if err != nil {
return "", fmt.Errorf("expansion result does not contain path from %s to %s: %w", dep.Source.ID, dep.Target.ID, err)
}
for i, res := range path {
if i == 0 {
sb.WriteString(res.String())
continue
}
fmt.Fprintf(sb, " -> %s", res)
handled.Add(construct.SimpleEdge{Source: path[i-1], Target: res})
}
edges, err := result.Edges()
if err != nil {
return sb.String(), err
}
for _, e := range edges {
se := construct.SimpleEdge{Source: e.Source, Target: e.Target}
if handled.Contains(se) {
continue
}
fmt.Fprintf(sb, ", %s", se.String())
}
return sb.String(), nil
}
package operational_eval
import (
"errors"
"fmt"
"strings"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/constraints"
"github.com/klothoplatform/klotho/pkg/engine/operational_rule"
"github.com/klothoplatform/klotho/pkg/engine/solution"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/klothoplatform/klotho/pkg/set"
)
type (
propertyVertex struct {
Ref construct.PropertyRef
Template knowledgebase.Property
EdgeRules map[construct.SimpleEdge][]knowledgebase.OperationalRule
// TransformRules are Rules found in edge templates where the property depends on itself, thus transforming the existing value
// Transform rules are initially a part of the edge vertex
// when a TransformRule is found it is removed from the EdgeRules and added to the TransformRules
TransformRules map[construct.SimpleEdge]*set.HashedSet[string, knowledgebase.OperationalRule]
ResourceRules map[string][]knowledgebase.OperationalRule
}
)
func (prop propertyVertex) Key() Key {
return Key{Ref: prop.Ref}
}
func (prop *propertyVertex) Dependencies(eval *Evaluator, propCtx dependencyCapturer) error {
res, err := eval.Solution.RawView().Vertex(prop.Ref.Resource)
if err != nil {
return fmt.Errorf("could not get resource for property vertex dependency calculation %s: %w", prop.Ref, err)
}
path, err := res.PropertyPath(prop.Ref.Property)
if err != nil {
return fmt.Errorf("could not get property path for %s: %w", prop.Ref, err)
}
resData := knowledgebase.DynamicValueData{Resource: prop.Ref.Resource, Path: path}
// Template can be nil when checking for dependencies from a propertyVertex when adding an edge template
if prop.Template != nil {
_, _ = prop.Template.GetDefaultValue(propCtx, resData)
details := prop.Template.Details()
if opRule := details.OperationalRule; opRule != nil {
if err := propCtx.ExecutePropertyRule(resData, *opRule); err != nil {
return fmt.Errorf("could not execute resource operational rule for %s: %w", prop.Ref, err)
}
}
}
if prop.shouldEvalEdges(eval.Solution.Constraints().Resources) {
current_edges := make(map[Key]set.Set[Key])
for k, v := range propCtx.GetChanges().edges {
current_edges[k] = v
}
for _, edge := range construct.EdgeKeys(prop.EdgeRules) {
rule := prop.EdgeRules[edge]
resEdge, err := eval.Solution.RawView().Edge(edge.Source, edge.Target)
if err != nil {
return fmt.Errorf("could not get edge for property vertex dependency calculation %s: %w", prop.Ref, err)
}
keyEdge := construct.ResourceEdgeToKeyEdge(resEdge)
edgeData := knowledgebase.DynamicValueData{
Resource: prop.Ref.Resource,
Edge: &keyEdge,
}
var corrected_edge_rules []knowledgebase.OperationalRule
for _, opRule := range rule {
addRule := true
if err := propCtx.ExecuteOpRule(edgeData, opRule); err != nil {
return fmt.Errorf("could not execute edge operational rule for %s: %w", prop.Ref, err)
}
// Analyze the changes to ensure there are no self dependencies
// If there are then we want to label the operational rule as a transform rule to be operated on at the end
curr_deps := propCtx.GetChanges().edges[prop.Key()]
existing_deps := current_edges[prop.Key()]
for v := range curr_deps {
if v == prop.Key() && !existing_deps.Contains(v) {
current_set := prop.TransformRules[edge]
if current_set == nil {
current_set = &set.HashedSet[string, knowledgebase.OperationalRule]{
Hasher: func(s knowledgebase.OperationalRule) string {
return fmt.Sprintf("%v", s)
},
}
}
current_set.Add(opRule)
prop.TransformRules[edge] = current_set
propCtx.GetChanges().edges[prop.Key()].Remove(v)
addRule = false
}
}
if addRule {
corrected_edge_rules = append(corrected_edge_rules, opRule)
}
}
prop.EdgeRules[edge] = corrected_edge_rules
}
}
return nil
}
func (prop *propertyVertex) UpdateFrom(otherV Vertex) {
if prop == otherV {
return
}
other, ok := otherV.(*propertyVertex)
if !ok {
panic(fmt.Sprintf("cannot merge property with non-property vertex: %T", otherV))
}
if prop.Ref != other.Ref {
panic(fmt.Sprintf("cannot merge properties with different refs: %s != %s", prop.Ref, other.Ref))
}
if prop.Template == nil {
prop.Template = other.Template
}
if prop.EdgeRules == nil {
prop.EdgeRules = make(map[construct.SimpleEdge][]knowledgebase.OperationalRule)
}
for edge, rules := range other.EdgeRules {
if _, ok := prop.EdgeRules[edge]; ok {
// already have rules for this edge, don't duplicate them
continue
}
prop.EdgeRules[edge] = rules
}
}
func (v *propertyVertex) Evaluate(eval *Evaluator) error {
sol := eval.Solution
res, err := sol.RawView().Vertex(v.Ref.Resource)
if err != nil {
return fmt.Errorf("could not get resource to evaluate %s: %w", v.Ref, err)
}
path, err := res.PropertyPath(v.Ref.Property)
if err != nil {
return fmt.Errorf("could not get property path for %s: %w", v.Ref, err)
}
dynData := knowledgebase.DynamicValueData{Resource: res.ID, Path: path, GlobalTag: eval.Solution.GlobalTag()}
if err := v.evaluateConstraints(
&solution.Configurer{Ctx: sol},
solution.DynamicCtx(sol),
res,
sol.Constraints().Resources,
dynData,
); err != nil {
return err
}
opCtx := operational_rule.OperationalRuleContext{
Solution: sol,
Property: v.Template,
Data: dynData,
}
if err := v.evaluateResourceOperational(&opCtx); err != nil {
return err
}
if v.shouldEvalEdges(eval.Solution.Constraints().Resources) {
if err := v.evaluateEdgeOperational(eval, res, &opCtx); err != nil {
return err
}
}
if err := v.evaluateTransforms(res, &opCtx); err != nil {
return err
}
if err := eval.UpdateId(v.Ref.Resource, res.ID); err != nil {
return err
}
propertyType := v.Template.Type()
if strings.HasPrefix(propertyType, "list") || strings.HasPrefix(propertyType, "set") || strings.HasPrefix(propertyType, "map") {
property, err := res.GetProperty(v.Ref.Property)
if err != nil {
return fmt.Errorf("could not get property %s on resource %s: %w", v.Ref.Property, v.Ref.Resource, err)
}
if property != nil {
err = eval.cleanupPropertiesSubVertices(v.Ref, res)
if err != nil {
return fmt.Errorf("could not cleanup sub vertices for %s: %w", v.Ref, err)
}
}
// If we have modified a list or set we want to re add the resource to be evaluated
// so the nested fields are ensured to be set if required
err = eval.AddResources(res)
if err != nil {
return fmt.Errorf("could not add resource %s to be re-evaluated: %w", res.ID, err)
}
}
// Now that the vertex is evaluated, we will check it for validity and record our decision
val, err := res.GetProperty(v.Ref.Property)
if err != nil {
return fmt.Errorf("error while validating resource property: could not get property %s on resource %s: %w", v.Ref.Property, v.Ref.Resource, err)
}
err = v.Template.Validate(res, val, solution.DynamicCtx(eval.Solution))
eval.Solution.RecordDecision(solution.PropertyValidationDecision{
Resource: v.Ref.Resource,
Property: v.Template,
Value: val,
Error: err,
})
return nil
}
func (v *propertyVertex) evaluateConstraints(
rc solution.ResourceConfigurer,
ctx knowledgebase.DynamicValueContext,
res *construct.Resource,
rcs []constraints.ResourceConstraint,
dynData knowledgebase.DynamicValueData,
) error {
var setConstraint constraints.ResourceConstraint
var addConstraints []constraints.ResourceConstraint
for _, c := range rcs {
if c.Target != res.ID || c.Property != v.Ref.Property {
continue
}
if c.Operator == constraints.EqualsConstraintOperator {
setConstraint = c
continue
}
addConstraints = append(addConstraints, c)
}
currentValue, err := res.GetProperty(v.Ref.Property)
if err != nil {
return fmt.Errorf("could not get current value for %s: %w", v.Ref, err)
}
var defaultVal any
if currentValue == nil && !res.Imported && setConstraint.Operator == "" {
defaultVal, err = v.Template.GetDefaultValue(ctx, dynData)
if err != nil {
return fmt.Errorf("could not get default value for %s: %w", v.Ref, err)
}
}
if currentValue == nil && setConstraint.Operator == "" && v.Template != nil && defaultVal != nil && !res.Imported {
err = rc.ConfigureResource(
res,
knowledgebase.Configuration{Field: v.Ref.Property, Value: defaultVal},
dynData,
constraints.EqualsConstraintOperator,
false,
)
if err != nil {
return fmt.Errorf("could not set default value for %s: %w", v.Ref, err)
}
} else if setConstraint.Operator != "" {
err = rc.ConfigureResource(
res,
knowledgebase.Configuration{Field: v.Ref.Property, Value: setConstraint.Value},
dynData,
constraints.EqualsConstraintOperator,
true,
)
if err != nil {
return fmt.Errorf("could not apply initial constraint for %s: %w", v.Ref, err)
}
}
dynData.Resource = res.ID // Update in case the property changes the ID
var errs error
for _, c := range addConstraints {
if c.Operator == constraints.EqualsConstraintOperator {
continue
}
errs = errors.Join(errs, rc.ConfigureResource(
res,
knowledgebase.Configuration{Field: v.Ref.Property, Value: c.Value},
dynData,
c.Operator,
true,
))
dynData.Resource = res.ID
}
if errs != nil {
return fmt.Errorf("could not apply constraints for %s: %w", v.Ref, errs)
}
return nil
}
func (v *propertyVertex) evaluateResourceOperational(
opCtx operational_rule.OpRuleHandler,
) error {
if v.Template == nil || v.Template.Details().OperationalRule == nil {
return nil
}
err := opCtx.HandlePropertyRule(*v.Template.Details().OperationalRule)
if err != nil {
return fmt.Errorf("could not apply operational rule for %s: %w", v.Ref, err)
}
return nil
}
// shouldEvalEdges is used as common logic for whether edges should be evaluated and is used in dependency
// calculation and in the Evaluate method.
func (v *propertyVertex) shouldEvalEdges(cs []constraints.ResourceConstraint) bool {
if knowledgebase.IsCollectionProperty(v.Template) {
return true
}
for _, c := range cs {
if c.Target != v.Ref.Resource || c.Property != v.Ref.Property {
continue
}
// NOTE(gg): does operator even matter here? If it's not a collection,
// what does an 'add' mean? Should it allow edges to overwrite?
if c.Operator == constraints.EqualsConstraintOperator {
return false
}
}
return true
}
func (v *propertyVertex) evaluateEdgeOperational(
eval *Evaluator,
res *construct.Resource,
opCtx operational_rule.OpRuleHandler,
) error {
oldId := v.Ref.Resource
var errs error
for _, edge := range construct.EdgeKeys(v.EdgeRules) {
for _, rule := range v.EdgeRules[edge] {
// In case one of the previous rules changed the ID, update it
edge = UpdateEdgeId(edge, oldId, res.ID)
resEdge, err := eval.Solution.RawView().Edge(edge.Source, edge.Target)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("could not get edge from graph for %s: %w", edge, err))
continue
}
keyEdge := construct.ResourceEdgeToKeyEdge(resEdge)
opCtx.SetData(knowledgebase.DynamicValueData{
Resource: res.ID,
Edge: &keyEdge,
})
err = opCtx.HandleOperationalRule(rule, constraints.AddConstraintOperator)
if err != nil {
errs = errors.Join(errs, fmt.Errorf(
"could not apply edge %s -> %s operational rule for %s: %w",
edge.Source, edge.Target, v.Ref, err,
))
}
}
}
return errs
}
func (v *propertyVertex) evaluateTransforms(
res *construct.Resource,
opCtx operational_rule.OpRuleHandler,
) error {
var errs error
oldId := v.Ref.Resource
for _, edge := range construct.EdgeKeys(v.TransformRules) {
rules := v.TransformRules[edge].ToSlice()
for _, rule := range rules {
// In case one of the previous rules changed the ID, update it
edge = UpdateEdgeId(edge, oldId, res.ID)
opCtx.SetData(knowledgebase.DynamicValueData{
Resource: res.ID,
Edge: &graph.Edge[construct.ResourceId]{Source: edge.Source, Target: edge.Target},
})
err := opCtx.HandleOperationalRule(rule, constraints.EqualsConstraintOperator)
if err != nil {
errs = errors.Join(errs, fmt.Errorf(
"could not apply transform rule for %s: %w",
v.Ref, err,
))
}
}
}
return errs
}
func (v *propertyVertex) Ready(eval *Evaluator) (ReadyPriority, error) {
if v.Template == nil {
// wait until we have a template
return NotReadyMax, nil
}
if v.Template.Details().OperationalRule != nil {
// operational rules should run as soon as possible to create any resources they need
return ReadyNow, nil
}
ptype := v.Template.Type()
if strings.HasPrefix(ptype, "list") || strings.HasPrefix(ptype, "set") {
// never sure when a list/set is ready - it'll just be appended to by edges through
// `v.EdgeRules`
return NotReadyHigh, nil
}
if strings.HasPrefix(ptype, "map") && len(v.Template.SubProperties()) == 0 {
// maps without sub-properties (ie, not objects) are also appended to by edges
return NotReadyHigh, nil
}
// properties that have values set via edge rules dont' have default values
res, err := eval.Solution.RawView().Vertex(v.Ref.Resource)
if err != nil {
return NotReadyHigh, fmt.Errorf("could not get resource for property vertex dependency calculation %s: %w", v.Ref, err)
}
path, err := res.PropertyPath(v.Ref.Property)
if err != nil {
return NotReadyHigh, fmt.Errorf("could not get property path for %s: %w", v.Ref, err)
}
defaultVal, err := v.Template.GetDefaultValue(solution.DynamicCtx(eval.Solution),
knowledgebase.DynamicValueData{Resource: v.Ref.Resource, Path: path})
if err != nil {
return NotReadyMid, nil
}
if defaultVal != nil {
return ReadyNow, nil
}
// for non-list/set types, once an edge is here to set the value, it can be run
if len(v.EdgeRules) > 0 {
return ReadyNow, nil
}
return NotReadyMid, nil
}
// addConfigurationRuleToPropertyVertex adds a configuration rule to a property vertex
// if the vertex parameter is a edgeVertex or resourceRuleVertex, it will add the rule to the
// appropriate property vertex and field on the property vertex.
//
// The method returns a map of rules which can be evaluated immediately, and an error if any
func addConfigurationRuleToPropertyVertex(
rule knowledgebase.OperationalRule,
v Vertex,
cfgCtx knowledgebase.DynamicValueContext,
data knowledgebase.DynamicValueData,
eval *Evaluator,
) (map[construct.ResourceId][]knowledgebase.ConfigurationRule, error) {
configuration := make(map[construct.ResourceId][]knowledgebase.ConfigurationRule)
log := eval.Log()
pred, err := eval.graph.PredecessorMap()
if err != nil {
return configuration, err
}
var errs error
for _, config := range rule.ConfigurationRules {
var ref construct.PropertyRef
err := cfgCtx.ExecuteDecode(config.Resource, data, &ref.Resource)
if err != nil {
errs = errors.Join(errs, fmt.Errorf(
"could not decode resource for %s: %w",
config.Resource, err,
))
continue
}
err = cfgCtx.ExecuteDecode(config.Config.Field, data, &ref.Property)
if err != nil {
errs = errors.Join(errs, fmt.Errorf(
"could not decode property for %s: %w",
config.Config.Field, err,
))
continue
}
key := Key{Ref: ref}
vertex, err := eval.graph.Vertex(key)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("could not attempt to get existing vertex for %s: %w", ref, err))
continue
}
_, unevalErr := eval.unevaluated.Vertex(key)
if errors.Is(unevalErr, graph.ErrVertexNotFound) {
var evalDeps []string
for dep := range pred[key] {
depEvaled, err := eval.isEvaluated(dep)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("could not check if %s is evaluated: %w", dep, err))
continue
}
if depEvaled {
evalDeps = append(evalDeps, `"`+dep.String()+`"`)
}
}
if len(evalDeps) == 0 {
configuration[ref.Resource] = append(configuration[ref.Resource], config)
log.Debugf("Allowing config on %s to be evaluated due to no dependents", key)
} else {
errs = errors.Join(errs, fmt.Errorf(
"cannot add rules to evaluated node %s: evaluated dependents: %s",
ref, strings.Join(evalDeps, ", "),
))
}
continue
} else if unevalErr != nil {
errs = errors.Join(errs, fmt.Errorf("could not get existing unevaluated vertex for %s: %w", ref, err))
continue
}
pv, ok := vertex.(*propertyVertex)
if !ok {
errs = errors.Join(errs,
fmt.Errorf("existing vertex for %s is not a property vertex", ref),
)
}
switch v := v.(type) {
case *edgeVertex:
edge := construct.ToSimpleEdge(v.Edge)
pv.EdgeRules[edge] = append(pv.EdgeRules[edge], knowledgebase.OperationalRule{
If: rule.If,
ConfigurationRules: []knowledgebase.ConfigurationRule{config},
})
default:
errs = errors.Join(errs,
fmt.Errorf("existing vertex for %s is not able to add configuration rules to property vertex", ref),
)
}
}
return configuration, errs
}
package operational_eval
import (
"errors"
"fmt"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/constraints"
"github.com/klothoplatform/klotho/pkg/engine/operational_rule"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
type (
resourceRuleVertex struct {
Resource construct.ResourceId
Rule knowledgebase.AdditionalRule
hash string
}
)
func (v resourceRuleVertex) Key() Key {
return Key{Ref: construct.PropertyRef{Resource: v.Resource}, RuleHash: v.hash}
}
func (v *resourceRuleVertex) Dependencies(eval *Evaluator, propCtx dependencyCapturer) error {
resData := knowledgebase.DynamicValueData{Resource: v.Resource}
var errs error
errs = errors.Join(errs, propCtx.ExecuteOpRule(resData, knowledgebase.OperationalRule{
If: v.Rule.If,
Steps: v.Rule.Steps,
}))
if errs != nil {
return fmt.Errorf("could not execute %s: %w", v.Key(), errs)
}
return nil
}
func (v *resourceRuleVertex) UpdateFrom(other Vertex) {
if v == other {
return
}
otherRule, ok := other.(*resourceRuleVertex)
if !ok {
panic(fmt.Sprintf("cannot merge edge with non-edge vertex: %T", other))
}
if v.Resource != otherRule.Resource {
panic(fmt.Sprintf("cannot merge resource rule with different refs: %s != %s", v.Resource, otherRule.Resource))
}
v.Rule = otherRule.Rule
}
func (v *resourceRuleVertex) Evaluate(eval *Evaluator) error {
sol := eval.Solution
opCtx := operational_rule.OperationalRuleContext{
Solution: sol,
Data: knowledgebase.DynamicValueData{Resource: v.Resource},
}
if err := v.evaluateResourceRule(&opCtx, eval); err != nil {
return err
}
res, err := sol.RawView().Vertex(v.Resource)
if err != nil {
return fmt.Errorf("could not get resource to evaluate %s: %w", v.Resource, err)
}
if err := eval.UpdateId(v.Resource, res.ID); err != nil {
return err
}
return nil
}
func (v *resourceRuleVertex) evaluateResourceRule(
opCtx operational_rule.OpRuleHandler,
eval *Evaluator,
) error {
err := opCtx.HandleOperationalRule(knowledgebase.OperationalRule{
If: v.Rule.If,
Steps: v.Rule.Steps,
}, constraints.AddConstraintOperator)
if err != nil {
return fmt.Errorf(
"could not apply resource %s operational rule: %w",
v.Resource, err,
)
}
return nil
}
package operational_rule
import (
"errors"
"fmt"
"sort"
"github.com/dominikbraun/graph"
"github.com/google/uuid"
"github.com/klothoplatform/klotho/pkg/collectionutil"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/solution"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/klothoplatform/klotho/pkg/set"
)
type (
operationalResourceAction struct {
Step knowledgebase.OperationalStep
CurrentIds []construct.ResourceId
ruleCtx OperationalRuleContext
numNeeded int
}
)
func (action *operationalResourceAction) handleOperationalResourceAction(resource *construct.Resource) error {
if action.numNeeded == 0 {
return nil
}
if action.Step.Unique && action.numNeeded > 0 {
err := action.createUniqueResources(resource)
if err != nil {
return fmt.Errorf("error during operational resource action while creating unique resources: %w", err)
}
return nil
}
// we want the negative and positive case to trigger this so you can specify -1 as all available
if action.numNeeded != 0 {
err := action.useAvailableResources(resource)
if err != nil {
return fmt.Errorf("error during operational resource action while using available resources: %w", err)
}
}
for action.numNeeded > 0 {
priorityType, selector, err := action.getPriorityResourceType()
if err != nil {
return fmt.Errorf("cannot create resources to satisfy operational step: no resource types found for step: %w", err)
}
err = action.createResource(priorityType, selector, resource)
if err != nil {
return err
}
}
return nil
}
func (action *operationalResourceAction) createUniqueResources(resource *construct.Resource) error {
priorityType, selector, err := action.getPriorityResourceType()
if err != nil {
return err
}
// Lets check to see if the unique resource was created by some other process
// it must be directly up/downstream and have no other dependencies in that direction
var ids []construct.ResourceId
if action.Step.Direction == knowledgebase.DirectionDownstream {
ids, err = solution.Downstream(action.ruleCtx.Solution, resource.ID, knowledgebase.ResourceDirectLayer)
if err != nil {
return err
}
} else {
ids, err = solution.Upstream(action.ruleCtx.Solution, resource.ID, knowledgebase.ResourceDirectLayer)
if err != nil {
return err
}
}
for _, id := range ids {
if priorityType.Matches(id) {
var uids []construct.ResourceId
if action.Step.Direction == knowledgebase.DirectionUpstream {
uids, err = solution.Downstream(action.ruleCtx.Solution, id, knowledgebase.ResourceDirectLayer)
if err != nil {
return err
}
} else {
uids, err = solution.Upstream(action.ruleCtx.Solution, id, knowledgebase.ResourceDirectLayer)
if err != nil {
return err
}
}
if len(uids) == 1 && uids[0] == resource.ID {
res, err := action.ruleCtx.Solution.RawView().Vertex(id)
if err != nil {
return err
}
if action.numNeeded > 0 {
err := action.ruleCtx.addDependencyForDirection(action.Step, resource, res)
if err != nil {
return err
}
action.numNeeded--
if action.numNeeded == 0 {
break
}
}
}
}
}
for action.numNeeded > 0 {
err := action.createResource(priorityType, selector, resource)
if err != nil {
return err
}
}
return nil
}
func (action *operationalResourceAction) useAvailableResources(resource *construct.Resource) error {
configCtx := solution.DynamicCtx(action.ruleCtx.Solution)
availableResources := make(set.Set[*construct.Resource])
edges, err := action.ruleCtx.Solution.DataflowGraph().Edges()
if err != nil {
return err
}
resources, err := construct.TopologicalSort(action.ruleCtx.Solution.RawView())
if err != nil {
return err
}
// Next we will loop through and try to use available resources if the unique flag is not set
for _, resourceSelector := range action.Step.Resources {
ids, err := resourceSelector.ExtractResourceIds(configCtx, action.ruleCtx.Data)
if err != nil {
return err
}
if len(ids) == 0 {
continue
}
// because there can be multiple types if we only have classifications on the resource selector we want to loop over all ids
for _, id := range ids {
// if there is no functional path for the id then we can skip it since we know its not available to satisfy a valid graph
if action.Step.Direction == knowledgebase.DirectionDownstream &&
!action.ruleCtx.Solution.KnowledgeBase().HasFunctionalPath(resource.ID, id) {
continue
} else if action.Step.Direction == knowledgebase.DirectionUpstream &&
!action.ruleCtx.Solution.KnowledgeBase().HasFunctionalPath(id, resource.ID) {
continue
}
for _, resId := range resources {
res, err := action.ruleCtx.Solution.RawView().Vertex(resId)
if err != nil {
return err
}
if collectionutil.Contains(action.CurrentIds, res.ID) {
continue
}
if match, err := resourceSelector.IsMatch(configCtx, action.ruleCtx.Data, res); !match {
canUse, err := resourceSelector.CanUse(configCtx, action.ruleCtx.Data, res)
if err != nil {
return fmt.Errorf("error checking %s can use resource: %w", resId, err)
}
if !canUse {
continue
}
// This can happen if an empty resource was created via path expansion, but isn't yet set up.
tmpl, err := action.ruleCtx.Solution.KnowledgeBase().GetResourceTemplate(res.ID)
if err != nil {
return err
}
for k, v := range resourceSelector.Properties {
v, err := knowledgebase.TransformToPropertyValue(res.ID, k, v, configCtx, action.ruleCtx.Data)
if err != nil {
return err
}
err = res.SetProperty(k, v)
if err != nil {
return err
}
if tmpl.GetProperty(k).Details().Namespace {
oldId := res.ID
res.ID.Namespace = resource.ID.Namespace
err := action.ruleCtx.Solution.OperationalView().UpdateResourceID(oldId, res.ID)
if err != nil {
return err
}
}
}
} else if err != nil {
return fmt.Errorf("error checking %s matches selector: %w", resId, err)
}
if satisfy, err := action.doesResourceSatisfyNamespace(resource, res); !satisfy {
continue
} else if err != nil {
return fmt.Errorf("error checking %s satisfies namespace: %w", resId, err)
}
var edge construct.SimpleEdge
if action.Step.Direction == knowledgebase.DirectionDownstream {
edge = construct.SimpleEdge{Source: resource.ID, Target: res.ID}
} else {
edge = construct.SimpleEdge{Source: res.ID, Target: resource.ID}
}
// Check to see if the edge already exists, if it does, then we should be able to reuse the resource
_, err = action.ruleCtx.Solution.RawView().Edge(edge.Source, edge.Target)
if err == nil {
availableResources.Add(res)
continue
} else if !errors.Is(err, graph.ErrEdgeNotFound) {
return err
}
edgeTmpl := action.ruleCtx.Solution.KnowledgeBase().GetEdgeTemplate(edge.Source, edge.Target)
if edgeTmpl == nil {
continue
}
if edgeTmpl.Unique.CanAdd(edges, edge.Source, edge.Target) {
availableResources.Add(res)
}
}
}
}
err = action.placeResources(resource, availableResources)
if err != nil {
return fmt.Errorf("error during operational resource action while placing resources: %w", err)
}
return nil
}
func (action *operationalResourceAction) placeResources(resource *construct.Resource,
availableResources set.Set[*construct.Resource]) error {
placerGen, ok := placerMap[action.Step.SelectionOperator]
if !ok {
return fmt.Errorf("unknown selection operator %s", action.Step.SelectionOperator)
}
placer := placerGen()
placer.SetCtx(action.ruleCtx)
resources := availableResources.ToSlice()
sort.Slice(resources, func(i, j int) bool {
return construct.ResourceIdLess(resources[i].ID, resources[j].ID)
})
return placer.PlaceResources(resource, action.Step, resources, &action.numNeeded)
}
func (action *operationalResourceAction) doesResourceSatisfyNamespace(stepResource *construct.Resource, resource *construct.Resource) (bool, error) {
kb := action.ruleCtx.Solution.KnowledgeBase()
namespacedIds, err := kb.GetAllowedNamespacedResourceIds(solution.DynamicCtx(action.ruleCtx.Solution), resource.ID)
if err != nil {
return false, err
}
// If the type to create doesnt get namespaced, then we can ignore this satisfication
if len(namespacedIds) == 0 {
return true, nil
}
// Get all the functional resources which exist downstream of the step resource
var namespaceResourcesForResource []construct.ResourceId
for _, namespacedId := range namespacedIds {
// If theres no functional path from one resource to the other, then we dont care about that namespacedId
if kb.HasFunctionalPath(stepResource.ID, namespacedId) {
downstreams, err := solution.Downstream(action.ruleCtx.Solution, stepResource.ID, knowledgebase.FirstFunctionalLayer)
if err != nil {
return false, err
}
for _, downstream := range downstreams {
if namespacedId.Matches(downstream) {
namespaceResourcesForResource = append(namespaceResourcesForResource, downstream)
}
}
}
}
// If there are no functional resources downstream for the possible namespace resource types
// we have free will to choose any of the resources available with the type of the type to create
if len(namespaceResourcesForResource) == 0 {
return true, nil
}
// for the resource we are checking if its available based on if it is namespaced
// if it is namespaced we will ensure that it is namespaced into one of the resources downstream of the step resource
namespaceResourceId, err := kb.GetResourcesNamespaceResource(resource)
if err != nil {
return false, fmt.Errorf("error during operational resource action while getting namespace resource: %w", err)
}
var namespaceResource *construct.Resource
if !namespaceResourceId.IsZero() {
var err error
namespaceResource, err = action.ruleCtx.Solution.RawView().Vertex(namespaceResourceId)
if err != nil {
return false, err
}
// needed resource is not namespaced or resource doesnt have any namespace types downstream or the namespaced resource is using the right namespace
if !collectionutil.Contains(namespaceResourcesForResource, namespaceResource.ID) {
return false, nil
}
}
return true, nil
}
func (action *operationalResourceAction) getPriorityResourceType() (
construct.ResourceId,
knowledgebase.ResourceSelector,
error,
) {
for _, resourceSelector := range action.Step.Resources {
ids, err := resourceSelector.ExtractResourceIds(solution.DynamicCtx(action.ruleCtx.Solution), action.ruleCtx.Data)
if err != nil {
return construct.ResourceId{}, resourceSelector, err
}
for _, id := range ids {
res, err := action.ruleCtx.Solution.RawView().Vertex(id)
if err != nil && !errors.Is(err, graph.ErrVertexNotFound) {
return construct.ResourceId{}, resourceSelector, err
}
if id.IsZero() || (res != nil && !action.Step.Unique) {
continue
}
return construct.ResourceId{Provider: id.Provider, Type: id.Type, Namespace: id.Namespace, Name: id.Name}, resourceSelector, nil
}
}
return construct.ResourceId{}, knowledgebase.ResourceSelector{}, fmt.Errorf("no resource types found for step, %s", action.Step.Resource)
}
func (action *operationalResourceAction) addSelectorProperties(properties map[string]any, resource *construct.Resource) error {
template, err := action.ruleCtx.Solution.KnowledgeBase().GetResourceTemplate(resource.ID)
if err != nil {
return err
}
var errs error
configCtx := solution.DynamicCtx(action.ruleCtx.Solution)
for key, value := range properties {
property := template.GetProperty(key)
if property == nil {
return fmt.Errorf("property %s not found in template %s", key, template.Id())
}
selectorPropertyVal, err := knowledgebase.TransformToPropertyValue(resource.ID, key, value, configCtx, action.ruleCtx.Data)
if err != nil {
return err
}
err = resource.SetProperty(key, selectorPropertyVal)
if err != nil {
errs = errors.Join(errs, err)
}
}
return errs
}
func (action *operationalResourceAction) createResource(
resourceType construct.ResourceId,
selector knowledgebase.ResourceSelector,
stepResource *construct.Resource,
) error {
resId := resourceType
if err := action.generateResourceName(&resId, stepResource.ID); err != nil {
return err
}
newRes, err := knowledgebase.CreateResource(action.ruleCtx.Solution.KnowledgeBase(), resId)
if err != nil {
return err
}
if err := action.createAndAddDependency(newRes, stepResource); err != nil {
return err
}
if err := action.addSelectorProperties(selector.Properties, newRes); err != nil {
return err
}
return nil
}
func (action *operationalResourceAction) createAndAddDependency(res, stepResource *construct.Resource) error {
err := action.ruleCtx.Solution.OperationalView().AddVertex(res)
if err != nil && !errors.Is(err, graph.ErrVertexAlreadyExists) {
return err
}
err = action.ruleCtx.addDependencyForDirection(action.Step, stepResource, res)
if err != nil {
return err
}
action.numNeeded--
return nil
}
func (action *operationalResourceAction) generateResourceName(resourceToSet *construct.ResourceId, resource construct.ResourceId) error {
if resourceToSet.Name != "" {
return nil
}
if action.Step.Unique {
// If creating unique resources, don't need to count the total resources because the owner's name is added
// which adds enough uniqueness against other resources in the graph. Just need to handle when the owner
// creates multiple resources of the same type.
suffix := ""
if action.Step.NumNeeded > 1 {
// If we are creating multiple resources, we want to append the number of resources we have created so far
// so that the names are unique.
suffix = fmt.Sprintf("-%d", action.Step.NumNeeded-action.numNeeded)
}
resourceToSet.Name = fmt.Sprintf("%s-%s%s", resource.Name, resourceToSet.Type, suffix)
return nil
}
return generateResourceName(action.ruleCtx.Solution, resourceToSet)
}
func generateResourceName(sol solution.Solution, resourceToSet *construct.ResourceId) error {
numResources := 0
ids, err := construct.TopologicalSort(sol.DataflowGraph())
if err != nil {
return err
}
currNames := make(set.Set[string])
// we cannot consider things only in the namespace because when creating a resource for an operational action
// it likely has not been namespaced yet and we dont know where it will be namespaced to
matcher := construct.ResourceId{Provider: resourceToSet.Provider, Type: resourceToSet.Type}
for _, id := range ids {
if matcher.Matches(id) {
currNames.Add(id.Name)
numResources++
}
}
// check if the current name based on the digit conflicts with an existing name and if so create a random uuid suffix
resourceToSet.Name = fmt.Sprintf("%s-%d", resourceToSet.Type, numResources)
if currNames.Contains(resourceToSet.Name) {
suffix := uuid.NewString()[:8]
resourceToSet.Name = fmt.Sprintf("%s-%s", resourceToSet.Type, suffix)
}
return nil
}
package operational_rule
import (
"fmt"
"github.com/klothoplatform/klotho/pkg/engine/constraints"
"github.com/klothoplatform/klotho/pkg/engine/solution"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
func (ctx OperationalRuleContext) HandleConfigurationRule(
config knowledgebase.ConfigurationRule,
configurationOperator constraints.ConstraintOperator,
) error {
dyn := solution.DynamicCtx(ctx.Solution)
res, err := knowledgebase.ExecuteDecodeAsResourceId(dyn, config.Resource, ctx.Data)
if err != nil {
return err
}
resource, err := ctx.Solution.DataflowGraph().Vertex(res)
if err != nil {
return fmt.Errorf("resource %s not found: %w", res, err)
}
resolvedField := config.Config.Field
err = dyn.ExecuteDecode(config.Config.Field, ctx.Data, &resolvedField)
if err != nil {
return err
}
config.Config.Field = resolvedField
configurer := &solution.Configurer{Ctx: ctx.Solution}
err = configurer.ConfigureResource(resource, config.Config, ctx.Data, configurationOperator, false)
if err != nil {
return err
}
return nil
}
package operational_rule
import (
"errors"
"fmt"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/constraints"
"github.com/klothoplatform/klotho/pkg/engine/reconciler"
"github.com/klothoplatform/klotho/pkg/engine/solution"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/klothoplatform/klotho/pkg/set"
"go.uber.org/zap"
)
//go:generate mockgen -source=./operational_rule.go --destination=../operational_eval/operational_rule_mock_test.go --package=operational_eval
type (
OperationalRuleContext struct {
Solution solution.Solution
Property knowledgebase.Property
Data knowledgebase.DynamicValueData
}
OpRuleHandler interface {
HandleOperationalRule(rule knowledgebase.OperationalRule, configurationOperator constraints.ConstraintOperator) error
HandlePropertyRule(rule knowledgebase.PropertyRule) error
SetData(data knowledgebase.DynamicValueData)
}
)
func (ctx *OperationalRuleContext) HandleOperationalRule(
rule knowledgebase.OperationalRule,
configurationOperator constraints.ConstraintOperator,
) error {
shouldRun, err := EvaluateIfCondition(rule.If, ctx.Solution, ctx.Data)
if err != nil {
return err
}
if !shouldRun {
return nil
}
var errs error
for i, operationalStep := range rule.Steps {
err := ctx.HandleOperationalStep(operationalStep)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("could not apply step %d: %w", i, err))
continue
}
}
for i, operationalConfig := range rule.ConfigurationRules {
err := ctx.HandleConfigurationRule(operationalConfig, configurationOperator)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("could not apply configuration rule %d: %w", i, err))
}
}
return errs
}
func (ctx *OperationalRuleContext) HandlePropertyRule(rule knowledgebase.PropertyRule) error {
if ctx.Property == nil {
return fmt.Errorf("property rule has no property")
}
if ctx.Data.Resource.IsZero() {
return fmt.Errorf("property rule has no resource")
}
shouldRun, err := EvaluateIfCondition(rule.If, ctx.Solution, ctx.Data)
if err != nil {
return err
}
if !shouldRun {
return nil
}
if ctx.Property != nil && len(rule.Step.Resources) > 0 {
err := ctx.CleanProperty(rule.Step)
if err != nil {
return err
}
}
var errs error
if len(rule.Step.Resources) > 0 {
err = ctx.HandleOperationalStep(rule.Step)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("could not apply step: %w", err))
}
}
if rule.Value != nil {
dynctx := solution.DynamicCtx(ctx.Solution)
val, err := ctx.Property.Parse(rule.Value, dynctx, ctx.Data)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("could not parse value %s: %w", rule.Value, err))
}
resource, err := ctx.Solution.RawView().Vertex(ctx.Data.Resource)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("could not get resource %s: %w", ctx.Data.Resource, err))
} else {
err = ctx.Property.SetProperty(resource, val)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("could not set property %s: %w", ctx.Property, err))
}
}
}
return errs
}
func (ctx *OperationalRuleContext) SetData(data knowledgebase.DynamicValueData) {
ctx.Data = data
}
// CleanProperty clears the property associated with the rule if it no longer matches the rule.
// For array properties, each element must match at least one step selector and non-matching
// elements will be removed.
func (ctx OperationalRuleContext) CleanProperty(step knowledgebase.OperationalStep) error {
log := zap.L().With(
zap.String("op", "op_rule"),
zap.String("property", ctx.Property.Details().Path),
zap.String("resource", ctx.Data.Resource.String()),
).Sugar()
resource, err := ctx.Solution.RawView().Vertex(ctx.Data.Resource)
if err != nil {
return err
}
path, err := resource.PropertyPath(ctx.Property.Details().Path)
if err != nil {
return err
}
prop, _ := path.Get()
if prop == nil {
return nil
}
checkResForMatch := func(res construct.ResourceId) (bool, error) {
propRes, err := ctx.Solution.RawView().Vertex(res)
if err != nil {
return false, err
}
for i, sel := range step.Resources {
match, err := sel.IsMatch(solution.DynamicCtx(ctx.Solution), ctx.Data, propRes)
if err != nil {
return false, fmt.Errorf("error checking if %s matches selector %d: %w", prop, i, err)
}
if match {
return true, nil
}
}
return false, nil
}
switch prop := prop.(type) {
case construct.ResourceId:
isMatch, err := checkResForMatch(prop)
if err != nil {
return err
}
if isMatch {
return nil
}
log.Infof("removing %s, does not match selectors", prop)
err = path.Remove(nil)
if err != nil {
return err
}
err = ForceRemoveDependency(ctx.Data.Resource, prop, ctx.Solution)
if err != nil {
return err
}
return reconciler.RemoveResource(ctx.Solution, prop, false)
case []construct.ResourceId:
matching := make([]construct.ResourceId, 0, len(prop))
toRemove := make(set.Set[construct.ResourceId])
for _, id := range prop {
isMatch, err := checkResForMatch(id)
if err != nil {
return err
}
if !isMatch {
toRemove.Add(id)
} else {
matching = append(matching, id)
}
}
if len(matching) == len(prop) {
return nil
}
err := path.Set(matching)
if err != nil {
return err
}
var errs error
for rem := range toRemove {
log.Infof("removing %s, does not match selectors", prop)
errs = errors.Join(errs, ForceRemoveDependency(ctx.Data.Resource, rem, ctx.Solution))
errs = errors.Join(errs, reconciler.RemoveResource(ctx.Solution, rem, false))
}
case []any:
matching := make([]any, 0, len(prop))
toRemove := make(set.Set[construct.ResourceId])
for _, propV := range prop {
id, ok := propV.(construct.ResourceId)
if !ok {
propRef, ok := propV.(construct.PropertyRef)
if !ok {
matching = append(matching, propV)
continue
}
id = propRef.Resource
}
isMatch, err := checkResForMatch(id)
if err != nil {
return err
}
if !isMatch {
toRemove.Add(id)
} else {
matching = append(matching, id)
}
}
if len(matching) == len(prop) {
return nil
}
err := path.Set(matching)
if err != nil {
return err
}
var errs error
for rem := range toRemove {
log.Infof("removing %s, does not match selectors", prop)
errs = errors.Join(errs, ForceRemoveDependency(ctx.Data.Resource, rem, ctx.Solution))
errs = errors.Join(errs, reconciler.RemoveResource(ctx.Solution, rem, false))
}
case construct.PropertyRef:
isMatch, err := checkResForMatch(prop.Resource)
if err != nil {
return err
}
if isMatch {
return nil
}
log.Infof("removing %s, does not match selectors", prop)
err = path.Remove(nil)
if err != nil {
return err
}
err = ForceRemoveDependency(ctx.Data.Resource, prop.Resource, ctx.Solution)
if err != nil {
return err
}
return reconciler.RemoveResource(ctx.Solution, prop.Resource, false)
}
return nil
}
func EvaluateIfCondition(
tmplString string,
sol solution.Solution,
data knowledgebase.DynamicValueData,
) (bool, error) {
if tmplString == "" {
return true, nil
}
result := false
dyn := solution.DynamicCtx(sol)
err := dyn.ExecuteDecode(tmplString, data, &result)
if err != nil {
return false, err
}
return result, nil
}
func ForceRemoveDependency(
res1, res2 construct.ResourceId,
sol solution.Solution,
) error {
err := sol.RawView().RemoveEdge(res1, res2)
if err != nil && !errors.Is(err, graph.ErrEdgeNotFound) {
return err
}
err = sol.RawView().RemoveEdge(res2, res1)
if err != nil && !errors.Is(err, graph.ErrEdgeNotFound) {
return err
}
return nil
}
package operational_rule
import (
"errors"
"fmt"
"reflect"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/reconciler"
"github.com/klothoplatform/klotho/pkg/engine/solution"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/klothoplatform/klotho/pkg/logging"
"go.uber.org/zap"
)
func (ctx OperationalRuleContext) HandleOperationalStep(step knowledgebase.OperationalStep) error {
// Default to 1 resource needed
if step.NumNeeded == 0 {
step.NumNeeded = 1
}
dyn := solution.DynamicCtx(ctx.Solution)
resourceId := ctx.Data.Resource
if resourceId.IsZero() {
var err error
resourceId, err = knowledgebase.ExecuteDecodeAsResourceId(dyn, step.Resource, ctx.Data)
if err != nil {
return err
}
}
resource, err := ctx.Solution.OperationalView().Vertex(resourceId)
if err != nil {
return fmt.Errorf("resource %s not found: %w", resourceId, err)
}
// If we are replacing we want to remove all dependencies and clear the property
// otherwise we want to add dependencies from the property and gather the resources which satisfy the step
var ids []construct.ResourceId
var otherValues []any
if ctx.Property != nil {
var err error
ids, otherValues, err = ctx.addDependenciesFromProperty(step, resource, ctx.Property.Details().Path)
if err != nil {
return err
}
} else { // an edge rule won't have a Property
ids, err = ctx.getResourcesForStep(step, resource.ID)
if err != nil {
return err
}
}
numValues := len(ids) + len(otherValues)
if numValues >= step.NumNeeded && step.NumNeeded > 0 || resource.Imported {
return nil
}
if step.FailIfMissing {
return fmt.Errorf("operational resource '%s' missing when required", resource.ID)
}
action := operationalResourceAction{
Step: step,
CurrentIds: ids,
numNeeded: step.NumNeeded - numValues,
ruleCtx: ctx,
}
return action.handleOperationalResourceAction(resource)
}
func (ctx OperationalRuleContext) getResourcesForStep(step knowledgebase.OperationalStep, resource construct.ResourceId) ([]construct.ResourceId, error) {
var ids []construct.ResourceId
var err error
if step.Direction == knowledgebase.DirectionUpstream {
ids, err = solution.Upstream(ctx.Solution, resource, knowledgebase.FirstFunctionalLayer)
} else {
ids, err = solution.Downstream(ctx.Solution, resource, knowledgebase.FirstFunctionalLayer)
}
if err != nil {
return nil, err
}
resources, err := construct.ResolveIds(ctx.Solution.RawView(), ids)
if err != nil {
return nil, fmt.Errorf("could not resolve ids for 'getResourcesForStep': %w", err)
}
dyn := solution.DynamicCtx(ctx.Solution)
var resourcesOfType []construct.ResourceId
for _, dep := range resources {
for _, resourceSelector := range step.Resources {
if match, err := resourceSelector.IsMatch(dyn, ctx.Data, dep); match {
resourcesOfType = append(resourcesOfType, dep.ID)
} else if err != nil {
return nil, fmt.Errorf("error checking if %s is side effect of %s: %w", dep.ID, resource, err)
}
}
}
return resourcesOfType, nil
}
// addDependenciesFromProperty adds dependencies from the property of the resource
// and returns the resource ids and other values that were found in the property
func (ctx OperationalRuleContext) addDependenciesFromProperty(
step knowledgebase.OperationalStep,
resource *construct.Resource,
propertyName string,
) ([]construct.ResourceId, []any, error) {
val, err := resource.GetProperty(propertyName)
if err != nil {
return nil, nil, fmt.Errorf("error getting property %s on resource %s: %w", propertyName, resource.ID, err)
}
if val == nil {
return nil, nil, nil
}
addDep := func(id construct.ResourceId) error {
dep, err := ctx.Solution.RawView().Vertex(id)
if err != nil {
return fmt.Errorf("could not add dep to %s from %s#%s: %w", id, resource.ID, propertyName, err)
}
if _, err := ctx.Solution.RawView().Edge(resource.ID, dep.ID); err == nil {
return nil
}
err = ctx.addDependencyForDirection(step, resource, dep)
if err != nil {
return err
}
return nil
}
switch val := val.(type) {
case construct.ResourceId:
if val.IsZero() {
return nil, nil, nil
}
return []construct.ResourceId{val}, nil, addDep(val)
case []construct.ResourceId:
var errs error
for _, id := range val {
errs = errors.Join(errs, addDep(id))
}
return val, nil, errs
case []any:
var errs error
var ids []construct.ResourceId
var otherValues []any
for _, elem := range val {
switch elem := elem.(type) {
case construct.ResourceId:
ids = append(ids, elem)
errs = errors.Join(errs, addDep(elem))
case construct.PropertyRef:
ids = append(ids, elem.Resource)
errs = errors.Join(errs, addDep(elem.Resource))
case any:
otherValues = append(otherValues, elem)
}
}
return ids, otherValues, errs
case construct.PropertyRef:
return []construct.ResourceId{val.Resource}, nil, addDep(val.Resource)
default:
return nil, []any{val}, nil
}
}
func (ctx OperationalRuleContext) clearProperty(step knowledgebase.OperationalStep, resource *construct.Resource, propertyName string) error {
val, err := resource.GetProperty(propertyName)
if err != nil {
return err
}
if val == nil {
return nil
}
kb := ctx.Solution.KnowledgeBase()
removeDep := func(id construct.ResourceId) error {
err := ctx.removeDependencyForDirection(step.Direction, resource.ID, id)
if err != nil {
return err
}
if knowledgebase.GetFunctionality(kb, id) == knowledgebase.Unknown {
return reconciler.RemoveResource(ctx.Solution, id, false)
}
return nil
}
switch val := val.(type) {
case construct.ResourceId:
err := removeDep(val)
if err != nil {
return err
}
return resource.RemoveProperty(propertyName, nil)
case []construct.ResourceId:
var errs error
for _, id := range val {
errs = errors.Join(errs, removeDep(id))
}
if errs != nil {
return errs
}
return resource.RemoveProperty(propertyName, nil)
case []any:
var errs error
for _, elem := range val {
if id, ok := elem.(construct.ResourceId); ok {
errs = errors.Join(errs, removeDep(id))
}
}
if errs != nil {
return errs
}
return resource.RemoveProperty(propertyName, nil)
}
return fmt.Errorf("cannot clear property %s on resource %s", propertyName, resource.ID)
}
func (ctx OperationalRuleContext) addDependencyForDirection(
step knowledgebase.OperationalStep,
resource, dependentResource *construct.Resource,
) error {
var edge construct.Edge
if step.Direction == knowledgebase.DirectionUpstream {
edge = construct.Edge{Source: dependentResource.ID, Target: resource.ID}
} else {
edge = construct.Edge{Source: resource.ID, Target: dependentResource.ID}
}
err := ctx.Solution.OperationalView().AddEdge(edge.Source, edge.Target)
if err != nil && !errors.Is(err, graph.ErrEdgeAlreadyExists) {
return err
}
return ctx.SetField(resource, dependentResource, step)
}
func (ctx OperationalRuleContext) removeDependencyForDirection(direction knowledgebase.Direction, resource, dependentResource construct.ResourceId) error {
if direction == knowledgebase.DirectionUpstream {
return ctx.Solution.OperationalView().RemoveEdge(dependentResource, resource)
} else {
return ctx.Solution.OperationalView().RemoveEdge(resource, dependentResource)
}
}
func (ctx OperationalRuleContext) SetField(resource, fieldResource *construct.Resource, step knowledgebase.OperationalStep) error {
if ctx.Property == nil {
return nil
}
path := ctx.Property.Details().Path
propVal, err := resource.GetProperty(path)
if err != nil {
zap.S().Debugf("property %s not found on resource %s", path, resource.ID)
}
var propertyValue any
propertyValue = fieldResource.ID
if step.UsePropertyRef != "" {
propertyValue = construct.PropertyRef{Resource: fieldResource.ID, Property: step.UsePropertyRef}
}
if resource.Imported {
if ctx.Property.Contains(propVal, propertyValue) {
ctx.namespace(resource, fieldResource, resource.ID)
return nil
}
return fmt.Errorf("cannot set field on imported resource %s", resource.ID)
}
// snapshot the ID from before any field changes
oldId := resource.ID
removeResource := func(currResId construct.ResourceId) error {
err := ctx.removeDependencyForDirection(step.Direction, resource.ID, currResId)
if err != nil {
return err
}
zap.S().Infof("Removing old field value for '%s' (%s) for %s", path, currResId, fieldResource.ID)
// Remove the old field value if it's unused
err = reconciler.RemoveResource(ctx.Solution, currResId, false)
if err != nil {
return err
}
return nil
}
switch val := propVal.(type) {
case construct.ResourceId:
if val != fieldResource.ID {
err = removeResource(val)
}
case construct.PropertyRef:
if val.Resource != fieldResource.ID {
err = removeResource(val.Resource)
}
}
if err != nil {
return err
}
resVal := reflect.ValueOf(propVal)
if resVal.IsValid() && (resVal.Kind() == reflect.Slice || resVal.Kind() == reflect.Array) {
// If the current field is a resource id we will compare it against the one passed in to see if we need to remove the current resource
for i := 0; i < resVal.Len(); i++ {
currResId, ok := resVal.Index(i).Interface().(construct.ResourceId)
if !ok {
continue
}
if !currResId.IsZero() && currResId == fieldResource.ID {
return nil
}
}
}
// Right now we only enforce the top level properties if they have rules, so we can assume the path is equal to the name of the property
err = ctx.Property.AppendProperty(resource, propertyValue)
if err != nil {
return fmt.Errorf("error appending field %s#%s with %s: %w", resource.ID, path, fieldResource.ID, err)
}
log := logging.GetLogger(ctx.Solution.Context())
log.Sugar().Infof("appended field %s#%s with %s", resource.ID, path, fieldResource.ID)
ctx.namespace(resource, fieldResource, oldId)
return nil
}
func (ctx *OperationalRuleContext) namespace(resource, fieldResource *construct.Resource, oldId construct.ResourceId) {
if ctx.Property.Details().Namespace {
resource.ID.Namespace = fieldResource.ID.Name
}
// updated the rule context ids if they have changed
if ctx.Data.Resource.Matches(oldId) {
ctx.Data.Resource = resource.ID
}
if ctx.Data.Edge != nil {
if ctx.Data.Edge.Source.Matches(oldId) {
ctx.Data.Edge.Source = resource.ID
}
if ctx.Data.Edge.Target.Matches(oldId) {
ctx.Data.Edge.Target = resource.ID
}
}
}
package operational_rule
import (
"errors"
"math"
"sort"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/solution"
"github.com/klothoplatform/klotho/pkg/graph_addons"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
type (
ResourcePlacer interface {
PlaceResources(
resource *construct.Resource,
step knowledgebase.OperationalStep,
availableResources []*construct.Resource,
numNeeded *int,
) error
SetCtx(ctx OperationalRuleContext)
}
SpreadPlacer struct {
ctx OperationalRuleContext
}
ClusterPlacer struct {
ctx OperationalRuleContext
}
ClosestPlacer struct {
ctx OperationalRuleContext
}
)
var placerMap = map[knowledgebase.SelectionOperator]func() ResourcePlacer{
knowledgebase.SpreadSelectionOperator: func() ResourcePlacer { return &SpreadPlacer{} },
knowledgebase.ClusterSelectionOperator: func() ResourcePlacer { return &ClusterPlacer{} },
knowledgebase.ClosestSelectionOperator: func() ResourcePlacer { return &ClosestPlacer{} },
}
func (p *SpreadPlacer) PlaceResources(
resource *construct.Resource,
step knowledgebase.OperationalStep,
availableResources []*construct.Resource,
numNeeded *int,
) error {
// if we get the spread operator our logic goes as follows:
// If there is only one resource available, do not place in that resource and instead create a new one
// If there are multiple available, find the one with the least connections to the same resource in question and use that
if len(availableResources) <= 1 {
// If there is only one resource available, do not place in that resource and instead create a new one
return nil
}
if *numNeeded == 0 {
return nil
}
mapOfConnections, err := p.ctx.findNumConnectionsToTypeForAvailableResources(step, availableResources, resource.ID)
if err != nil {
return err
}
numConnectionsArray := sortNumConnectionsMap(mapOfConnections)
for _, numConnections := range numConnectionsArray {
for _, availableResource := range mapOfConnections[numConnections] {
err := p.ctx.addDependencyForDirection(step, resource, availableResource)
if err != nil {
return err
}
*numNeeded--
if *numNeeded == 0 {
return nil
}
}
}
return nil
}
func (p *SpreadPlacer) SetCtx(ctx OperationalRuleContext) {
p.ctx = ctx
}
func (p *ClusterPlacer) PlaceResources(
resource *construct.Resource,
step knowledgebase.OperationalStep,
availableResources []*construct.Resource,
numNeeded *int,
) error {
// if we get the cluster operator our logic goes as follows:
// Place in the resource which has the most connections to the same resource in question
mapOfConnections, err := p.ctx.findNumConnectionsToTypeForAvailableResources(step, availableResources, resource.ID)
if err != nil {
return err
}
numConnectionsArray := sortNumConnectionsMap(mapOfConnections)
sort.Sort(sort.Reverse(sort.IntSlice(numConnectionsArray)))
for _, numConnections := range numConnectionsArray {
for _, availableResource := range mapOfConnections[numConnections] {
err := p.ctx.addDependencyForDirection(step, resource, availableResource)
if err != nil {
return err
}
*numNeeded--
if *numNeeded == 0 {
return nil
}
}
}
return nil
}
func (p *ClusterPlacer) SetCtx(ctx OperationalRuleContext) {
p.ctx = ctx
}
func (p ClosestPlacer) PlaceResources(
resource *construct.Resource,
step knowledgebase.OperationalStep,
availableResources []*construct.Resource,
numNeeded *int,
) error {
// if we get the closest operator our logic goes as follows:
// find the closest available resource in terms of functional distance and use that
if *numNeeded == 0 {
return nil
}
resourceDepths := make(map[construct.ResourceId]int)
undirectedGraph, err := BuildUndirectedGraph(p.ctx.Solution)
if err != nil {
return err
}
pather, err := construct.ShortestPaths(undirectedGraph, resource.ID, construct.DontSkipEdges)
if err != nil {
return err
}
for _, availableResource := range availableResources {
path, err := pather.ShortestPath(availableResource.ID)
if err != nil && !errors.Is(err, graph.ErrTargetNotReachable) {
return err
}
// If the target isnt reachable then we want to make it so that it is the longest possible path
if path == nil {
resourceDepths[availableResource.ID] = math.MaxInt64
continue
}
length := 0
for i := range path {
if i == 0 {
continue
}
edge, _ := undirectedGraph.Edge(path[i-1], path[i])
length += edge.Properties.Weight
}
resourceDepths[availableResource.ID] = length
}
sort.SliceStable(availableResources, func(i, j int) bool {
return resourceDepths[availableResources[i].ID] < resourceDepths[availableResources[j].ID]
})
num := *numNeeded
if num > len(availableResources) || num < 0 {
num = len(availableResources)
}
for _, availableResource := range availableResources[:num] {
err := p.ctx.addDependencyForDirection(step, resource, availableResource)
if err != nil {
return err
}
}
*numNeeded -= num
return nil
}
func (p *ClosestPlacer) SetCtx(ctx OperationalRuleContext) {
p.ctx = ctx
}
func BuildUndirectedGraph(ctx solution.Solution) (construct.Graph, error) {
undirected := graph.NewWithStore(
construct.ResourceHasher,
graph_addons.NewMemoryStore[construct.ResourceId, *construct.Resource](),
)
err := undirected.AddVerticesFrom(ctx.RawView())
if err != nil {
return nil, err
}
edges, err := ctx.RawView().Edges()
if err != nil {
return nil, err
}
for _, e := range edges {
weight := 1
// increase weights for edges that are connected to a functional resource
if knowledgebase.GetFunctionality(ctx.KnowledgeBase(), e.Source) != knowledgebase.Unknown {
weight = 1000
} else if knowledgebase.GetFunctionality(ctx.KnowledgeBase(), e.Target) != knowledgebase.Unknown {
weight = 1000
}
err := undirected.AddEdge(e.Source, e.Target, graph.EdgeWeight(weight))
if err != nil {
return nil, err
}
}
return undirected, nil
}
func (ctx OperationalRuleContext) findNumConnectionsToTypeForAvailableResources(
step knowledgebase.OperationalStep,
availableResources []*construct.Resource,
resource construct.ResourceId,
) (map[int][]*construct.Resource, error) {
mapOfConnections := map[int][]*construct.Resource{}
// If there are multiple available, find the one with the least connections to the same resource in question and use that
for _, availableResource := range availableResources {
var err error
var connections []construct.ResourceId
// We will look to see what direct dependencies are already existing in the same direction as the rule
// if we dont only look at direct, we risk getting incorrect results if the resource can have non functional connections
if step.Direction == knowledgebase.DirectionDownstream {
connections, err = solution.Upstream(ctx.Solution, availableResource.ID,
knowledgebase.ResourceDirectLayer)
} else {
connections, err = solution.Downstream(ctx.Solution, availableResource.ID,
knowledgebase.ResourceDirectLayer)
}
var connectionsOfType []construct.ResourceId
for _, connection := range connections {
if connection.QualifiedTypeName() == resource.QualifiedTypeName() {
connectionsOfType = append(connectionsOfType, connection)
}
}
if err != nil {
return mapOfConnections, err
}
mapOfConnections[len(connectionsOfType)] = append(mapOfConnections[len(connectionsOfType)], availableResource)
}
return mapOfConnections, nil
}
func sortNumConnectionsMap(mapOfConnections map[int][]*construct.Resource) []int {
numConnectionsArray := []int{}
for numConnections := range mapOfConnections {
numConnectionsArray = append(numConnectionsArray, numConnections)
}
sort.Ints(numConnectionsArray)
return numConnectionsArray
}
package path_selection
import (
"errors"
"fmt"
"strings"
"github.com/dominikbraun/graph"
"github.com/klothoplatform/klotho/pkg/collectionutil"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/solution"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"go.uber.org/zap"
)
type (
// downstreamChecker is a validityChecker that checks if a candidate is valid based on what is downstream of the specified
// resources
downstreamChecker struct {
ctx solution.Solution
}
)
// checkModifiesImportedResource checks if there is an imported resource that would be modified due to the edge
// If there is an edge rule modifying the resource then we consider the edge to be invalid
func checkModifiesImportedResource(
source, target construct.ResourceId,
ctx solution.Solution,
et *knowledgebase.EdgeTemplate,
) (bool, error) {
// see if the source resource exists in the graph
sourceResource, srcErr := ctx.RawView().Vertex(source)
// see if the target resource exists in the graph
targetResource, trgtErr := ctx.RawView().Vertex(target)
if errors.Is(srcErr, graph.ErrVertexNotFound) && errors.Is(trgtErr, graph.ErrVertexNotFound) {
return false, nil
}
if et == nil {
et = ctx.KnowledgeBase().GetEdgeTemplate(source, target)
}
checkRules := func(resources construct.ResourceList) (bool, error) {
if len(resources) == 0 {
return false, nil
}
for _, rule := range et.OperationalRules {
for _, config := range rule.ConfigurationRules {
dynamicCtx := solution.DynamicCtx(ctx)
id := construct.ResourceId{}
// we ignore the error since phantom resources will cause errors in the decoding of templates
_ = dynamicCtx.ExecuteDecode(config.Resource, knowledgebase.DynamicValueData{
Edge: &construct.Edge{
Source: source,
Target: target,
}}, &id)
if resources.MatchesAny(id) {
return true, nil
}
}
}
return false, nil
}
importedResources := construct.ResourceList{}
if sourceResource != nil && sourceResource.Imported {
importedResources = append(importedResources, source)
}
if targetResource != nil && targetResource.Imported {
importedResources = append(importedResources, target)
}
return checkRules(importedResources)
}
// checkCandidatesValidity checks if the candidate is valid based on the validity of its own path satisfaction rules and namespace
func checkCandidatesValidity(
ctx solution.Solution,
resource *construct.Resource,
path []construct.ResourceId,
classification string,
) (bool, error) {
// We only care if the validity is true if its not a direct edge since we know direct edges are valid
if len(path) <= 3 {
return true, nil
}
rt, err := ctx.KnowledgeBase().GetResourceTemplate(resource.ID)
if err != nil || rt == nil {
return false, err
}
var errs error
// check validity of candidate being a target if not direct edge to source
valid, err := checkAsTargetValidity(ctx, resource, path[0], classification)
if err != nil {
errs = errors.Join(errs, err)
}
if !valid {
zap.S().Debugf("candidate %s is not valid as target", resource.ID)
return false, errs
}
// check validity of candidate being a source if not direct edge to target
valid, err = checkAsSourceValidity(ctx, resource, path[len(path)-1], classification)
if err != nil {
errs = errors.Join(errs, err)
}
if !valid {
zap.S().Debugf("candidate %s is not valid as source", resource.ID)
return false, errs
}
return true, errs
}
// checkNamespaceValidity checks if the candidate is valid based on the namespace it is a part of.
// If the candidate is namespaced and the target is not in the same namespace,
//
// then the candidate is not valid if those namespace resources are the same type
func checkNamespaceValidity(
ctx solution.Solution,
resource *construct.Resource,
target construct.ResourceId,
) (bool, error) {
// Check if its a valid namespaced resource
ids, err := ctx.KnowledgeBase().GetAllowedNamespacedResourceIds(solution.DynamicCtx(ctx), resource.ID)
if err != nil {
return false, err
}
for _, i := range ids {
if i.Matches(target) {
ns, err := ctx.KnowledgeBase().GetResourcesNamespaceResource(resource)
if err != nil {
return false, err
}
if !ns.Matches(target) {
return false, nil
}
}
}
return true, nil
}
// checkAsTargetValidity checks if the candidate is valid based on the validity of its own path satisfaction rules
// for the specified classification. If the candidate uses property references to check validity then the candidate
// can be considered valid if those properties are not set
func checkAsTargetValidity(
ctx solution.Solution,
resource *construct.Resource,
source construct.ResourceId,
classification string,
) (bool, error) {
rt, err := ctx.KnowledgeBase().GetResourceTemplate(resource.ID)
if err != nil {
return false, err
}
if rt == nil {
return true, nil
}
var errs error
for _, ps := range rt.PathSatisfaction.AsTarget {
if ps.Classification == classification && ps.Validity != "" {
resources := []construct.ResourceId{resource.ID}
if ps.PropertyReference != "" {
resources, err = solution.GetResourcesFromPropertyReference(ctx,
resource.ID, ps.PropertyReference)
if err != nil {
// dont return error because it just means that the property isnt set and we can make the
// resource valid
zap.S().Debugf(
"no resource available from resource %s from property ref %s: %v",
resource.ID, ps.PropertyReference, err,
)
}
if len(resources) == 0 {
err = assignForValidity(ctx, resource, source, ps)
errs = errors.Join(errs, err)
}
}
for _, res := range resources {
valid, err := checkValidityOperation(ctx, source, res, ps)
if err != nil {
errs = errors.Join(errs, err)
}
if !valid {
return false, errs
}
}
}
}
return true, errs
}
// checkAsSourceValidity checks if the candidate is valid based on the validity of its own path satisfaction rules
// for the specified classification. If the candidate uses property references to check validity then the candidate
// can be considered valid if those properties are not set
func checkAsSourceValidity(
ctx solution.Solution,
resource *construct.Resource,
target construct.ResourceId,
classification string,
) (bool, error) {
rt, err := ctx.KnowledgeBase().GetResourceTemplate(resource.ID)
if err != nil {
return false, err
}
if rt == nil {
return true, nil
}
var errs error
for _, ps := range rt.PathSatisfaction.AsSource {
if ps.Classification == classification && ps.Validity != "" {
resources := []construct.ResourceId{resource.ID}
if ps.PropertyReference != "" {
resources, err = solution.GetResourcesFromPropertyReference(ctx,
resource.ID, ps.PropertyReference)
if err != nil {
// dont return error because it just means that the property isnt set and we can make the
// resource valid
zap.S().Debugf(
"no resource available from resource %s from property ref %s: %v",
resource.ID, ps.PropertyReference, err,
)
}
if len(resources) == 0 {
err = assignForValidity(ctx, resource, target, ps)
errs = errors.Join(errs, err)
}
}
for _, res := range resources {
valid, err := checkValidityOperation(ctx, res, target, ps)
if err != nil {
errs = errors.Join(errs, err)
}
if !valid {
return false, errs
}
}
}
}
return true, errs
}
// checkValidityOperation checks if the candidate is valid based on the operation the validity check specifies
func checkValidityOperation(
ctx solution.Solution,
src, target construct.ResourceId,
ps knowledgebase.PathSatisfactionRoute,
) (bool, error) {
var errs error
switch ps.Validity {
case knowledgebase.DownstreamOperation:
valid, err := downstreamChecker{ctx}.isValid(src, target)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("error checking downstream validity: %w", err))
}
if !valid {
return false, errs
}
}
return true, errs
}
// assignForValidity assigns the candidate to be valid based on the operation the validity check specified
// This is allowed to be run if the property reference used in the validity check is not set on the candidate
func assignForValidity(
ctx solution.Solution,
resource *construct.Resource,
operationResourceId construct.ResourceId,
ps knowledgebase.PathSatisfactionRoute,
) error {
operationResource, err := ctx.RawView().Vertex(operationResourceId)
if err != nil {
return err
}
var errs error
switch ps.Validity {
case knowledgebase.DownstreamOperation:
err := downstreamChecker{ctx}.makeValid(resource, operationResource, ps.PropertyReference)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("error making resource downstream validity: %w", err))
}
}
return errs
}
// makeValid makes the candidate valid based on the operation the validity check specified
// It will find a resource to assign to the propertyRef specified based on what is downstream of the operationResource.
func (d downstreamChecker) makeValid(resource, operationResource *construct.Resource, propertyRef string) error {
downstreams, err := solution.Downstream(d.ctx, operationResource.ID, knowledgebase.FirstFunctionalLayer)
if err != nil {
return err
}
// include the operation resource in downstreams in case it also can be assigned to the target property
downstreams = append(downstreams, operationResource.ID)
cfgCtx := solution.DynamicCtx(d.ctx)
assign := func(r construct.ResourceId, property string) (bool, error) {
var errs error
rt, err := d.ctx.KnowledgeBase().GetResourceTemplate(r)
if err != nil || rt == nil {
return false, fmt.Errorf("error getting resource template for resource %s: %w", resource.ID, err)
}
p := rt.Properties[property]
for _, downstream := range downstreams {
val, err := knowledgebase.TransformToPropertyValue(r, property, downstream, cfgCtx,
knowledgebase.DynamicValueData{Resource: r})
if err != nil || val == nil {
continue // Because this error may just mean that its not the right type of resource
}
// We need to check if the current resource is what we are operating on and if so not search our raw view
// this is because it could be a phantom resource
var currRes *construct.Resource
if resource.ID == r {
currRes = resource
} else {
currRes, err = d.ctx.RawView().Vertex(r)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("error getting resource %s: %w", resource.ID, err))
continue
}
}
return true, errors.Join(errs, p.AppendProperty(currRes, downstream))
}
return false, errs
}
var errs error
parts := strings.Split(propertyRef, "#")
currResources := []construct.ResourceId{resource.ID}
for _, part := range parts {
var nextResources []construct.ResourceId
for _, currResource := range currResources {
val, err := cfgCtx.FieldValue(part, currResource)
if err != nil {
_, err = assign(currResource, part)
if err != nil {
errs = errors.Join(errs, err)
}
continue
}
if id, ok := val.(construct.ResourceId); ok {
nextResources = append(nextResources, id)
} else if ids, ok := val.([]construct.ResourceId); ok {
nextResources = append(nextResources, ids...)
}
}
currResources = nextResources
}
return errs
}
// isValid checks if the candidate is valid based on what is downstream of the resourceToCheck
func (d downstreamChecker) isValid(resourceToCheck, targetResource construct.ResourceId) (bool, error) {
downstreams, err := solution.Downstream(d.ctx, resourceToCheck, knowledgebase.FirstFunctionalLayer)
if err != nil {
return false, err
}
return collectionutil.Contains(downstreams, targetResource), nil
}
package path_selection
import (
"errors"
"github.com/dominikbraun/graph"
"github.com/klothoplatform/klotho/pkg/collectionutil"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/solution"
"github.com/klothoplatform/klotho/pkg/graph_addons"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
// determineCandidateWeight determines the weight of a candidate resource based on its relationship to the src and target resources
// and if it is already in the result graph.
//
// The weight is determined by the following:
// 1. If the candidate is downstream of the src or upstream of the target, add 10 to the weight
// 2. If the candidate is in the result graph, add 9 to the weight
// 3. if the candidate is existing determine how close it is to the src and target resources for additional weighting
//
// 'undirected' is from the 'ctx' raw view, but given as an argument here to avoid having to recompute it.
// 'desc' return is purely for debugging purposes, describing the weight calculation.
func determineCandidateWeight(
ctx solution.Solution,
src, target construct.ResourceId,
id construct.ResourceId,
resultGraph construct.Graph,
undirected construct.Graph,
) (weight int, errs error) {
// note(gg) perf: these Downstream/Upstream functions don't need the full list and don't need to run twice
downstreams, err := solution.Downstream(ctx, src, knowledgebase.ResourceDirectLayer)
errs = errors.Join(errs, err)
if collectionutil.Contains(downstreams, id) {
weight += 10
} else {
downstreams, err := solution.Downstream(ctx, src, knowledgebase.ResourceGlueLayer)
errs = errors.Join(errs, err)
if collectionutil.Contains(downstreams, id) {
weight += 5
}
}
upstreams, err := solution.Upstream(ctx, target, knowledgebase.ResourceDirectLayer)
errs = errors.Join(errs, err)
if collectionutil.Contains(upstreams, id) {
weight += 10
} else {
upstreams, err := solution.Upstream(ctx, target, knowledgebase.ResourceGlueLayer)
errs = errors.Join(errs, err)
if collectionutil.Contains(upstreams, id) {
weight += 5
}
}
// See if its currently in the result graph and if so add weight to increase chances of being reused
_, err = resultGraph.Vertex(id)
if err == nil {
weight += 9
}
pather, err := construct.ShortestPaths(undirected, id, construct.DontSkipEdges)
if err != nil {
errs = errors.Join(errs, err)
return
}
// We start at 8 so its weighted less than actually being upstream of the target or downstream of the src
availableWeight := 10
shortestPath, err := pather.ShortestPath(src)
if err != nil {
availableWeight = -5
}
for _, res := range shortestPath {
if knowledgebase.GetFunctionality(ctx.KnowledgeBase(), res) != knowledgebase.Unknown {
availableWeight -= 2
} else {
availableWeight -= 1
}
}
shortestPath, err = pather.ShortestPath(target)
if err != nil {
// If we can't find a path to the src then we dont want to add divide by weight since its currently not reachable
availableWeight = -5
}
for _, res := range shortestPath {
if knowledgebase.GetFunctionality(ctx.KnowledgeBase(), res) != knowledgebase.Unknown {
availableWeight -= 1
}
}
// We make sure the divideWeightBy is at least 2 so that reusing resources is always valued higher than creating new ones if possible
if availableWeight < 0 {
availableWeight = 2
}
weight += availableWeight
return
}
func BuildUndirectedGraph(g construct.Graph, kb knowledgebase.TemplateKB) (construct.Graph, error) {
undirected := graph.NewWithStore(
construct.ResourceHasher,
graph_addons.NewMemoryStore[construct.ResourceId, *construct.Resource](),
graph.Weighted(),
)
err := undirected.AddVerticesFrom(g)
if err != nil {
return nil, err
}
edges, err := g.Edges()
if err != nil {
return nil, err
}
for _, e := range edges {
weight := 1
// increase weights for edges that are connected to a functional resource
if knowledgebase.GetFunctionality(kb, e.Source) != knowledgebase.Unknown {
weight = 1000
} else if knowledgebase.GetFunctionality(kb, e.Target) != knowledgebase.Unknown {
weight = 1000
}
err := undirected.AddEdge(e.Source, e.Target, graph.EdgeWeight(weight))
if err != nil {
return nil, err
}
}
return undirected, nil
}
package path_selection
import (
"errors"
"fmt"
"slices"
"github.com/dominikbraun/graph"
"github.com/klothoplatform/klotho/pkg/knowledgebase"
)
func ClassPaths(
kb knowledgebase.Graph,
start, end string,
classification string,
cb func([]string) error,
) error {
adjacencyMap, err := kb.AdjacencyMap()
if err != nil {
return err
}
startTmpl, err := kb.Vertex(start)
if err != nil {
return fmt.Errorf("failed to find start template: %w", err)
}
return classPaths(
kb,
adjacencyMap,
start, end,
classification,
cb,
[]string{start},
classification == "" || slices.Contains(startTmpl.Classification.Is, classification),
)
}
var (
SkipPathErr = errors.New("skip path")
)
func classPaths(
kb knowledgebase.Graph,
adjacencyMap map[string]map[string]graph.Edge[string],
start, end string,
classification string,
cb func([]string) error,
currentPath []string,
classificationSatisfied bool,
) error {
last := currentPath[len(currentPath)-1]
frontier := adjacencyMap[last]
if len(frontier) == 0 {
return nil
}
var errs []error
for next := range frontier {
if slices.Contains(currentPath, next) {
// Prevent infinite looping, since the knowledge base can be cyclic
continue
}
nextClassificationSatisfied := classificationSatisfied
edge, err := kb.Edge(last, next)
if err != nil {
errs = append(errs, err)
continue
}
edgeTmpl := edge.Properties.Data.(*knowledgebase.EdgeTemplate)
if edgeTmpl.DirectEdgeOnly {
continue
}
if !nextClassificationSatisfied && slices.Contains(edgeTmpl.Classification, classification) {
nextClassificationSatisfied = true
}
tmpl, err := kb.Vertex(next)
if err != nil {
errs = append(errs, err)
continue
}
if next != end {
// ContainsUnneccessaryHopsInPath
if fct := tmpl.GetFunctionality(); fct != knowledgebase.Unknown {
continue
}
}
if !nextClassificationSatisfied && slices.Contains(tmpl.Classification.Is, classification) {
nextClassificationSatisfied = true
}
if classification != "" && slices.Contains(tmpl.PathSatisfaction.DenyClassifications, classification) {
continue
}
// NOTE(gg): The old code let the end point satisfy the classification. Is this correct?
if next == end && nextClassificationSatisfied {
if err := cb(append(currentPath, end)); err != nil {
errs = append(errs, err)
}
continue
} else if next != end {
err := classPaths(
kb,
adjacencyMap,
start, end,
classification,
cb,
// This append is okay because we're only doing one path at a time, in DFS.
// Otherwise, we'd need to copy the slice. This is why we use DFS instead of BFS or a stack-based approach
// (like used in [graph.AllPathsBetween]).
append(currentPath, next),
nextClassificationSatisfied,
)
if err != nil {
errs = append(errs, err)
}
}
}
if err := errors.Join(errs...); err != nil {
return fmt.Errorf("failed to find paths from %s: %w", last, err)
}
return nil
}
package path_selection
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"sync"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/dot"
"github.com/klothoplatform/klotho/pkg/engine/debug"
"github.com/klothoplatform/klotho/pkg/set"
"go.uber.org/zap"
)
// seenFiles is used to keep track of which files have been added to by this execution
// so that it can tell when to append (when already seen by this execution) or truncate
// (to reset between executions)
var seenFiles = make(set.Set[string])
var seenFilesLock = new(sync.Mutex)
func writeGraph(ctx context.Context, input ExpansionInput, working, result construct.Graph) {
dir := "selection"
if debugDir := debug.GetDebugDir(ctx); debugDir != "" {
dir = filepath.Join(debugDir, "selection")
}
err := os.MkdirAll(dir, 0755)
if err != nil && !errors.Is(err, os.ErrExist) {
zap.S().Warnf("Could not create folder for selection diagram: %v", err)
return
}
fprefix := fmt.Sprintf("%s-%s", input.SatisfactionEdge.Source.ID, input.SatisfactionEdge.Target.ID)
fprefix = strings.ReplaceAll(fprefix, ":", "_") // some filesystems (NTFS) don't like colons in filenames
fprefix = filepath.Join(dir, fprefix)
f, err := os.OpenFile(fprefix+".gv", os.O_RDWR|os.O_CREATE, 0755)
if err != nil {
zap.S().Errorf("could not create file %s: %v", fprefix, err)
return
}
defer f.Close()
seenFilesLock.Lock()
if !seenFiles.Contains(f.Name()) {
seenFiles.Add(f.Name())
err := f.Truncate(0)
if err != nil {
zap.S().Errorf("could not truncate file %s: %v", f.Name(), err)
seenFilesLock.Unlock()
return
}
}
seenFilesLock.Unlock()
dotContent := new(bytes.Buffer)
_, err = io.Copy(dotContent, f)
if err != nil {
zap.S().Errorf("could not read file %s: %v", f.Name(), err)
return
}
if dotContent.Len() > 0 {
content := strings.TrimSpace(dotContent.String())
content = strings.TrimSuffix(content, "}")
dotContent.Reset()
dotContent.WriteString(content)
} else {
fmt.Fprintf(dotContent, `digraph {
label = "%s → %s"
rankdir = LR
labelloc = t
graph [ranksep = 2]
`, input.SatisfactionEdge.Source.ID, input.SatisfactionEdge.Target.ID)
}
err = graphToDOTCluster(input.Classification, working, result, dotContent)
if err != nil {
zap.S().Errorf("could not render graph for %s: %v", fprefix, err)
return
}
fmt.Fprintln(dotContent, "}")
content := dotContent.String()
_, err = f.Seek(0, 0)
if err == nil {
_, err = io.Copy(f, strings.NewReader(content))
}
if err != nil {
zap.S().Errorf("could not write file %s: %v", f.Name(), err)
return
}
svgContent, err := dot.ExecPan(strings.NewReader(content))
if err != nil {
zap.S().Errorf("could not render graph to file %s: %v", fprefix, err)
return
}
svgFile, err := os.Create(fprefix + ".gv.svg")
if err != nil {
zap.S().Errorf("could not create file %s.gv.svg: %v", fprefix, err)
return
}
defer svgFile.Close()
fmt.Fprint(svgFile, svgContent)
}
package path_selection
import (
"errors"
"fmt"
"io"
"strings"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/dot"
)
const (
notChosenColour = "#e87b7b"
// yellow = "#e3cf9d"
choenColour = "#3f822b"
)
func attributes(result construct.Graph, id construct.ResourceId, props graph.VertexProperties) map[string]string {
a := make(map[string]string)
if strings.HasPrefix(id.Name, PHANTOM_PREFIX) {
a["style"] = "dashed"
a["label"] = id.QualifiedTypeName()
a["shape"] = "ellipse"
} else {
a["label"] = id.String()
a["shape"] = "box"
}
if newidstr := props.Attributes["new_id"]; newidstr != "" {
// vertex is a renamed node
a["style"] = "dashed"
newid := id
_ = newid.UnmarshalText([]byte(newidstr))
name := newid.Name
if newid.Namespace != "" {
name = fmt.Sprintf("%s:%s", newid.Namespace, name)
}
a["label"] = fmt.Sprintf(`%s\n:%s`, id.QualifiedTypeName(), name)
}
if _, err := result.Vertex(id); err == nil {
a["color"] = choenColour
} else {
a["color"] = notChosenColour
}
return a
}
func graphToDOTCluster(class string, working, result construct.Graph, out io.Writer) error {
var errs error
printf := func(s string, args ...any) {
_, err := fmt.Fprintf(out, s, args...)
errs = errors.Join(errs, err)
}
label := class
if class == "" {
class = "default"
label = "<default>"
}
printf(` subgraph cluster_%s {
label = %q
`, class, label)
adj, err := working.AdjacencyMap()
if err != nil {
return err
}
fixId := func(id construct.ResourceId) construct.ResourceId {
_, tProps, _ := working.VertexWithProperties(id)
if tProps.Attributes != nil {
if newid := tProps.Attributes["new_id"]; newid != "" {
errs = errors.Join(errs, id.UnmarshalText([]byte(newid)))
}
}
return id
}
for src, a := range adj {
_, props, _ := working.VertexWithProperties(src)
src = fixId(src)
attribs := attributes(result, src, props)
prefixedSrc := fmt.Sprintf("%s/%s", class, src)
printf(" %q%s\n", prefixedSrc, dot.AttributesToString(attribs))
for tgt, e := range a {
tgt = fixId(tgt)
prefixedTgt := fmt.Sprintf("%s/%s", class, tgt)
edgeAttribs := make(map[string]string)
if _, err := result.Edge(src, tgt); err == nil {
edgeAttribs["color"] = choenColour
edgeAttribs["weight"] = "1000"
edgeAttribs["penwidth"] = "2"
} else {
edgeAttribs["style"] = "dashed"
}
edgeAttribs["label"] = fmt.Sprintf("%d", e.Properties.Weight)
printf(" %q -> %q%s\n", prefixedSrc, prefixedTgt, dot.AttributesToString(edgeAttribs))
}
}
printf(" }\n")
return errs
}
package path_selection
import (
"errors"
"fmt"
"reflect"
"strings"
"github.com/dominikbraun/graph"
"github.com/klothoplatform/klotho/pkg/collectionutil"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/solution"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
// checkUniquenessValidity checks if the candidate is valid based on if it is intended to be created as unique resource
// for another resource. If the resource was created by an operational rule with the unique flag, we wont consider it as valid
func checkUniquenessValidity(
ctx solution.Solution,
src, trgt construct.ResourceId,
) (bool, error) {
// check if the node is a phantom node
source, err := ctx.RawView().Vertex(src)
switch {
case errors.Is(err, graph.ErrVertexNotFound):
source = &construct.Resource{ID: src}
case err != nil:
return false, err
}
target, err := ctx.RawView().Vertex(trgt)
switch {
case errors.Is(err, graph.ErrVertexNotFound):
target = &construct.Resource{ID: trgt}
case err != nil:
return false, err
}
// check if the upstream resource has a unique rule for the matched resource type
valid, err := checkProperties(ctx, target, source, knowledgebase.DirectionUpstream)
if err != nil {
return false, err
}
if !valid {
return false, nil
}
// check if the downstream resource has a unique rule for the matched resource type
valid, err = checkProperties(ctx, source, target, knowledgebase.DirectionDownstream)
if err != nil {
return false, err
}
if !valid {
return false, nil
}
return true, nil
}
// check properties checks the resource's properties to make sure its not supposed to have a unique toCheck type
// if it is, it makes sure that the toCheck is not used elsewhere or already being used as its unique type
func checkProperties(ctx solution.Solution, resource, toCheck *construct.Resource, direction knowledgebase.Direction) (bool, error) {
//check if the upstream resource has a unique rule for the matched resource type
template, err := ctx.KnowledgeBase().GetResourceTemplate(resource.ID)
if err != nil || template == nil {
return false, fmt.Errorf("error getting resource template for resource %s: %w", resource.ID, err)
}
if strings.Contains(resource.ID.Name, PHANTOM_PREFIX) {
return true, nil
}
explicitlyNotValid := false
explicitlyValid := false
err = template.LoopProperties(resource, func(prop knowledgebase.Property) error {
details := prop.Details()
rule := details.OperationalRule
if rule == nil || len(rule.Step.Resources) == 0 {
return nil
}
step := rule.Step
if !step.Unique || step.Direction != direction {
return nil
}
//check if the upstream resource is the same type as the matched resource type
for _, selector := range step.Resources {
match, err := selector.CanUse(solution.DynamicCtx(ctx), knowledgebase.DynamicValueData{Resource: resource.ID},
toCheck)
if err != nil {
return fmt.Errorf("error checking if resource %s matches selector %s: %w", toCheck.ID, selector.Selector, err)
}
// if its a match for the selectors, lets ensure that it has a dependency and exists in the properties of the rul
if !match {
continue
}
property, err := resource.GetProperty(details.Path)
if err != nil {
return fmt.Errorf("error getting property %s for resource %s: %w", details.Path, toCheck.ID, err)
}
if property != nil {
if checkIfPropertyContainsResource(property, toCheck.ID) {
explicitlyValid = true
return knowledgebase.ErrStopWalk
}
} else {
loneDep, err := checkIfLoneDependency(ctx, resource.ID, toCheck.ID, direction, selector)
if err != nil {
return err
}
if loneDep {
explicitlyValid = true
return knowledgebase.ErrStopWalk
}
}
explicitlyNotValid = true
return knowledgebase.ErrStopWalk
}
return nil
})
if err != nil {
return false, err
}
if explicitlyValid {
return true, nil
} else if explicitlyNotValid {
return false, nil
}
// if we cant validate uniqueness off of properties we then need to see if the resource was created to be unique
// check if the upstream resource was created as a unique resource by any of its direct dependents
valid, err := checkIfCreatedAsUniqueValidity(ctx, resource, toCheck, direction)
if err != nil {
return false, err
}
if !valid {
return false, nil
}
return true, nil
}
// checkIfPropertyContainsResource checks if the property contains the resource id passed in
func checkIfPropertyContainsResource(property interface{}, resource construct.ResourceId) bool {
switch reflect.ValueOf(property).Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < reflect.ValueOf(property).Len(); i++ {
val := reflect.ValueOf(property).Index(i).Interface()
if id, ok := val.(construct.ResourceId); ok && id.Matches(resource) {
return true
}
if pref, ok := val.(construct.PropertyRef); ok && pref.Resource.Matches(resource) {
return true
}
}
case reflect.Struct:
if id, ok := property.(construct.ResourceId); ok && id.Matches(resource) {
return true
}
if pref, ok := property.(construct.PropertyRef); ok && pref.Resource.Matches(resource) {
return true
}
}
return false
}
func checkIfLoneDependency(ctx solution.Solution,
resource, toCheck construct.ResourceId, direction knowledgebase.Direction,
selector knowledgebase.ResourceSelector) (bool, error) {
var resources []construct.ResourceId
var err error
// we are going to check if the resource was created as a unique resource by any of its direct dependents. if it was and that
// dependent is not the other id, its not a valid candidate for this edge
// here the direction matches because we are checking the resource for being used by another resource similar to other
if direction == knowledgebase.DirectionDownstream {
resources, err = solution.Upstream(ctx, resource, knowledgebase.ResourceDirectLayer)
if err != nil {
return false, err
}
} else {
resources, err = solution.Downstream(ctx, resource, knowledgebase.ResourceDirectLayer)
if err != nil {
return false, err
}
}
if len(resources) == 0 {
return true, nil
} else if len(resources) == 1 && resources[0].Matches(toCheck) {
return true, nil
} else {
for _, res := range resources {
depRes, err := ctx.RawView().Vertex(res)
if err != nil {
return false, err
}
data := knowledgebase.DynamicValueData{Resource: resource}
dynCtx := solution.DynamicCtx(ctx)
canUse, err := selector.CanUse(dynCtx, data, depRes)
if err != nil {
return false, err
}
if canUse {
return false, nil
}
}
return true, nil
}
}
// checkIfCreatedAsUnique checks if the resource was created as a unique resource by any of its direct dependents. if it was and that
// dependent is not the other id, its not a valid candidate for this edge
func checkIfCreatedAsUniqueValidity(ctx solution.Solution, resource, other *construct.Resource, direction knowledgebase.Direction) (bool, error) {
var resources []construct.ResourceId
var foundMatch bool
var err error
// we are going to check if the resource was created as a unique resource by any of its direct dependents. if it was and that
// dependent is not the other id, its not a valid candidate for this edge
// here the direction matches because we are checking the resource for being used by another resource similar to other
if direction == knowledgebase.DirectionUpstream {
resources, err = solution.Upstream(ctx, resource.ID, knowledgebase.ResourceDirectLayer)
if err != nil {
return false, err
}
} else {
resources, err = solution.Downstream(ctx, resource.ID, knowledgebase.ResourceDirectLayer)
if err != nil {
return false, err
}
}
// if the dependencies contains the other resource, dont run any checks as we assume its valid
if collectionutil.Contains(resources, other.ID) {
return true, nil
}
for _, res := range resources {
// check if the upstream resource has a unique rule for the matched resource type
template, err := ctx.KnowledgeBase().GetResourceTemplate(res)
if err != nil || template == nil {
return false, fmt.Errorf("error getting resource template for resource %s: %w", res, err)
}
currRes, err := ctx.RawView().Vertex(res)
if err != nil {
return false, err
}
err = template.LoopProperties(currRes, func(prop knowledgebase.Property) error {
details := prop.Details()
rule := details.OperationalRule
if rule == nil || len(rule.Step.Resources) == 0 {
return nil
}
step := rule.Step
// we want the step to be the opposite of the direction passed in so we know its creating the resource in the direction of the resource
// since we are looking at the resources dependencies
if !step.Unique || step.Direction == direction {
return nil
}
//check if the upstream resource is the same type as the matched resource type
for _, selector := range step.Resources {
match, err := selector.CanUse(solution.DynamicCtx(ctx), knowledgebase.DynamicValueData{Resource: currRes.ID},
resource)
if err != nil {
return fmt.Errorf("error checking if resource %s matches selector %s: %w", other.ID, selector.Selector, err)
}
// if its a match for the selectors, lets ensure that it has a dependency and exists in the properties of the rul
if !match {
continue
}
foundMatch = true
return knowledgebase.ErrStopWalk
}
return nil
})
if err != nil {
return false, err
}
if foundMatch {
return false, nil
}
}
return true, nil
}
package path_selection
import (
"errors"
"fmt"
"sort"
"strings"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
engine_errs "github.com/klothoplatform/klotho/pkg/engine/errors"
"github.com/klothoplatform/klotho/pkg/engine/solution"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/klothoplatform/klotho/pkg/logging"
"github.com/klothoplatform/klotho/pkg/set"
)
//go:generate mockgen -source=./path_expansion.go --destination=../operational_eval/path_expansion_mock_test.go --package=operational_eval
type (
ExpansionInput struct {
ExpandEdge construct.SimpleEdge
SatisfactionEdge construct.ResourceEdge
Classification string
TempGraph construct.Graph
}
ExpansionResult struct {
Edges []graph.Edge[construct.ResourceId]
Graph construct.Graph
}
EdgeExpander interface {
ExpandEdge(input ExpansionInput) (ExpansionResult, error)
}
EdgeExpand struct {
Ctx solution.Solution
}
)
func (e *EdgeExpand) ExpandEdge(
input ExpansionInput,
) (ExpansionResult, error) {
ctx := e.Ctx
tempGraph := input.TempGraph
dep := input.SatisfactionEdge
result := ExpansionResult{
Graph: construct.NewGraph(),
}
defer writeGraph(ctx.Context(), input, tempGraph, result.Graph)
var errs error
// TODO: Revisit if we want to run on namespaces (this causes issue depending on what the namespace is)
// A file system can be a namespace and that doesnt really fit the reason we are running this at the moment
// errs = errors.Join(errs, runOnNamespaces(dep.Source, dep.Target, ctx, result))
connected, err := connectThroughNamespace(dep.Source, dep.Target, ctx, result)
if err != nil {
errs = errors.Join(errs, err)
}
if !connected {
edges, err := expandEdge(ctx, input, result.Graph)
errs = errors.Join(errs, err)
result.Edges = append(result.Edges, edges...)
}
return result, errs
}
func expandEdge(
ctx solution.Solution,
input ExpansionInput,
g construct.Graph,
) ([]graph.Edge[construct.ResourceId], error) {
paths, err := graph.AllPathsBetween(input.TempGraph, input.SatisfactionEdge.Source.ID, input.SatisfactionEdge.Target.ID)
if err != nil {
return nil, err
}
if len(paths) == 0 {
return nil, engine_errs.UnsupportedExpansionErr{
ExpandEdge: input.ExpandEdge,
SatisfactionEdge: construct.SimpleEdge{
Source: input.SatisfactionEdge.Source.ID,
Target: input.SatisfactionEdge.Target.ID,
},
Classification: input.Classification,
}
}
sort.Slice(paths, func(i, j int) bool {
il, jl := len(paths[i]), len(paths[j])
if il != jl {
return il < jl
}
pi, pj := paths[i], paths[j]
for k := 0; k < il; k++ {
if pi[k] != pj[k] {
return construct.ResourceIdLess(pi[k], pj[k])
}
}
return false
})
undirected, err := BuildUndirectedGraph(ctx.RawView(), ctx.KnowledgeBase())
if err != nil {
return nil, err
}
var errs error
// represents id to qualified type because we dont need to do that processing more than once
for _, path := range paths {
err := expandPath(ctx, undirected, input, path, g)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("error expanding path %s: %w", construct.Path(path), err))
}
}
if errs != nil {
return nil, errs
}
path, err := graph.ShortestPathStable(
input.TempGraph,
input.SatisfactionEdge.Source.ID,
input.SatisfactionEdge.Target.ID,
construct.ResourceIdLess,
)
if err != nil {
// NOTE(gg) this can't happen with the current expandPath implementation
// but may in the future.
return nil, engine_errs.InvalidPathErr{
ExpandEdge: input.ExpandEdge,
SatisfactionEdge: construct.SimpleEdge{
Source: input.SatisfactionEdge.Source.ID,
Target: input.SatisfactionEdge.Target.ID,
},
Classification: input.Classification,
}
}
resultResources, err := renameAndReplaceInTempGraph(ctx, input, g, path)
errs = errors.Join(errs, err)
edges, err := findSubExpansionsToRun(resultResources, ctx)
return edges, errors.Join(errs, err)
}
func renameAndReplaceInTempGraph(
ctx solution.Solution,
input ExpansionInput,
g construct.Graph,
path construct.Path,
) ([]*construct.Resource, error) {
var errs error
name := fmt.Sprintf("%s-%s", input.SatisfactionEdge.Source.ID.Name, input.SatisfactionEdge.Target.ID.Name)
// rename phantom nodes
result := make([]*construct.Resource, len(path))
for i, id := range path {
r, props, err := input.TempGraph.VertexWithProperties(id)
if err != nil {
errs = errors.Join(errs, err)
continue
}
if strings.HasPrefix(id.Name, PHANTOM_PREFIX) {
id.Name = name
// because certain resources may be namespaced, we will check against all resource type names
currNames, err := getCurrNames(ctx, &id)
if err != nil {
errs = errors.Join(errs, err)
continue
}
for suffix := 0; suffix < 1000; suffix++ {
if !currNames.Contains(id.Name) {
break
}
id.Name = fmt.Sprintf("%s-%d", name, suffix)
}
if props.Attributes != nil {
props.Attributes["new_id"] = id.String()
}
phantomRes := r
r, err = knowledgebase.CreateResource(ctx.KnowledgeBase(), id)
if err != nil {
errs = errors.Join(errs, err)
continue
}
r.Properties = phantomRes.Properties
phantomRes.Properties = make(construct.Properties)
}
err = g.AddVertex(r)
switch {
case errors.Is(err, graph.ErrVertexAlreadyExists):
r, err = g.Vertex(id)
if err != nil {
errs = errors.Join(errs, err)
continue
}
case err != nil:
errs = errors.Join(errs, err)
continue
}
result[i] = r
if i > 0 {
err = g.AddEdge(result[i-1].ID, r.ID)
if err != nil && !errors.Is(err, graph.ErrEdgeAlreadyExists) {
errs = errors.Join(errs, fmt.Errorf("error adding edge for path[%d]: %w", i, err))
}
}
}
if errs != nil {
return nil, errs
}
// We need to replace the phantom nodes in the temp graph in case we reuse the temp graph for sub expansions
for i, res := range result {
err := construct.ReplaceResource(input.TempGraph, path[i], res)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("error replacing path[%d] %s: %w", i, path[i], err))
}
}
return result, errs
}
func getCurrNames(sol solution.Solution, resourceToSet *construct.ResourceId) (set.Set[string], error) {
currNames := make(set.Set[string])
ids, err := construct.TopologicalSort(sol.DataflowGraph())
if err != nil {
return currNames, err
}
// we cannot consider things only in the namespace because when creating a resource for an operational action
// it likely has not been namespaced yet and we dont know where it will be namespaced to
matcher := construct.ResourceId{Provider: resourceToSet.Provider, Type: resourceToSet.Type}
for _, id := range ids {
if matcher.Matches(id) {
currNames.Add(id.Name)
}
}
return currNames, nil
}
func findSubExpansionsToRun(
result []*construct.Resource,
ctx solution.Solution,
) (edges []graph.Edge[construct.ResourceId], errs error) {
resourceTemplates := make(map[construct.ResourceId]*knowledgebase.ResourceTemplate)
added := make(map[construct.ResourceId]map[construct.ResourceId]bool)
getResourceTemplate := func(id construct.ResourceId) *knowledgebase.ResourceTemplate {
rt, found := resourceTemplates[id]
if !found {
var err error
rt, err = ctx.KnowledgeBase().GetResourceTemplate(id)
if err != nil || rt == nil {
errs = errors.Join(errs, fmt.Errorf("could not find resource template for %s: %w", id, err))
return nil
}
}
return rt
}
for i, res := range result {
if i == 0 || i == len(result)-1 {
continue
}
rt := getResourceTemplate(res.ID)
if rt == nil {
continue
}
if len(rt.PathSatisfaction.AsSource) != 0 {
for j := i + 2; j < len(result); j++ {
target := result[j]
rt := getResourceTemplate(target.ID)
if rt == nil {
continue
}
if len(rt.PathSatisfaction.AsTarget) != 0 || j == len(result)-1 {
if _, ok := added[res.ID]; !ok {
added[res.ID] = make(map[construct.ResourceId]bool)
}
if added, ok := added[res.ID][target.ID]; !ok || !added {
edges = append(edges, graph.Edge[construct.ResourceId]{Source: res.ID, Target: target.ID})
}
added[res.ID][target.ID] = true
}
}
}
// do the same logic for asTarget
if len(rt.PathSatisfaction.AsTarget) != 0 {
for j := i - 2; j >= 0; j-- {
source := result[j]
rt := getResourceTemplate(source.ID)
if rt == nil {
continue
}
if len(rt.PathSatisfaction.AsSource) != 0 || j == 0 {
if _, ok := added[source.ID]; !ok {
added[source.ID] = make(map[construct.ResourceId]bool)
}
if added, ok := added[source.ID][res.ID]; !ok || !added {
edges = append(edges, graph.Edge[construct.ResourceId]{Source: source.ID, Target: res.ID})
}
added[source.ID][res.ID] = true
}
}
}
}
return
}
// ExpandEdge takes a given `selectedPath` and resolves it to a path of resourceIds that can be used
// for creating resources, or existing resources.
// 'undirected' is the undirected graph of the dataflow graph from 'ctx' but are a separate input to reuse
// the calculated graph for performance.
func expandPath(
ctx solution.Solution,
undirected construct.Graph,
input ExpansionInput,
path construct.Path,
resultGraph construct.Graph,
) error {
log := logging.GetLogger(ctx.Context()).Sugar()
if len(path) == 2 {
modifiesImport, err := checkModifiesImportedResource(input.SatisfactionEdge.Source.ID,
input.SatisfactionEdge.Target.ID, ctx, nil)
if err != nil {
return err
}
if modifiesImport {
// Because the direct edge will cause modifications to an imported resource, we need to remove the direct edge
return input.TempGraph.RemoveEdge(input.SatisfactionEdge.Source.ID,
input.SatisfactionEdge.Target.ID)
}
}
log.Debugf("Resolving path %s", path)
type candidate struct {
id construct.ResourceId
divideWeightBy int
}
var errs error
nonBoundaryResources := path[1 : len(path)-1]
// candidates maps the nonboundary index to the set of resources that could satisfy it
// this is a helper to make adding all the edges to the graph easier.
candidates := make([]map[construct.ResourceId]int, len(nonBoundaryResources))
newResources := make(set.Set[construct.ResourceId])
// Create new nodes for the path
for i, node := range nonBoundaryResources {
candidates[i] = make(map[construct.ResourceId]int)
candidates[i][node] = 0
resource, err := input.TempGraph.Vertex(node)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("error getting vertex for path[%d]: %w", i, err))
continue
}
// we know phantoms are always able to be valid, so we want to ensure we make them valid based on src and target validity checks
// right now we dont want validity checks to be blocking, just preference so we use them to modify the weight
valid, err := checkCandidatesValidity(ctx, resource, path, input.Classification)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("error checking validity of path[%d]: %w", i, err))
continue
}
if !valid {
candidates[i][node] = -1000
}
newResources.Add(node)
}
if errs != nil {
return errs
}
addCandidates := func(id construct.ResourceId, resource *construct.Resource, nerr error) error {
matchIdx := matchesNonBoundary(id, nonBoundaryResources)
if matchIdx < 0 {
return nil
}
valid, err := checkNamespaceValidity(ctx, resource, input.SatisfactionEdge.Target.ID)
if err != nil {
return errors.Join(nerr, fmt.Errorf("error checking namespace validity of %s: %w", resource.ID, err))
}
if !valid {
return nerr
}
// Calculate edge weight for candidate
err = input.TempGraph.AddVertex(resource)
if err != nil && !errors.Is(err, graph.ErrVertexAlreadyExists) {
return errors.Join(nerr, err)
}
if _, ok := candidates[matchIdx][id]; !ok {
candidates[matchIdx][id] = 0
}
weight, err := determineCandidateWeight(ctx, input.SatisfactionEdge.Source.ID, input.SatisfactionEdge.Target.ID, id, resultGraph, undirected)
if err != nil {
return errors.Join(nerr, err)
}
// right now we dont want validity checks to be blocking, just preference so we use them to modify the weight
valid, err = checkCandidatesValidity(ctx, resource, path, input.Classification)
if err != nil {
return errors.Join(nerr, err)
}
if !valid {
weight = -1000
}
candidates[matchIdx][id] += weight
return nerr
}
// We need to add candidates which exist in our current result graph so we can reuse them. We do this in case
// we have already performed expansions to ensure the namespaces are connected, etc
err := construct.WalkGraph(resultGraph, func(id construct.ResourceId, resource *construct.Resource, nerr error) error {
return addCandidates(id, resource, nerr)
})
if err != nil {
errs = errors.Join(errs, fmt.Errorf("error during result graph walk graph: %w", err))
}
// Add all other candidates which exist within the graph
err = construct.WalkGraph(ctx.RawView(), func(id construct.ResourceId, resource *construct.Resource, nerr error) error {
return addCandidates(id, resource, nerr)
})
if err != nil {
errs = errors.Join(errs, fmt.Errorf("error during raw view walk graph: %w", err))
}
edges, err := ctx.DataflowGraph().Edges()
if err != nil {
errs = errors.Join(errs, err)
}
if errs != nil {
return errs
}
// addEdge checks whether the edge should be added according to the following rules:
// 1. If it connects two new resources, always add it
// 2. If the edge exists, and its template specifies it is unique, only add it if it's an existing edge
// 3. Otherwise, add it
addEdge := func(source, target candidate) {
weight := CalculateEdgeWeight(
construct.SimpleEdge{Source: input.SatisfactionEdge.Source.ID, Target: input.SatisfactionEdge.Target.ID},
source.id, target.id,
source.divideWeightBy, target.divideWeightBy,
input.Classification,
ctx.KnowledgeBase())
tmpl := ctx.KnowledgeBase().GetEdgeTemplate(source.id, target.id)
if tmpl == nil {
errs = errors.Join(errs, fmt.Errorf("could not find edge template for %s -> %s", source.id, target.id))
return
}
if !tmpl.Unique.CanAdd(edges, source.id, target.id) {
return
}
modifiesImport, err := checkModifiesImportedResource(source.id, target.id, ctx, tmpl)
if err != nil {
errs = errors.Join(errs, err)
return
}
if modifiesImport {
return
}
// if the edge doesnt exist in the actual graph and there is any uniqueness constraint,
// then we need to check uniqueness validity
_, err = ctx.RawView().Edge(source.id, target.id)
if errors.Is(err, graph.ErrEdgeNotFound) {
if tmpl.Unique.Source || tmpl.Unique.Target {
valid, err := checkUniquenessValidity(ctx, source.id, target.id)
if err != nil {
errs = errors.Join(errs, err)
return
}
if !valid {
return
}
}
} else if err != nil {
errs = errors.Join(errs, fmt.Errorf("unexpected error from raw edge: %v", err))
return
}
err = input.TempGraph.AddEdge(source.id, target.id, graph.EdgeWeight(weight))
switch {
case errors.Is(err, graph.ErrEdgeAlreadyExists):
errs = errors.Join(errs, input.TempGraph.UpdateEdge(source.id, target.id, graph.EdgeWeight(weight)))
case errors.Is(err, graph.ErrEdgeCreatesCycle):
// ignore cycles
case err != nil:
errs = errors.Join(errs, fmt.Errorf("unexpected error adding edge to temp graph: %v", err))
}
}
for i, resCandidates := range candidates {
for id, weight := range resCandidates {
if i == 0 {
addEdge(candidate{id: input.SatisfactionEdge.Source.ID}, candidate{id: id, divideWeightBy: weight})
continue
}
sources := candidates[i-1]
for source, w := range sources {
addEdge(candidate{id: source, divideWeightBy: w}, candidate{id: id, divideWeightBy: weight})
}
}
}
if len(candidates) > 0 {
for c, weight := range candidates[len(candidates)-1] {
addEdge(candidate{id: c, divideWeightBy: weight}, candidate{id: input.SatisfactionEdge.Target.ID})
}
}
if errs != nil {
return errs
}
return nil
}
func connectThroughNamespace(src, target *construct.Resource, sol solution.Solution, result ExpansionResult) (
connected bool,
errs error,
) {
kb := sol.KnowledgeBase()
targetNamespaceResource, _ := kb.GetResourcesNamespaceResource(target)
if targetNamespaceResource.IsZero() {
return
}
downstreams, err := solution.Downstream(sol, src.ID, knowledgebase.ResourceLocalLayer)
if err != nil {
return connected, err
}
for _, downId := range downstreams {
// Right now we only check for side effects of the same type
// We may want to check for any side effects that could be namespaced into the target namespace since that would influence
// the source resources connection to that target namespace resource
if downId.QualifiedTypeName() != targetNamespaceResource.QualifiedTypeName() {
continue
}
down, err := sol.RawView().Vertex(downId)
if err != nil {
errs = errors.Join(errs, err)
continue
}
res, _ := kb.GetResourcesNamespaceResource(down)
if res.IsZero() {
continue
}
if res == targetNamespaceResource {
continue
}
// if we have a namespace resource that is not the same as the target namespace resource
tg, err := BuildPathSelectionGraph(
sol.Context(),
construct.SimpleEdge{Source: res, Target: target.ID},
kb,
"",
true,
)
if err != nil {
continue
}
input := ExpansionInput{
SatisfactionEdge: construct.ResourceEdge{Source: down, Target: target},
Classification: "",
TempGraph: tg,
}
edges, err := expandEdge(sol, input, result.Graph)
if err != nil {
errs = errors.Join(errs, err)
continue
}
result.Edges = append(result.Edges, edges...)
connected = true
}
return
}
// NOTE(gg): if for some reason the path could contain a duplicated selector
// this would just add the resource to the first match. I don't
// think this should happen for a call into [ExpandEdge], but noting it just in case.
func matchesNonBoundary(id construct.ResourceId, nonBoundaryResources []construct.ResourceId) int {
for i, node := range nonBoundaryResources {
typedNodeId := construct.ResourceId{Provider: node.Provider, Type: node.Type, Namespace: node.Namespace}
if typedNodeId.Matches(id) {
return i
}
}
return -1
}
package path_selection
import (
"context"
"errors"
"fmt"
"github.com/dominikbraun/graph"
"github.com/klothoplatform/klotho/pkg/collectionutil"
"github.com/klothoplatform/klotho/pkg/construct"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/klothoplatform/klotho/pkg/logging"
)
// PHANTOM_PREFIX deliberately uses an invalid character so if it leaks into an actual input/output, it will
// fail to parse.
const PHANTOM_PREFIX = "phantom$"
const GLUE_WEIGHT = 100
const FUNCTIONAL_WEIGHT = 100000
func BuildPathSelectionGraph(
ctx context.Context,
dep construct.SimpleEdge,
kb knowledgebase.TemplateKB,
classification string,
ignoreDirectEdge bool,
) (construct.Graph, error) {
log := logging.GetLogger(ctx).Sugar()
log.Debugf("Building path selection graph for %s", dep)
tempGraph := construct.NewAcyclicGraph(graph.Weighted())
// Check to see if there is a direct edge which satisfies the classification and if so short circuit in building the temp graph
et := kb.GetEdgeTemplate(dep.Source, dep.Target)
if !ignoreDirectEdge && et != nil && dep.Source.Namespace == dep.Target.Namespace {
directEdgeSatisfies := collectionutil.Contains(et.Classification, classification)
if !directEdgeSatisfies {
srcRt, err := kb.GetResourceTemplate(dep.Source)
if err != nil {
return nil, err
}
dst, err := kb.GetResourceTemplate(dep.Source)
if err != nil {
return nil, err
}
directEdgeSatisfies = collectionutil.Contains(srcRt.Classification.Is, classification) ||
collectionutil.Contains(dst.Classification.Is, classification)
}
if directEdgeSatisfies {
err := tempGraph.AddVertex(&construct.Resource{ID: dep.Source})
if err != nil && !errors.Is(err, graph.ErrVertexAlreadyExists) {
return nil, fmt.Errorf("failed to add source vertex to path selection graph for %s: %w", dep, err)
}
err = tempGraph.AddVertex(&construct.Resource{ID: dep.Target})
if err != nil && !errors.Is(err, graph.ErrVertexAlreadyExists) {
return nil, fmt.Errorf("failed to add target vertex to path selection graph for %s: %w", dep, err)
}
err = tempGraph.AddEdge(dep.Source, dep.Target, graph.EdgeWeight(CalculateEdgeWeight(dep, dep.Source, dep.Target, 0, 0, classification, kb)))
if err != nil {
return nil, err
}
return tempGraph, nil
}
}
// Panic is okay on the cast in following line since it will only happen on programming error
kbGraph := kb.(interface{ Graph() knowledgebase.Graph }).Graph()
err := tempGraph.AddVertex(&construct.Resource{ID: dep.Source, Properties: make(construct.Properties)})
if err != nil {
return nil, fmt.Errorf("failed to add source vertex to path selection graph for %s: %w", dep, err)
}
err = tempGraph.AddVertex(&construct.Resource{ID: dep.Target, Properties: make(construct.Properties)})
if err != nil {
return nil, fmt.Errorf("failed to add target vertex to path selection graph for %s: %w", dep, err)
}
satisfied_paths := 0
addPath := func(path []string) error {
var prevId construct.ResourceId
for i, typeName := range path {
tmpl, err := kbGraph.Vertex(typeName)
if err != nil {
return fmt.Errorf("failed to get template for path[%d]: %w", i, err)
}
var id construct.ResourceId
switch i {
case 0:
prevId = dep.Source
continue
case len(path) - 1:
id = dep.Target
default:
id, err = makePhantom(tempGraph, tmpl.Id())
if err != nil {
return fmt.Errorf("failed to make phantom for path[%d]: %w", i, err)
}
res := &construct.Resource{ID: id, Properties: make(construct.Properties)}
if err := tempGraph.AddVertex(res); err != nil {
return fmt.Errorf("failed to add phantom vertex for path[%d]: %w", i, err)
}
}
weight := graph.EdgeWeight(CalculateEdgeWeight(dep, prevId, id, 0, 0, classification, kb))
if err := tempGraph.AddEdge(prevId, id, weight); err != nil {
return fmt.Errorf("failed to add edge[%d] %s -> %s: %w", i-1, prevId, id, err)
}
prevId = id
}
satisfied_paths++
return nil
}
err = ClassPaths(kbGraph, dep.Source.QualifiedTypeName(), dep.Target.QualifiedTypeName(), classification, addPath)
if err != nil {
return nil, fmt.Errorf("failed to find paths for %s: %w", dep, err)
}
log.Debugf("Found %d paths for %s :: %s", satisfied_paths, dep, classification)
return tempGraph, nil
}
func makePhantom(g construct.Graph, id construct.ResourceId) (construct.ResourceId, error) {
for suffix := 0; suffix < 1000; suffix++ {
candidate := id
candidate.Name = fmt.Sprintf("%s%d", PHANTOM_PREFIX, suffix)
if _, err := g.Vertex(candidate); errors.Is(err, graph.ErrVertexNotFound) {
return candidate, nil
}
}
return id, fmt.Errorf("exhausted suffixes for creating phantom for %s", id)
}
func CalculateEdgeWeight(
dep construct.SimpleEdge,
source, target construct.ResourceId,
divideSourceBy, divideTargetBy int,
classification string,
kb knowledgebase.TemplateKB,
) int {
if divideSourceBy == 0 {
divideSourceBy = 1
}
if divideTargetBy == 0 {
divideTargetBy = 1
}
// check to see if the resources match the classification being solved and account for their weights accordingly
sourceTemplate, err := kb.GetResourceTemplate(source)
if err == nil || sourceTemplate != nil {
if collectionutil.Contains(sourceTemplate.Classification.Is, classification) {
divideSourceBy += 10
}
}
targetTemplate, err := kb.GetResourceTemplate(target)
if err == nil || targetTemplate != nil {
if collectionutil.Contains(targetTemplate.Classification.Is, classification) {
divideTargetBy += 10
}
}
// We start with a weight of 10 for glue and 10000 for functionality for newly created edges of "phantom" resources
// We do so to allow for the preference of existing resources since we can multiply these weights by a decimal
// This will achieve priority for existing resources over newly created ones
weight := 0
if knowledgebase.GetFunctionality(kb, source) != knowledgebase.Unknown && !source.Matches(dep.Source) {
if divideSourceBy > 0 {
weight += (FUNCTIONAL_WEIGHT / divideSourceBy)
} else if divideSourceBy < 0 {
weight += (FUNCTIONAL_WEIGHT * divideSourceBy * -1)
}
} else {
if divideSourceBy > 0 {
weight += (GLUE_WEIGHT / divideSourceBy)
} else if divideSourceBy < 0 {
weight += (GLUE_WEIGHT * divideSourceBy * -1)
}
}
if knowledgebase.GetFunctionality(kb, target) != knowledgebase.Unknown && !target.Matches(dep.Target) {
if divideTargetBy > 0 {
weight += (FUNCTIONAL_WEIGHT / divideTargetBy)
} else if divideTargetBy < 0 {
weight += (FUNCTIONAL_WEIGHT * divideTargetBy * -1)
}
} else {
if divideTargetBy > 0 {
weight += (GLUE_WEIGHT / divideTargetBy)
} else if divideTargetBy < 0 {
weight += (GLUE_WEIGHT * divideTargetBy * -1)
}
}
et := kb.GetEdgeTemplate(source, target)
if et != nil && et.EdgeWeightMultiplier != 0 {
return int(float32(weight) * et.EdgeWeightMultiplier)
}
return weight
}
package path_selection
import (
"errors"
"fmt"
"slices"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/solution"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
func GetPaths(
sol solution.Solution,
source, target construct.ResourceId,
pathValidityChecks func(source, target construct.ResourceId, path construct.Path) bool,
hasPathCheck bool,
) ([]construct.Path, error) {
var errs error
var result []construct.Path
pathsCache := map[construct.SimpleEdge][][]construct.ResourceId{}
pathSatisfactions, err := sol.KnowledgeBase().GetPathSatisfactionsFromEdge(source, target)
if err != nil {
return result, err
}
sourceRes, err := sol.RawView().Vertex(source)
if err != nil {
return result, fmt.Errorf("has path could not find source resource %s: %w", source, err)
}
targetRes, err := sol.RawView().Vertex(target)
if err != nil {
return result, fmt.Errorf("has path could not find target resource %s: %w", target, err)
}
edge := construct.ResourceEdge{Source: sourceRes, Target: targetRes}
for _, satisfaction := range pathSatisfactions {
expansions, err := DeterminePathSatisfactionInputs(sol, satisfaction, edge)
if err != nil {
return result, err
}
for _, expansion := range expansions {
simple := construct.SimpleEdge{Source: expansion.SatisfactionEdge.Source.ID, Target: expansion.SatisfactionEdge.Target.ID}
paths, found := pathsCache[simple]
if !found {
var err error
paths, err = graph.AllPathsBetween(sol.RawView(), expansion.SatisfactionEdge.Source.ID, expansion.SatisfactionEdge.Target.ID)
if err != nil {
errs = errors.Join(errs, err)
continue
}
pathsCache[simple] = paths
}
if len(paths) == 0 {
return nil, nil
}
// we have to track the result of each expansion because if we cant find a path for a single expansion
// we denote that we dont have an actual path from src -> target
var expansionResult []construct.ResourceId
if expansion.Classification != "" {
PATHS:
for _, path := range paths {
for i, res := range path {
if i == 0 {
continue
}
if et := sol.KnowledgeBase().GetEdgeTemplate(path[i-1], res); et != nil && et.DirectEdgeOnly {
continue PATHS
}
}
if !pathSatisfiesClassification(sol.KnowledgeBase(), path, expansion.Classification) {
continue PATHS
}
if !pathValidityChecks(source, target, path) {
continue PATHS
}
result = append(result, path)
expansionResult = path
if hasPathCheck {
break
}
}
} else {
expansionResult = paths[0]
for _, path := range paths {
result = append(result, path)
}
if hasPathCheck {
break
}
}
if expansionResult == nil {
return nil, nil
}
}
}
return result, nil
}
func DeterminePathSatisfactionInputs(
sol solution.Solution,
satisfaction knowledgebase.EdgePathSatisfaction,
edge construct.ResourceEdge,
) (expansions []ExpansionInput, errs error) {
srcIds := construct.ResourceList{edge.Source.ID}
targetIds := construct.ResourceList{edge.Target.ID}
var err error
if satisfaction.Source.PropertyReferenceChangesBoundary() {
srcIds, err = solution.GetResourcesFromPropertyReference(sol, edge.Source.ID, satisfaction.Source.PropertyReference)
if err != nil {
errs = errors.Join(errs, fmt.Errorf(
"failed to determine path satisfaction inputs. could not find resource %s: %w",
edge.Source.ID, err,
))
}
}
if satisfaction.Target.PropertyReferenceChangesBoundary() {
targetIds, err = solution.GetResourcesFromPropertyReference(sol, edge.Target.ID, satisfaction.Target.PropertyReference)
if err != nil {
errs = errors.Join(errs, fmt.Errorf(
"failed to determine path satisfaction inputs. could not find resource %s: %w",
edge.Target.ID, err,
))
}
}
if satisfaction.Source.Script != "" {
dynamicCtx := solution.DynamicCtx(sol)
err = dynamicCtx.ExecuteDecode(satisfaction.Source.Script,
knowledgebase.DynamicValueData{Resource: edge.Source.ID}, &srcIds)
if err != nil {
errs = errors.Join(errs, fmt.Errorf(
"failed to determine path satisfaction source inputs. could not run script for %s: %w",
edge.Source.ID, err,
))
}
}
if satisfaction.Target.Script != "" {
dynamicCtx := solution.DynamicCtx(sol)
err = dynamicCtx.ExecuteDecode(satisfaction.Target.Script,
knowledgebase.DynamicValueData{Resource: edge.Target.ID}, &targetIds)
if err != nil {
errs = errors.Join(errs, fmt.Errorf(
"failed to determine path satisfaction target inputs. could not run script for %s: %w",
edge.Target.ID, err,
))
}
}
for _, srcId := range srcIds {
for _, targetId := range targetIds {
if srcId == targetId {
continue
}
src, err := sol.RawView().Vertex(srcId)
if err != nil {
errs = errors.Join(errs, fmt.Errorf(
"failed to determine path satisfaction inputs. could not find resource %s: %w",
srcId, err,
))
continue
}
target, err := sol.RawView().Vertex(targetId)
if err != nil {
errs = errors.Join(errs, fmt.Errorf(
"failed to determine path satisfaction inputs. could not find resource %s: %w",
targetId, err,
))
continue
}
e := construct.ResourceEdge{Source: src, Target: target}
exp := ExpansionInput{
SatisfactionEdge: e,
Classification: satisfaction.Classification,
}
expansions = append(expansions, exp)
}
}
return
}
func pathSatisfiesClassification(
kb knowledgebase.TemplateKB,
path []construct.ResourceId,
classification string,
) bool {
if containsUnneccessaryHopsInPath(path, kb) {
return false
}
if classification == "" {
return true
}
metClassification := false
for i, res := range path {
resTemplate, err := kb.GetResourceTemplate(res)
if err != nil || slices.Contains(resTemplate.PathSatisfaction.DenyClassifications, classification) {
return false
}
if slices.Contains(resTemplate.Classification.Is, classification) {
metClassification = true
}
if i > 0 {
et := kb.GetEdgeTemplate(path[i-1], res)
if slices.Contains(et.Classification, classification) {
metClassification = true
}
}
}
return metClassification
}
// containsUnneccessaryHopsInPath determines if the path contains any unnecessary hops to get to the destination
//
// We check if the source and destination of the dependency have a functionality. If they do, we check if the functionality of the source or destination
// is the same as the functionality of the source or destination of the edge in the path. If it is then we ensure that the source or destination of the edge
// in the path is not the same as the source or destination of the dependency. If it is then we know that the edge in the path is an unnecessary hop to get to the destination
func containsUnneccessaryHopsInPath(p []construct.ResourceId, kb knowledgebase.TemplateKB) bool {
if len(p) == 2 {
return false
}
// Here we check if the edge or destination functionality exist within the path in another resource. If they do, we know that the path contains unnecessary hops.
for i, res := range p {
// We know that we can skip over the initial source and dest since those are the original edges passed in
if i == 0 || i == len(p)-1 {
continue
}
resTemplate, err := kb.GetResourceTemplate(res)
if err != nil {
return true
}
resFunctionality := resTemplate.GetFunctionality()
// Now we will look to see if there are duplicate functionality in resources within the edge, if there are we will say it contains unnecessary hops. We will verify first that those duplicates dont exist because of a constraint
if resFunctionality != knowledgebase.Unknown {
return true
}
}
return false
}
package reconciler
import (
"errors"
"fmt"
"reflect"
"github.com/klothoplatform/klotho/pkg/collectionutil"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/solution"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/klothoplatform/klotho/pkg/set"
"go.uber.org/zap"
)
type (
deleteRequest struct {
resource construct.ResourceId
explicit bool
}
)
func RemoveResource(c solution.Solution, resource construct.ResourceId, explicit bool) error {
zap.S().Debugf("reconciling removal of resource %s ", resource)
queue := []deleteRequest{{
resource: resource,
explicit: explicit,
}}
var errs error
for len(queue) > 0 {
request := queue[0]
queue = queue[1:]
resource := request.resource
explicit := request.explicit
upstreams, downstreams, err := construct.Neighbors(c.DataflowGraph(), resource)
if err != nil {
errs = errors.Join(errs, err)
continue
}
template, err := c.KnowledgeBase().GetResourceTemplate(resource)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("unable to remove resource: error getting resource template for %s: %v", resource, err))
continue
}
canDelete, err := canDeleteResource(c, resource, explicit, template, upstreams, downstreams)
if err != nil {
errs = errors.Join(errs, err)
continue
}
if !canDelete {
continue
}
// find all namespaced resources before removing edges and the initial resource, otherwise certain resources may
// be moved out of their original namespace
namespacedResources, err := findAllResourcesInNamespace(c, resource)
if err != nil {
errs = errors.Join(errs, err)
continue
}
for res := range namespacedResources {
// Since we may be explicitly deleting the namespace resource,
// we will forward the same explicit flag to the namespace resource
queue = appendToQueue(deleteRequest{resource: res, explicit: explicit}, queue)
// find deployment dependencies to ensure the resource wont get recreated
queue, err = addAllDeploymentDependencies(c, res, explicit, queue)
if err != nil {
errs = errors.Join(errs, err)
}
}
op := c.OperationalView()
// We must remove all edges before removing the vertex
for res := range upstreams {
errs = errors.Join(errs, op.RemoveEdge(res, resource))
}
for res := range downstreams {
errs = errors.Join(errs, op.RemoveEdge(resource, res))
}
errs = errors.Join(errs, deleteRemainingDeploymentDependencies(c, resource))
err = construct.RemoveResource(op, resource)
if err != nil {
errs = errors.Join(errs, err)
continue
}
// try to cleanup, if the resource is removable
for res := range upstreams.Union(downstreams) {
queue = appendToQueue(deleteRequest{resource: res, explicit: false}, queue)
}
}
return errs
}
func appendToQueue(request deleteRequest, queue []deleteRequest) []deleteRequest {
for i, req := range queue {
if req.resource == request.resource && req.explicit == request.explicit {
return queue
// only update if it wasnt explicit previously and now it is
} else if req.resource == request.resource && !req.explicit {
queue[i] = request
return queue
} else if req.resource == request.resource && req.explicit {
return queue
}
}
return append(queue, request)
}
// deleteRemainingDeploymentDependencies removes all deployment dependencies from the graph for the resource specified
func deleteRemainingDeploymentDependencies(
ctx solution.Solution,
resource construct.ResourceId,
) error {
// begin by removing the dependency on resource since we know we are deleting the resource passed in at this point
var errs error
upstreamDeploymentDeps, err := construct.DirectUpstreamDependencies(ctx.DeploymentGraph(), resource)
if err != nil {
errs = errors.Join(errs, err)
} else {
for _, res := range upstreamDeploymentDeps {
err = ctx.DeploymentGraph().RemoveEdge(res, resource)
if err != nil {
errs = errors.Join(errs, err)
continue
}
}
}
downstreamDeploymentDeps, err := construct.DirectDownstreamDependencies(ctx.DeploymentGraph(), resource)
if err != nil {
errs = errors.Join(errs, err)
} else {
for _, res := range downstreamDeploymentDeps {
err = ctx.DeploymentGraph().RemoveEdge(resource, res)
if err != nil {
errs = errors.Join(errs, err)
continue
}
}
}
return errs
}
// addAllDeploymentDependencies adds all deployment dependencies to the queue while removing their dependency on the resource passed in
func addAllDeploymentDependencies(
ctx solution.Solution,
resource construct.ResourceId,
explicit bool,
queue []deleteRequest,
) ([]deleteRequest, error) {
deploymentDeps, err := knowledgebase.Upstream(
ctx.DeploymentGraph(),
ctx.KnowledgeBase(),
resource,
knowledgebase.ResourceDirectLayer,
)
if err != nil {
return nil, err
}
for _, dep := range deploymentDeps {
res, err := ctx.RawView().Vertex(dep)
if err != nil {
return nil, err
}
rt, err := ctx.KnowledgeBase().GetResourceTemplate(dep)
if err != nil {
return nil, err
}
if collectionutil.Contains(queue, deleteRequest{resource: dep, explicit: explicit}) {
continue
}
shouldDelete := false
// check if the dep exists as a property on the resource
err = rt.LoopProperties(res, func(p knowledgebase.Property) error {
propVal, err := res.GetProperty(p.Details().Path)
if err != nil {
return err
}
found := false
switch val := propVal.(type) {
case construct.ResourceId:
if val == resource {
found = true
}
case construct.PropertyRef:
if val.Resource == resource {
found = true
}
default:
if reflect.ValueOf(val).Kind() == reflect.Slice || reflect.ValueOf(val).Kind() == reflect.Array {
for i := 0; i < reflect.ValueOf(val).Len(); i++ {
idVal := reflect.ValueOf(val).Index(i).Interface()
if id, ok := idVal.(construct.ResourceId); ok && id == resource {
found = true
} else if pref, ok := idVal.(construct.PropertyRef); ok && pref.Resource == resource {
found = true
}
}
}
}
if found {
if p.Details().OperationalRule != nil || p.Details().Required {
shouldDelete = true
}
}
return nil
})
if err != nil {
return nil, err
}
if shouldDelete {
queue = appendToQueue(deleteRequest{resource: dep, explicit: explicit}, queue)
queue, err = addAllDeploymentDependencies(ctx, dep, explicit, queue)
if err != nil {
return nil, err
}
}
}
return queue, nil
}
func canDeleteResource(
ctx solution.Solution,
resource construct.ResourceId,
explicit bool,
template *knowledgebase.ResourceTemplate,
upstreamNodes set.Set[construct.ResourceId],
downstreamNodes set.Set[construct.ResourceId],
) (bool, error) {
res, err := ctx.RawView().Vertex(resource)
if err != nil {
return false, err
}
// dont allow deletion of imported or functional resources unless it is explicitly stated
if (template.GetFunctionality() != knowledgebase.Unknown || res.Imported) && !explicit {
return false, nil
}
log := zap.S().With(zap.String("id", resource.String()))
deletionCriteria := template.DeleteContext
ignoreUpstream := ignoreCriteria(ctx, resource, upstreamNodes, knowledgebase.DirectionUpstream)
ignoreDownstream := ignoreCriteria(ctx, resource, downstreamNodes, knowledgebase.DirectionDownstream)
// Check to see if there are upstream nodes for the resource trying to be deleted
// If upstream nodes exist, attempt to delete the resources upstream of the resource before deciding that the deletion process cannot continue
if deletionCriteria.RequiresNoUpstream && !explicit && len(upstreamNodes) > 0 {
log.Debugf("Cannot delete resource %s as it still has upstream dependencies", resource)
if !ignoreUpstream {
return false, nil
}
for up := range upstreamNodes {
err := RemoveResource(ctx, up, false)
if err != nil {
return false, err
}
}
// Now that we have attempted to delete the upstream resources, check to see if there are any upstream resources left for the deletion criteria
upstream, err := construct.DirectUpstreamDependencies(ctx.DataflowGraph(), resource)
if err != nil {
return false, err
}
if len(upstream) > 0 {
return false, fmt.Errorf("cannot delete resource %s as it still has %d upstream dependencies", resource, len(upstream))
}
}
if deletionCriteria.RequiresNoDownstream && !explicit && len(downstreamNodes) > 0 {
log.Debugf("Cannot delete resource %s as it still has downstream dependencies", resource)
if !ignoreDownstream {
return false, nil
}
for down := range downstreamNodes {
err := RemoveResource(ctx, down, false)
if err != nil {
return false, err
}
}
// Now that we have attempted to delete the downstream resources, check to see if there are any downstream resources left for the deletion criteria
downstream, err := construct.DirectDownstreamDependencies(ctx.DataflowGraph(), resource)
if err != nil {
return false, err
}
if len(downstream) > 0 {
return false, fmt.Errorf("cannot delete resource %s as it still has %d downstream dependencies", resource, len(downstream))
}
}
if deletionCriteria.RequiresNoUpstreamOrDownstream && !explicit && len(downstreamNodes) > 0 && len(upstreamNodes) > 0 {
log.Debugf("Cannot delete resource %s as it still has downstream dependencies", resource)
if !ignoreUpstream && !ignoreDownstream {
return false, nil
}
for down := range downstreamNodes {
err := RemoveResource(ctx, down, false)
if err != nil {
return false, err
}
}
for up := range upstreamNodes {
err := RemoveResource(ctx, up, false)
if err != nil {
return false, err
}
}
// Now that we have attempted to delete the downstream resources, check to see if there are any downstream resources left for the deletion criteria
downstream, err := construct.DirectDownstreamDependencies(ctx.DataflowGraph(), resource)
if err != nil {
return false, err
}
// Now that we have attempted to delete the upstream resources, check to see if there are any upstream resources left for the deletion criteria
upstream, err := construct.DirectUpstreamDependencies(ctx.DataflowGraph(), resource)
if err != nil {
return false, err
}
if len(downstream) > 0 && len(upstream) > 0 {
return false, fmt.Errorf(
"cannot delete resource %s as it still has %d upstream and %d downstream dependencies",
resource,
len(upstream),
len(downstream),
)
}
}
return true, nil
}
// ignoreCriteria determines if we can delete a resource because the knowledge base in use by the engine,
// shows that the initial resource is dependent on the sub resource for deletion.
// If the sub resource is deletion dependent on any of the dependent resources passed in then we will determine weather
// we can delete the dependent resource first.
func ignoreCriteria(
ctx solution.Solution,
resource construct.ResourceId,
nodes set.Set[construct.ResourceId],
direction knowledgebase.Direction,
) bool {
if direction == knowledgebase.DirectionDownstream {
for down := range nodes {
t := ctx.KnowledgeBase().GetEdgeTemplate(resource, down)
if t == nil || !t.DeletionDependent {
return false
}
}
} else {
for up := range nodes {
t := ctx.KnowledgeBase().GetEdgeTemplate(up, resource)
if t == nil || !t.DeletionDependent {
return false
}
}
}
return true
}
func findAllResourcesInNamespace(ctx solution.Solution, namespace construct.ResourceId) (set.Set[construct.ResourceId], error) {
namespacedResources := make(set.Set[construct.ResourceId])
err := construct.WalkGraph(ctx.RawView(), func(id construct.ResourceId, resource *construct.Resource, nerr error) error {
if id.Namespace == "" || id.Namespace != namespace.Name {
return nerr
}
rt, err := ctx.KnowledgeBase().GetResourceTemplate(id)
if err != nil {
return errors.Join(nerr, err)
}
if rt == nil {
return errors.Join(nerr, fmt.Errorf("unable to find resource template for %s", id))
}
err = rt.LoopProperties(resource, func(p knowledgebase.Property) error {
if !p.Details().Namespace {
return nil
}
propVal, err := resource.GetProperty(p.Details().Path)
if err != nil {
return err
}
switch val := propVal.(type) {
case construct.ResourceId:
if val.Matches(namespace) {
namespacedResources.Add(id)
}
case construct.PropertyRef:
if val.Resource.Matches(namespace) {
namespacedResources.Add(id)
}
}
return nil
})
return errors.Join(nerr, err)
})
if err != nil {
return nil, err
}
return namespacedResources, nil
}
package reconciler
import (
"errors"
"fmt"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/path_selection"
"github.com/klothoplatform/klotho/pkg/engine/solution"
"github.com/klothoplatform/klotho/pkg/set"
"go.uber.org/zap"
)
// RemovePath removes all paths between the source and target node.
//
// It will determine when edges within those paths are used for contexts outside of the source and target paths and not remove them.
func RemovePath(
source, target construct.ResourceId,
ctx solution.Solution,
) error {
zap.S().Infof("Removing path %s -> %s", source, target)
paths, err := graph.AllPathsBetween(ctx.DataflowGraph(), source, target)
if err != nil {
return err
}
if len(paths) == 0 {
return graph.ErrTargetNotReachable
}
nodes := nodesInPaths(paths)
used, err := nodesUsedOutsideOfContext(nodes, ctx)
if err != nil {
return err
}
edges, err := findEdgesUsedInOtherPathSelection(source, target, nodes, ctx)
if err != nil {
return err
}
var errs error
for _, path := range paths {
errs = errors.Join(errs, removeSinglePath(source, target, path, used, edges, ctx))
}
// Next we will try to delete any node in those paths in case they no longer are required for the architecture
// We will pass the explicit field as false so that explicitly added resources do not get deleted
for _, path := range paths {
for _, resource := range path {
// by this point its possible the resource no longer exists due to being deleted by the removeSinglePath call
// since this is ensuring we dont orphan resources we can ignore the error if we do not find the resource
err := RemoveResource(ctx, resource, false)
if err != nil && !errors.Is(err, graph.ErrVertexNotFound) {
errs = errors.Join(errs, err)
}
}
}
return errs
}
// removeSinglePath removes all edges in a single path between the source and target node, if they are allowed to be removed.
//
// in order for an edge to be removed:
// The source of the edge must not have other upstream dependencies (signaling there are other paths its connecting)
// The edge must not be used in path solving another connection from the source or to the target
func removeSinglePath(
source, target construct.ResourceId,
path []construct.ResourceId,
used set.Set[construct.ResourceId],
edges set.Set[construct.SimpleEdge],
ctx solution.Solution,
) error {
var errs error
// first we will remove all dependencies that make up the paths from the constraints source to target
for i, res := range path {
if i == 0 {
continue
}
// check if the previous resource is used outside of its context.
// Since we are deleting the edge downstream we have to make sure the source is not used,
// resulting in this edge being a part of another path
if used.Contains(path[i-1]) && !target.Matches(path[i-1]) && !source.Matches(path[i-1]) {
continue
}
if edges.Contains(construct.SimpleEdge{Source: path[i-1], Target: res}) {
continue
}
errs = errors.Join(errs, ctx.OperationalView().RemoveEdge(path[i-1], res))
if i > 1 {
errs = errors.Join(errs, RemoveResource(ctx, path[i-1], false))
}
}
return errs
}
func nodesInPaths(
paths [][]construct.ResourceId,
) set.Set[construct.ResourceId] {
nodes := make(set.Set[construct.ResourceId])
for _, path := range paths {
for _, res := range path {
nodes.Add(res)
}
}
return nodes
}
// nodesUsedOutsideOfContext returns all nodes that are used outside of the context.
// Being used outside of the context entails that there are upstream connections in the dataflow graph
// outside of the set of nodes used in the all paths between the source and target node.
//
// We only care about upstream because any extra connections downstream will stay intact and wont result in other
// paths being affected
func nodesUsedOutsideOfContext(
nodes set.Set[construct.ResourceId],
ctx solution.Solution,
) (set.Set[construct.ResourceId], error) {
var errs error
used := make(set.Set[construct.ResourceId])
pred, err := ctx.RawView().PredecessorMap()
if err != nil {
return nil, err
}
for node := range nodes {
upstreams := pred[node]
for upstream := range upstreams {
if !nodes.Contains(upstream) {
used.Add(node)
}
}
}
return used, errs
}
// findEdgesUsedInOtherPathSelection returns all edges that are used in other path selections to the target or from the source.
func findEdgesUsedInOtherPathSelection(
source, target construct.ResourceId,
nodes set.Set[construct.ResourceId],
ctx solution.Solution,
) (set.Set[construct.SimpleEdge], error) {
edges := make(set.Set[construct.SimpleEdge])
var errs error
upstreams, err := construct.AllUpstreamDependencies(ctx.DataflowGraph(), target)
if err != nil {
errs = errors.Join(errs, err)
}
for _, upstream := range upstreams {
if nodes.Contains(upstream) {
continue
}
upstreamRT, err := ctx.KnowledgeBase().GetResourceTemplate(upstream)
if err != nil {
errs = errors.Join(errs, err)
continue
} else if upstreamRT == nil {
errs = errors.Join(errs, fmt.Errorf("resource template %s not found", upstream))
continue
}
if len(upstreamRT.PathSatisfaction.AsSource) == 0 {
continue
}
paths, err := path_selection.GetPaths(ctx, upstream, target,
func(source, target construct.ResourceId, path construct.Path) bool { return true }, false)
if err != nil {
errs = errors.Join(errs, err)
continue
}
for _, path := range paths {
// if the source is in the path then the path is just a superset of the path we are trying to delete
if path.Contains(source) {
continue
}
for i, res := range path {
if i == 0 {
continue
}
edges.Add(construct.SimpleEdge{Source: path[i-1], Target: res})
}
}
}
downstreams, err := construct.AllDownstreamDependencies(ctx.DataflowGraph(), source)
if err != nil {
errs = errors.Join(errs, err)
}
for _, downstream := range downstreams {
if nodes.Contains(downstream) {
continue
}
downstreamRT, err := ctx.KnowledgeBase().GetResourceTemplate(downstream)
if err != nil {
errs = errors.Join(errs, err)
continue
} else if downstreamRT == nil {
errs = errors.Join(errs, fmt.Errorf("resource template %s not found", downstream))
continue
}
if len(downstreamRT.PathSatisfaction.AsTarget) == 0 {
continue
}
paths, err := path_selection.GetPaths(ctx, source, downstream,
func(source, target construct.ResourceId, path construct.Path) bool { return true }, false)
if err != nil {
errs = errors.Join(errs, err)
continue
}
for _, path := range paths {
// if the target is in the path then the path is just a superset of the path we are trying to delete
if path.Contains(target) {
continue
}
for i, res := range path {
if i == 0 {
continue
}
edges.Add(construct.SimpleEdge{Source: path[i-1], Target: res})
}
}
}
return edges, errs
}
package engine
import (
"context"
"errors"
"fmt"
"github.com/klothoplatform/klotho/pkg/logging"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/constraints"
property_eval "github.com/klothoplatform/klotho/pkg/engine/operational_eval"
"github.com/klothoplatform/klotho/pkg/engine/solution"
"github.com/klothoplatform/klotho/pkg/graph_addons"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
type (
// engineSolution implements [solution_context.SolutionContext]
engineSolution struct {
solution.DecisionRecords
KB knowledgebase.TemplateKB
Dataflow construct.Graph
Deployment construct.Graph
constraints *constraints.Constraints
propertyEval *property_eval.Evaluator
globalTag string
outputs map[string]construct.Output
ctx context.Context
}
)
func NewSolution(ctx context.Context, kb knowledgebase.TemplateKB, globalTag string, constraints *constraints.Constraints) *engineSolution {
sol := &engineSolution{
KB: kb,
Dataflow: graph_addons.LoggingGraph[construct.ResourceId, *construct.Resource]{
Log: logging.GetLogger(ctx).Named("dataflow").Sugar(),
Graph: construct.NewGraph(),
Hash: func(r *construct.Resource) construct.ResourceId { return r.ID },
},
Deployment: construct.NewAcyclicGraph(),
constraints: constraints,
globalTag: globalTag,
outputs: make(map[string]construct.Output),
ctx: ctx,
}
sol.propertyEval = property_eval.NewEvaluator(sol)
return sol
}
func (s *engineSolution) Solve() error {
err := s.propertyEval.Evaluate()
if err != nil {
return err
}
return s.captureOutputs()
}
func (s *engineSolution) Context() context.Context {
return s.ctx
}
func (s *engineSolution) RawView() construct.Graph {
return solution.NewRawView(s)
}
func (s *engineSolution) OperationalView() solution.OperationalView {
return (*MakeOperationalView)(s)
}
func (s *engineSolution) DeploymentGraph() construct.Graph {
return s.Deployment
}
func (s *engineSolution) DataflowGraph() construct.Graph {
return s.Dataflow
}
func (s *engineSolution) KnowledgeBase() knowledgebase.TemplateKB {
return s.KB
}
func (s *engineSolution) Constraints() *constraints.Constraints {
return s.constraints
}
func (s *engineSolution) LoadGraph(graph construct.Graph) error {
if graph == nil {
return nil
}
// Since often the input `graph` is loaded from a yaml file, we need to transform all the property values
// to make sure they are of the correct type (eg, a string to ResourceId).
err := knowledgebase.TransformAllPropertyValues(knowledgebase.DynamicValueContext{
Graph: graph,
KnowledgeBase: s.KB,
})
if err != nil {
return err
}
op := s.OperationalView()
raw := s.RawView()
if err := op.AddVerticesFrom(graph); err != nil {
return err
}
edges, err := graph.Edges()
if err != nil {
return err
}
for _, edge := range edges {
edgeTemplate := s.KB.GetEdgeTemplate(edge.Source, edge.Target)
src, err := graph.Vertex(edge.Source)
if err != nil {
return err
}
dst, err := graph.Vertex(edge.Target)
if err != nil {
return err
}
if src.Imported && dst.Imported {
if err := raw.AddEdge(edge.Source, edge.Target); err != nil {
return err
}
continue
}
if edgeTemplate == nil {
return fmt.Errorf("edge template %s -> %s not found", edge.Source, edge.Target)
}
if edgeTemplate.AlwaysProcess {
if err := op.AddEdge(edge.Source, edge.Target); err != nil {
return err
}
} else {
if err := raw.AddEdge(edge.Source, edge.Target); err != nil {
return err
}
}
}
// ensure any deployment dependencies due to properties are in place
return construct.WalkGraph(s.RawView(), func(id construct.ResourceId, resource *construct.Resource, nerr error) error {
return errors.Join(nerr, resource.WalkProperties(func(path construct.PropertyPath, werr error) error {
prop, _ := path.Get()
err := solution.AddDeploymentDependenciesFromVal(s, resource, prop)
return errors.Join(werr, err)
}))
})
}
func (s *engineSolution) GlobalTag() string {
return s.globalTag
}
func (s *engineSolution) captureOutputs() error {
outputConstraints := s.Constraints().Outputs
var errs []error
for _, outputConstraint := range outputConstraints {
if outputConstraint.Ref.Resource.IsZero() {
s.outputs[outputConstraint.Name] = construct.Output{
Value: outputConstraint.Value,
}
continue
}
if _, err := s.Dataflow.Vertex(outputConstraint.Ref.Resource); err != nil {
errs = append(errs, fmt.Errorf("output %s error in reference: %w", outputConstraint.Name, err))
continue
}
s.outputs[outputConstraint.Name] = construct.Output{
Ref: outputConstraint.Ref,
}
}
return errors.Join(errs...)
}
func (s *engineSolution) Outputs() map[string]construct.Output {
return s.outputs
}
package solution
import (
"fmt"
construct "github.com/klothoplatform/klotho/pkg/construct"
engine_errs "github.com/klothoplatform/klotho/pkg/engine/errors"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
type (
SolveDecision interface {
// internal is a private method to prevent other packages from implementing this interface.
// It's not necessary, but it could prevent some accidental bad practices from emerging.
internal()
}
AsEngineError interface {
// TryEngineError returns an EngineError if the decision is an error, otherwise nil.
TryEngineError() engine_errs.EngineError
}
AddResourceDecision struct {
Resource construct.ResourceId
}
RemoveResourceDecision struct {
Resource construct.ResourceId
}
AddDependencyDecision struct {
From construct.ResourceId
To construct.ResourceId
}
RemoveDependencyDecision struct {
From construct.ResourceId
To construct.ResourceId
}
SetPropertyDecision struct {
Resource construct.ResourceId
Property string
Value any
}
PropertyValidationDecision struct {
Resource construct.ResourceId
Property knowledgebase.Property
Value any
Error error
}
ConfigValidationError struct {
PropertyValidationDecision
}
)
func (d AddResourceDecision) internal() {}
func (d AddDependencyDecision) internal() {}
func (d RemoveResourceDecision) internal() {}
func (d RemoveDependencyDecision) internal() {}
func (d SetPropertyDecision) internal() {}
func (d PropertyValidationDecision) internal() {}
func (d PropertyValidationDecision) TryEngineError() engine_errs.EngineError {
if d.Error == nil {
return nil
}
return ConfigValidationError{
PropertyValidationDecision: d,
}
}
func (e ConfigValidationError) Error() string {
return fmt.Sprintf(
"config validation error on %s#%s: %v",
e.Resource,
e.Property.Details().Path,
e.PropertyValidationDecision.Error,
)
}
func (e ConfigValidationError) ErrorCode() engine_errs.ErrorCode {
return engine_errs.ConfigInvalidCode
}
func (e ConfigValidationError) ToJSONMap() map[string]any {
return map[string]any{
"resource": e.Resource,
"property": e.Property.Details().Path,
"value": e.Value,
"validation_error": e.PropertyValidationDecision.Error.Error(),
}
}
func (e ConfigValidationError) Unwrap() error {
return e.PropertyValidationDecision.Error
}
package solution
import (
construct "github.com/klothoplatform/klotho/pkg/construct"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
func Downstream(
sol Solution,
rid construct.ResourceId,
layer knowledgebase.DependencyLayer,
) ([]construct.ResourceId, error) {
return knowledgebase.Downstream(sol.DataflowGraph(), sol.KnowledgeBase(), rid, layer)
}
func DownstreamFunctional(sol Solution, resource construct.ResourceId) ([]construct.ResourceId, error) {
return knowledgebase.DownstreamFunctional(sol.DataflowGraph(), sol.KnowledgeBase(), resource)
}
func Upstream(
sol Solution,
resource construct.ResourceId,
layer knowledgebase.DependencyLayer,
) ([]construct.ResourceId, error) {
return knowledgebase.Upstream(sol.DataflowGraph(), sol.KnowledgeBase(), resource, layer)
}
func UpstreamFunctional(sol Solution, resource construct.ResourceId) ([]construct.ResourceId, error) {
return knowledgebase.UpstreamFunctional(sol.DataflowGraph(), sol.KnowledgeBase(), resource)
}
func IsOperationalResourceSideEffect(sol Solution, rid, sideEffect construct.ResourceId) (bool, error) {
return knowledgebase.IsOperationalResourceSideEffect(sol.DataflowGraph(), sol.KnowledgeBase(), rid, sideEffect)
}
package solution
import "sync"
type DecisionRecords struct {
mu sync.Mutex
records []SolveDecision
}
func (r *DecisionRecords) RecordDecision(d SolveDecision) {
r.mu.Lock()
defer r.mu.Unlock()
if r.records == nil {
r.records = []SolveDecision{d}
return
}
r.records = append(r.records, d)
}
func (r *DecisionRecords) GetDecisions() []SolveDecision {
r.mu.Lock()
defer r.mu.Unlock()
return r.records
}
package solution
import (
"errors"
"fmt"
"reflect"
"strings"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/constraints"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
//go:generate mockgen -source=./resource_configuration.go --destination=../operational_eval/resource_configurer_mock_test.go --package=operational_eval
type (
ResourceConfigurer interface {
ConfigureResource(
resource *construct.Resource,
configuration knowledgebase.Configuration,
data knowledgebase.DynamicValueData,
action constraints.ConstraintOperator,
userInitiated bool,
) error
}
Configurer struct {
Ctx Solution
}
)
func (c *Configurer) ConfigureResource(
resource *construct.Resource,
configuration knowledgebase.Configuration,
data knowledgebase.DynamicValueData,
action constraints.ConstraintOperator,
userInitiated bool,
) error {
if resource == nil {
return fmt.Errorf("resource does not exist")
}
if resource.Imported && !userInitiated {
return fmt.Errorf("cannot configure imported resource %s", resource.ID)
}
if data.Resource != resource.ID {
return fmt.Errorf("data resource (%s) does not match configuring resource (%s)", data.Resource, resource.ID)
}
field := configuration.Field
rt, err := c.Ctx.KnowledgeBase().GetResourceTemplate(resource.ID)
if err != nil {
return err
}
property := rt.GetProperty(field)
if property == nil {
return fmt.Errorf("failed to get property %s on resource %s: %w", field, resource.ID, err)
}
val, err := knowledgebase.TransformToPropertyValue(
resource.ID,
field,
configuration.Value,
DynamicCtx(c.Ctx),
data,
)
if err != nil {
return err
}
switch action {
case constraints.EqualsConstraintOperator:
err = property.SetProperty(resource, val)
if err != nil {
return fmt.Errorf("failed to set property %s on resource %s: %w", field, resource.ID, err)
}
err = AddDeploymentDependenciesFromVal(c.Ctx, resource, val)
if err != nil {
return fmt.Errorf("failed to add deployment dependencies from property %s on resource %s: %w", field, resource.ID, err)
}
case constraints.AddConstraintOperator:
err = property.AppendProperty(resource, val)
if err != nil {
return fmt.Errorf("failed to add property %s on resource %s: %w", field, resource.ID, err)
}
err = AddDeploymentDependenciesFromVal(c.Ctx, resource, val)
if err != nil {
return fmt.Errorf("failed to add deployment dependencies from property %s on resource %s: %w", field, resource.ID, err)
}
case constraints.RemoveConstraintOperator:
err = property.RemoveProperty(resource, val)
if err != nil {
return fmt.Errorf("failed to remove property %s on resource %s: %w", field, resource.ID, err)
}
default:
return fmt.Errorf("invalid action %s", action)
}
c.Ctx.RecordDecision(SetPropertyDecision{
Resource: resource.ID,
Property: configuration.Field,
Value: configuration.Value,
})
return nil
}
func AddDeploymentDependenciesFromVal(
ctx Solution,
resource *construct.Resource,
val any,
) error {
var errs error
ids := getResourcesFromValue(val)
for _, id := range ids {
if resource.ID.Matches(id) {
continue
}
err := ctx.DeploymentGraph().AddEdge(resource.ID, id)
if err != nil && !errors.Is(err, graph.ErrEdgeAlreadyExists) {
errs = errors.Join(errs, fmt.Errorf("failed to add deployment dependency from %s to %s: %w", resource.ID, id, err))
}
}
return errs
}
func getResourcesFromValue(val any) (ids []construct.ResourceId) {
if val == nil {
return
}
switch v := val.(type) {
case construct.ResourceId:
ids = []construct.ResourceId{v}
case construct.PropertyRef:
ids = []construct.ResourceId{v.Resource}
default:
rval := reflect.ValueOf(val)
switch rval.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < reflect.ValueOf(val).Len(); i++ {
idVal := rval.Index(i).Interface()
ids = append(ids, getResourcesFromValue(idVal)...)
}
case reflect.Map:
for _, key := range reflect.ValueOf(val).MapKeys() {
idVal := rval.MapIndex(key).Interface()
ids = append(ids, getResourcesFromValue(idVal)...)
}
}
}
return
}
// getResourcesFromPropertyReference takes a property reference and returns all the resources that are
// referenced by it. It does this by walking the property reference (split by #)
// and finding all the resources that are in the property.
func GetResourcesFromPropertyReference(
ctx Solution,
resource construct.ResourceId,
propertyRef string,
) (
resources []construct.ResourceId,
errs error,
) {
parts := strings.Split(propertyRef, "#")
resources = []construct.ResourceId{resource}
if propertyRef == "" {
return
}
for _, part := range parts {
fieldValueResources := []construct.ResourceId{}
for _, resId := range resources {
r, err := ctx.RawView().Vertex(resId)
if r == nil || err != nil {
errs = errors.Join(errs, fmt.Errorf(
"failed to get resources from property reference. could not find resource %s: %w",
resId, err,
))
continue
}
val, err := r.GetProperty(part)
if err != nil || val == nil {
continue
}
if id, ok := val.(construct.ResourceId); ok {
fieldValueResources = append(fieldValueResources, id)
} else if rval := reflect.ValueOf(val); rval.Kind() == reflect.Slice || rval.Kind() == reflect.Array {
for i := 0; i < rval.Len(); i++ {
idVal := rval.Index(i).Interface()
if id, ok := idVal.(construct.ResourceId); ok {
fieldValueResources = append(fieldValueResources, id)
} else {
errs = errors.Join(errs, fmt.Errorf(
"failed to get resources from property reference. array property %s on resource %s is not a resource id",
part, resId,
))
}
}
} else {
errs = errors.Join(errs, fmt.Errorf(
"failed to get resources from property reference. property %s on resource %s is not a resource id",
part, resId,
))
}
}
resources = fieldValueResources
}
return
}
package solution
import (
"context"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/constraints"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
type (
Solution interface {
Context() context.Context
KnowledgeBase() knowledgebase.TemplateKB
Constraints() *constraints.Constraints
RecordDecision(d SolveDecision)
GetDecisions() []SolveDecision
DataflowGraph() construct.Graph
DeploymentGraph() construct.Graph
// OperationalView returns a graph that makes any resources or edges added operational as part of the operation.
// Read operations come from the Dataflow graph.
// Write operations will update both the Dataflow and Deployment graphs.
OperationalView() OperationalView
// RawView returns a graph that makes no changes beyond explicitly requested operations.
// Read operations come from the Dataflow graph.
// Write operations will update both the Dataflow and Deployment graphs.
RawView() construct.Graph
// GlobalTag returns the global tag for the solution context
GlobalTag() string
Outputs() map[string]construct.Output
}
OperationalView interface {
construct.Graph
MakeResourcesOperational(resources []*construct.Resource) error
UpdateResourceID(oldId, newId construct.ResourceId) error
MakeEdgesOperational(edges []construct.Edge) error
}
)
func DynamicCtx(sol Solution) knowledgebase.DynamicValueContext {
return knowledgebase.DynamicValueContext{Graph: sol.DataflowGraph(), KnowledgeBase: sol.KnowledgeBase()}
}
package solution
import (
"errors"
"fmt"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
)
type RawAccessView struct {
inner Solution
}
func NewRawView(inner Solution) RawAccessView {
return RawAccessView{inner: inner}
}
func (view RawAccessView) Traits() *graph.Traits {
return view.inner.DataflowGraph().Traits()
}
func (view RawAccessView) AddVertex(value *construct.Resource, options ...func(*graph.VertexProperties)) error {
rt, err := view.inner.KnowledgeBase().GetResourceTemplate(value.ID)
if err != nil {
return err
}
dfErr := view.inner.DataflowGraph().AddVertex(value, options...)
if !rt.NoIac {
deplErr := view.inner.DeploymentGraph().AddVertex(value, options...)
if errors.Is(dfErr, graph.ErrVertexAlreadyExists) && errors.Is(deplErr, graph.ErrVertexAlreadyExists) {
return graph.ErrVertexAlreadyExists
}
if deplErr != nil && !errors.Is(deplErr, graph.ErrVertexAlreadyExists) {
err = errors.Join(err, deplErr)
}
}
if dfErr != nil && !errors.Is(dfErr, graph.ErrVertexAlreadyExists) {
err = errors.Join(err, dfErr)
}
if err != nil {
return err
}
view.inner.RecordDecision(AddResourceDecision{Resource: value.ID})
return nil
}
func (view RawAccessView) AddVerticesFrom(g construct.Graph) error {
ordered, err := construct.ReverseTopologicalSort(g)
if err != nil {
return err
}
var errs error
for _, rid := range ordered {
//! This will cause issues when we solve multiple graphs
// this should copy the vertex instead of using the same pointer
res, err := g.Vertex(rid)
if err != nil {
errs = errors.Join(errs, err)
continue
}
err = view.AddVertex(res)
//? should the vertex overwrite?
if err != nil && !errors.Is(err, graph.ErrVertexAlreadyExists) {
errs = errors.Join(errs, err)
}
}
return errs
}
func (view RawAccessView) Vertex(hash construct.ResourceId) (*construct.Resource, error) {
return view.inner.DataflowGraph().Vertex(hash)
}
func (view RawAccessView) VertexWithProperties(
hash construct.ResourceId,
) (*construct.Resource, graph.VertexProperties, error) {
return view.inner.DataflowGraph().VertexWithProperties(hash)
}
func (view RawAccessView) RemoveVertex(hash construct.ResourceId) error {
err := view.inner.DataflowGraph().RemoveVertex(hash)
rt, terr := view.inner.KnowledgeBase().GetResourceTemplate(hash)
if terr != nil {
return terr
}
if !rt.NoIac {
err = errors.Join(
err,
view.inner.DeploymentGraph().RemoveVertex(hash),
)
}
if err != nil {
return err
}
view.inner.RecordDecision(RemoveResourceDecision{Resource: hash})
return nil
}
func (view RawAccessView) AddEdge(source, target construct.ResourceId, options ...func(*graph.EdgeProperties)) error {
dfErr := view.inner.DataflowGraph().AddEdge(source, target, options...)
// check to see if both resources are imported and if so allow the edge
src, err := view.inner.DataflowGraph().Vertex(source)
if err != nil {
return err
}
dst, err := view.inner.DataflowGraph().Vertex(target)
if err != nil {
return err
}
if src.Imported && dst.Imported {
// don't need to add it to the deployment graph - if both resources are imported, then
// they don't need each other since neither are technically being deployed.
return nil
}
var deplErr error
et := view.inner.KnowledgeBase().GetEdgeTemplate(source, target)
srcRt, terr := view.inner.KnowledgeBase().GetResourceTemplate(source)
if terr != nil {
return terr
}
dstRt, terr := view.inner.KnowledgeBase().GetResourceTemplate(target)
if terr != nil {
return terr
}
if !srcRt.NoIac && !dstRt.NoIac && (et != nil && !et.NoIac) {
if et != nil && et.DeploymentOrderReversed {
deplErr = view.inner.DeploymentGraph().AddEdge(target, source, options...)
} else {
deplErr = view.inner.DeploymentGraph().AddEdge(source, target, options...)
}
if errors.Is(dfErr, graph.ErrEdgeAlreadyExists) && errors.Is(deplErr, graph.ErrEdgeAlreadyExists) {
return graph.ErrEdgeAlreadyExists
}
}
if dfErr != nil && !errors.Is(dfErr, graph.ErrEdgeAlreadyExists) {
err = errors.Join(err, dfErr)
}
if deplErr != nil && !errors.Is(deplErr, graph.ErrEdgeAlreadyExists) {
err = errors.Join(err, deplErr)
}
if err != nil {
return err
}
view.inner.RecordDecision(AddDependencyDecision{
From: source,
To: target,
})
return nil
}
func (view RawAccessView) AddEdgesFrom(g construct.Graph) error {
edges, err := g.Edges()
if err != nil {
return err
}
var errs error
for _, edge := range edges {
errs = errors.Join(errs, view.AddEdge(edge.Source, edge.Target))
}
return errs
}
func (view RawAccessView) Edge(source, target construct.ResourceId) (construct.ResourceEdge, error) {
return view.inner.DataflowGraph().Edge(source, target)
}
func (view RawAccessView) Edges() ([]construct.Edge, error) {
return view.inner.DataflowGraph().Edges()
}
func (view RawAccessView) UpdateEdge(
source, target construct.ResourceId,
options ...func(properties *graph.EdgeProperties),
) error {
for _, id := range []*construct.ResourceId{&source, &target} {
rt, err := view.inner.KnowledgeBase().GetResourceTemplate(*id)
if err != nil {
res := "source"
if *id == target {
res = "target"
}
return fmt.Errorf("could not get template for %s: %w", res, err)
}
id.Name, err = rt.SanitizeName(id.Name)
if err != nil {
return fmt.Errorf("failed to sanitize name in %s: %w", *id, err)
}
}
dfErr := view.inner.DataflowGraph().UpdateEdge(source, target, options...)
var deplErr error
et := view.inner.KnowledgeBase().GetEdgeTemplate(source, target)
srcRt, terr := view.inner.KnowledgeBase().GetResourceTemplate(source)
if terr != nil {
return terr
}
dstRt, terr := view.inner.KnowledgeBase().GetResourceTemplate(target)
if terr != nil {
return terr
}
if !srcRt.NoIac && !dstRt.NoIac {
if et != nil && et.DeploymentOrderReversed {
deplErr = view.inner.DeploymentGraph().UpdateEdge(target, source, options...)
} else {
deplErr = view.inner.DeploymentGraph().UpdateEdge(source, target, options...)
}
}
return errors.Join(dfErr, deplErr)
}
func (view RawAccessView) RemoveEdge(source, target construct.ResourceId) error {
dfErr := view.inner.DataflowGraph().RemoveEdge(source, target)
var deplErr error
srcRt, terr := view.inner.KnowledgeBase().GetResourceTemplate(source)
if terr != nil {
return terr
}
dstRt, terr := view.inner.KnowledgeBase().GetResourceTemplate(target)
if terr != nil {
return terr
}
if !srcRt.NoIac && !dstRt.NoIac {
et := view.inner.KnowledgeBase().GetEdgeTemplate(source, target)
if et != nil && et.DeploymentOrderReversed {
deplErr = view.inner.DeploymentGraph().RemoveEdge(target, source)
} else {
deplErr = view.inner.DeploymentGraph().RemoveEdge(source, target)
}
}
if err := errors.Join(dfErr, deplErr); err != nil {
return err
}
view.inner.RecordDecision(RemoveDependencyDecision{
From: source,
To: target,
})
return nil
}
func (view RawAccessView) AdjacencyMap() (map[construct.ResourceId]map[construct.ResourceId]construct.Edge, error) {
return view.inner.DataflowGraph().AdjacencyMap()
}
func (view RawAccessView) PredecessorMap() (map[construct.ResourceId]map[construct.ResourceId]construct.Edge, error) {
return view.inner.DataflowGraph().PredecessorMap()
}
func (view RawAccessView) Clone() (construct.Graph, error) {
return nil, errors.New("cannot clone a raw view")
}
func (view RawAccessView) Order() (int, error) {
return view.inner.DataflowGraph().Order()
}
func (view RawAccessView) Size() (int, error) {
return view.inner.DataflowGraph().Size()
}
package engine
import (
"errors"
"fmt"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/solution"
)
type MakeOperationalView engineSolution
func (view *MakeOperationalView) Traits() *graph.Traits {
return view.Dataflow.Traits()
}
func (view *MakeOperationalView) AddVertex(value *construct.Resource, options ...func(*graph.VertexProperties)) error {
err := view.raw().AddVertex(value, options...)
if err != nil {
return err
}
err = view.MakeResourcesOperational([]*construct.Resource{value})
if err != nil {
return err
}
// Look for any global edge constraints to the type of resource we are adding and enforce them
for _, edgeConstraint := range view.constraints.Edges {
if edgeConstraint.Target.Source.Name == "" &&
edgeConstraint.Target.Source.QualifiedTypeName() == value.ID.QualifiedTypeName() {
err = view.AddEdge(value.ID, edgeConstraint.Target.Target)
if err != nil {
return err
}
} else if edgeConstraint.Target.Target.Name == "" &&
edgeConstraint.Target.Target.QualifiedTypeName() == value.ID.QualifiedTypeName() {
err = view.AddEdge(edgeConstraint.Target.Source, value.ID)
if err != nil {
return err
}
}
}
return nil
}
func (view *MakeOperationalView) AddVerticesFrom(g construct.Graph) error {
ordered, err := construct.ReverseTopologicalSort(g)
if err != nil {
return err
}
raw := view.raw()
var errs error
var resources []*construct.Resource
add := func(id construct.ResourceId) {
res, err := g.Vertex(id)
if err != nil {
errs = errors.Join(errs, err)
return
}
errs = errors.Join(errs, raw.AddVertex(res))
resources = append(resources, res)
}
for _, rid := range ordered {
add(rid)
}
if errs != nil {
return errs
}
return view.MakeResourcesOperational(resources)
}
func (view *MakeOperationalView) raw() solution.RawAccessView {
return solution.NewRawView((*engineSolution)(view))
}
func (view *MakeOperationalView) MakeResourcesOperational(resources []*construct.Resource) error {
return view.propertyEval.AddResources(resources...)
}
func (view *MakeOperationalView) UpdateResourceID(oldId, newId construct.ResourceId) error {
return view.propertyEval.UpdateId(oldId, newId)
}
func (view *MakeOperationalView) Vertex(hash construct.ResourceId) (*construct.Resource, error) {
return view.raw().Vertex(hash)
}
func (view *MakeOperationalView) VertexWithProperties(hash construct.ResourceId) (*construct.Resource, graph.VertexProperties, error) {
return view.raw().VertexWithProperties(hash)
}
func (view *MakeOperationalView) RemoveVertex(hash construct.ResourceId) error {
return errors.Join(
view.raw().RemoveVertex(hash),
view.propertyEval.RemoveResource(hash),
)
}
func (view *MakeOperationalView) AddEdge(source, target construct.ResourceId, options ...func(*graph.EdgeProperties)) (err error) {
var dep construct.ResourceEdge
var errs error
dep.Source, err = view.Vertex(source)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("no source found: %w", err))
}
dep.Target, err = view.Vertex(target)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("no target found: %w", err))
}
if errs != nil {
return fmt.Errorf("cannot add edge %s -> %s: %w", source, target, errs)
}
// Only add to view if there is an edge template, otherwise theres the potential to cause circular dependencies since path
// solving will not have been run on the edge yet for intermediate resource
if view.KB.GetEdgeTemplate(source, target) != nil {
err = view.raw().AddEdge(source, target, options...)
if err != nil {
return fmt.Errorf("cannot add edge %s -> %s: %w", source, target, err)
}
}
// If both resources are imported we dont need to evaluate the edge vertex since we cannot modify the resources properties
if dep.Source.Imported && dep.Target.Imported {
return nil
}
return view.propertyEval.AddEdges(graph.Edge[construct.ResourceId]{Source: source, Target: target})
}
func (view *MakeOperationalView) MakeEdgesOperational(edges []construct.Edge) error {
return view.propertyEval.AddEdges(edges...)
}
func (view *MakeOperationalView) AddEdgesFrom(g construct.Graph) error {
edges, err := g.Edges()
if err != nil {
return err
}
var errs error
for _, edge := range edges {
errs = errors.Join(errs, view.AddEdge(edge.Source, edge.Target))
}
return errs
}
func (view *MakeOperationalView) Edge(source, target construct.ResourceId) (construct.ResourceEdge, error) {
return view.Dataflow.Edge(source, target)
}
func (view *MakeOperationalView) Edges() ([]construct.Edge, error) {
return view.Dataflow.Edges()
}
func (view *MakeOperationalView) UpdateEdge(source, target construct.ResourceId, options ...func(properties *graph.EdgeProperties)) error {
return view.raw().UpdateEdge(source, target, options...)
}
func (view *MakeOperationalView) RemoveEdge(source, target construct.ResourceId) error {
return errors.Join(
view.raw().RemoveEdge(source, target),
view.propertyEval.RemoveEdge(source, target),
)
}
func (view *MakeOperationalView) AdjacencyMap() (map[construct.ResourceId]map[construct.ResourceId]construct.Edge, error) {
return view.Dataflow.AdjacencyMap()
}
func (view *MakeOperationalView) PredecessorMap() (map[construct.ResourceId]map[construct.ResourceId]construct.Edge, error) {
return view.Dataflow.PredecessorMap()
}
func (view *MakeOperationalView) Clone() (construct.Graph, error) {
return nil, errors.New("cannot clone an operational view")
}
func (view *MakeOperationalView) Order() (int, error) {
return view.Dataflow.Order()
}
func (view *MakeOperationalView) Size() (int, error) {
return view.Dataflow.Size()
}
package engine
import (
"errors"
"fmt"
"math"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/path_selection"
"github.com/klothoplatform/klotho/pkg/engine/solution"
"github.com/klothoplatform/klotho/pkg/graph_addons"
klotho_io "github.com/klothoplatform/klotho/pkg/io"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/klothoplatform/klotho/pkg/set"
visualizer "github.com/klothoplatform/klotho/pkg/visualizer"
)
type (
View string
Tag string
)
const (
DataflowView View = "dataflow"
IACView View = "iac"
ParentIconTag Tag = "parent"
BigIconTag Tag = "big"
SmallIconTag Tag = "small"
NoRenderTag Tag = "no-render"
)
func (e *Engine) VisualizeViews(ctx solution.Solution) ([]klotho_io.File, error) {
iac_topo := &visualizer.File{
FilenamePrefix: "iac-",
Provider: "aws",
}
dataflow_topo := &visualizer.File{
FilenamePrefix: "dataflow-",
Provider: "aws",
}
var err error
iac_topo.Graph, err = visualizer.ConstructToVis(ctx.DeploymentGraph())
if err != nil {
return nil, err
}
dataflow_topo.Graph, err = e.GetViewsDag(DataflowView, ctx)
return []klotho_io.File{iac_topo, dataflow_topo}, err
}
func GetResourceVizTag(kb knowledgebase.TemplateKB, view View, resource construct.ResourceId) Tag {
template, err := kb.GetResourceTemplate(resource)
if template == nil || err != nil {
return NoRenderTag
}
tag, found := template.Views[string(view)]
if !found {
return NoRenderTag
}
return Tag(tag)
}
func (e *Engine) GetViewsDag(view View, sol solution.Solution) (visualizer.VisGraph, error) {
viewDag := visualizer.NewVisGraph()
var resGraph construct.Graph
if view == IACView {
resGraph = sol.DeploymentGraph()
} else {
resGraph = sol.DataflowGraph()
}
undirected := construct.NewGraphWithOptions()
err := undirected.AddVerticesFrom(resGraph)
if err != nil {
return nil, fmt.Errorf("could not copy vertices for undirected: %w", err)
}
err = undirected.AddEdgesFrom(resGraph)
if err != nil {
return nil, fmt.Errorf("could not copy edges for undirected: %w", err)
}
ids, err := construct.ReverseTopologicalSort(resGraph)
if err != nil {
return nil, err
}
var errs error
// First pass gets all the vertices (groups or big icons)
for _, id := range ids {
var err error
switch tag := GetResourceVizTag(e.Kb, view, id); tag {
case NoRenderTag, SmallIconTag:
continue
case ParentIconTag, BigIconTag:
err = viewDag.AddVertex(&visualizer.VisResource{
ID: id,
Children: make(set.Set[construct.ResourceId]),
Tag: string(tag),
})
default:
errs = errors.Join(errs, fmt.Errorf("unknown tag %s", tag))
}
errs = errors.Join(errs, err)
}
if errs != nil {
return nil, errs
}
// Second pass sets up the small icons & parents
for _, id := range ids {
switch tag := GetResourceVizTag(e.Kb, view, id); tag {
case ParentIconTag, BigIconTag:
err := e.setupAncestry(sol, view, viewDag, id)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("failed to handle %s icon %s: %w", tag, id, err))
}
}
}
// Third pass for the edges. Needs to happen after parents to exclude edges to a node's ancestry
for _, id := range ids {
switch tag := GetResourceVizTag(e.Kb, view, id); tag {
case ParentIconTag, BigIconTag:
err := e.makeEdges(sol, view, viewDag, id)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("failed to handle %s icon %s: %w", tag, id, err))
}
}
}
return viewDag, errs
}
func (e *Engine) makeEdges(
sol solution.Solution,
view View,
viewDag visualizer.VisGraph,
id construct.ResourceId,
) error {
ancestors, err := visualizer.VertexAncestors(viewDag, id)
if err != nil {
return err
}
targets, err := construct.TopologicalSort(viewDag)
if err != nil {
return err
}
var errs error
for _, target := range targets {
if ancestors.Contains(target) {
// Don't draw edges from a node to its ancestors, since it already lives inside of them
continue
}
targetAncestors, err := visualizer.VertexAncestors(viewDag, target)
if err != nil {
errs = errors.Join(errs, err)
continue
}
if targetAncestors.Contains(id) {
// Don't draw edges from a node to its descendants
continue
}
_, targetErr := viewDag.Vertex(target)
if errors.Is(targetErr, graph.ErrVertexNotFound) {
continue
} else if targetErr != nil {
errs = errors.Join(errs, targetErr)
continue
}
paths, err := visPaths(sol, view, id, target)
if err != nil {
errs = errors.Join(errs, err)
continue
}
if len(paths) > 0 {
allPathResources := make(set.Set[construct.ResourceId])
for _, path := range paths {
for _, pathRes := range path[1 : len(path)-1] {
allPathResources.Add(pathRes)
}
}
errs = errors.Join(errs, viewDag.AddEdge(id, target, graph.EdgeData(visualizer.VisEdgeData{
PathResources: allPathResources,
})))
}
}
return errs
}
// setupAncestry sets the parent of the big icon if there is a group it should be added to and
// adds edges to any other big icons based on having the proper connections (network & permissions).
func (e *Engine) setupAncestry(
sol solution.Solution,
view View,
viewDag visualizer.VisGraph,
id construct.ResourceId,
) error {
this, err := viewDag.Vertex(id)
if err != nil {
return err
}
parent, err := e.findParent(view, sol, viewDag, id)
if err != nil {
return err
}
this.Parent = parent
if err := e.setChildren(sol, view, this); err != nil {
return err
}
return nil
}
func (e *Engine) setChildren(sol solution.Solution, view View, v *visualizer.VisResource) error {
local, err := knowledgebase.Downstream(
sol.DataflowGraph(),
sol.KnowledgeBase(),
v.ID,
knowledgebase.ResourceLocalLayer,
)
if err != nil {
return fmt.Errorf("failed to get local layer for %s: %w", v.ID, err)
}
for _, localElem := range local {
if GetResourceVizTag(e.Kb, view, localElem) == SmallIconTag {
v.Children.Add(localElem)
}
}
// After glue, also include any resources whose namespace is this resource
ids, err := construct.TopologicalSort(sol.DataflowGraph())
if err != nil {
return err
}
var errs error
for _, id := range ids {
if id.Namespace == "" {
continue
}
tmpl, err := sol.KnowledgeBase().GetResourceTemplate(id)
if err != nil {
errs = errors.Join(errs, err)
continue
}
for _, p := range tmpl.Properties {
if p.Details().Namespace {
pres, err := sol.RawView().Vertex(id)
if err != nil {
errs = errors.Join(errs, err)
break
}
val, err := pres.GetProperty(p.Details().Path)
if err != nil {
errs = errors.Join(errs, err)
break
}
if val == v.ID {
v.Children.Add(id)
break
}
}
}
}
return nil
}
func (e *Engine) findParent(
view View,
sol solution.Solution,
viewDag visualizer.VisGraph,
id construct.ResourceId,
) (bestParent construct.ResourceId, err error) {
if id.Namespace != "" {
// namespaced resources' parents is always their namespace resource
tmpl, err := sol.KnowledgeBase().GetResourceTemplate(id)
if err != nil {
return bestParent, err
}
thisRes, err := sol.RawView().Vertex(id)
if err != nil {
return bestParent, err
}
for _, p := range tmpl.Properties {
if !p.Details().Namespace {
continue
}
v, err := thisRes.GetProperty(p.Details().Path)
if err != nil {
return bestParent, fmt.Errorf("failed to get namespace property %s: %w", p.Details().Path, err)
}
if propId, ok := v.(construct.ResourceId); ok {
if GetResourceVizTag(e.Kb, view, propId) == ParentIconTag {
return propId, nil
}
// the property isn't shown as a parent (eg. Subnet or ALB Listener), so roll it up to the next parent
return e.findParent(view, sol, viewDag, propId)
} else {
return bestParent, fmt.Errorf("namespace property %s is not a resource id (was: %T)", p.Details().Path, v)
}
}
}
glue, err := knowledgebase.Downstream(
sol.DataflowGraph(),
sol.KnowledgeBase(),
id,
knowledgebase.ResourceLocalLayer,
)
if err != nil {
return
}
pather, err := construct.ShortestPaths(sol.DataflowGraph(), id, construct.DontSkipEdges)
if err != nil {
return
}
bestParentWeight := math.MaxInt32
var errs error
candidateLoop:
for _, candidate := range glue {
if GetResourceVizTag(e.Kb, view, candidate) != ParentIconTag {
continue
}
path, err := pather.ShortestPath(candidate)
if errors.Is(err, graph.ErrTargetNotReachable) {
continue
} else if err != nil {
errs = errors.Join(errs, err)
continue
}
for _, pathElem := range path[1 : len(path)-1] {
pathTmpl, err := e.Kb.GetResourceTemplate(pathElem)
if err != nil {
errs = errors.Join(errs, err)
continue
}
// Don't cross functional boundaries for parent attribution
if pathTmpl.GetFunctionality() != knowledgebase.Unknown {
continue candidateLoop
}
}
weight, err := graph_addons.PathWeight(sol.DataflowGraph(), graph_addons.Path[construct.ResourceId](path))
if err != nil {
errs = errors.Join(errs, err)
continue
}
if weight < bestParentWeight {
bestParentWeight = weight
bestParent = candidate
}
}
err = errs
return
}
func visPaths(sol solution.Solution, view View, source, target construct.ResourceId) ([]construct.Path, error) {
srcTemplate, err := sol.KnowledgeBase().GetResourceTemplate(source)
if err != nil || srcTemplate == nil {
return nil, fmt.Errorf("has path could not find source resource %s: %w", source, err)
}
targetTemplate, err := sol.KnowledgeBase().GetResourceTemplate(target)
if err != nil || targetTemplate == nil {
return nil, fmt.Errorf("has path could not find target resource %s: %w", target, err)
}
if len(targetTemplate.PathSatisfaction.AsTarget) == 0 || len(srcTemplate.PathSatisfaction.AsSource) == 0 {
return nil, nil
}
sourceRes, err := sol.RawView().Vertex(source)
if err != nil {
return nil, fmt.Errorf("has path could not find source resource %s: %w", source, err)
}
targetRes, err := sol.RawView().Vertex(target)
if err != nil {
return nil, fmt.Errorf("has path could not find target resource %s: %w", target, err)
}
consumed, err := knowledgebase.HasConsumedFromResource(
sourceRes,
targetRes,
solution.DynamicCtx(sol),
)
if err != nil {
return nil, err
}
if !consumed {
return nil, nil
}
return checkPaths(sol, view, source, target)
}
func checkPaths(sol solution.Solution, view View, source, target construct.ResourceId) ([]construct.Path, error) {
paths, err := path_selection.GetPaths(
sol,
source,
target,
func(source, target construct.ResourceId, path construct.Path) bool {
for _, res := range path[1 : len(path)-1] {
switch GetResourceVizTag(sol.KnowledgeBase(), view, res) {
case BigIconTag, ParentIconTag:
// Don't consider paths that go through big/parent icons
return false
}
}
return true
},
false,
)
return paths, err
}
package engine
import (
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/constraints"
"gopkg.in/yaml.v3"
)
// FileFormat is used for engine input/output to render or read a YAML file
// An example yaml file is:
//
// constraints:
// - scope: application
// operator: add
// node: p:t:a
// resources:
// p:t:a:
// p:t:b:
// edges:
// p:t:a -> p:t:b:
type FileFormat struct {
Constraints constraints.Constraints
Graph construct.Graph
}
func (ff FileFormat) MarshalYAML() (interface{}, error) {
constraintsNode := &yaml.Node{}
err := constraintsNode.Encode(ff.Constraints.ToList())
if err != nil {
return nil, err
}
if len(constraintsNode.Content) == 0 {
// this makes `constraints: {}` like `constraints:`
constraintsNode = &yaml.Node{
Kind: yaml.ScalarNode,
Tag: "!!null",
Value: "",
}
}
graphNode := &yaml.Node{}
err = graphNode.Encode(construct.YamlGraph{Graph: ff.Graph})
if err != nil {
return nil, err
}
root := &yaml.Node{
Kind: yaml.MappingNode,
Content: []*yaml.Node{
{
Kind: yaml.ScalarNode,
Value: "constraints",
},
constraintsNode,
},
}
root.Content = append(root.Content, graphNode.Content...)
return root, nil
}
func (ff *FileFormat) UnmarshalYAML(node *yaml.Node) error {
var constraints struct {
Constraints constraints.ConstraintList `yaml:"constraints"`
}
err := node.Decode(&constraints)
if err != nil {
return err
}
ff.Constraints, err = constraints.Constraints.ToConstraints()
if err != nil {
return err
}
var graph construct.YamlGraph
err = node.Decode(&graph)
if err != nil {
return err
}
ff.Graph = graph.Graph
return nil
}
package errors
import (
"fmt"
"runtime"
"github.com/pkg/errors"
)
type WrappedError struct {
Message string
Cause error
Stack errors.StackTrace
}
func (err *WrappedError) Error() string {
if err.Message != "" {
return err.Message + ": " + err.Cause.Error()
}
return err.Cause.Error()
}
func (err *WrappedError) Format(s fmt.State, verb rune) {
if err.Message != "" {
fmt.Fprint(s, err.Message+": ")
}
if len(err.Stack) > 0 && s.Flag('+') {
err.Stack.Format(s, verb)
}
if formatter, ok := err.Cause.(fmt.Formatter); ok {
formatter.Format(s, verb)
} else {
fmt.Fprint(s, err.Cause.Error())
}
}
func (err *WrappedError) Unwrap() error {
return err.Cause
}
func WrapErrf(err error, msg string, args ...interface{}) *WrappedError {
w := &WrappedError{
Message: fmt.Sprintf(msg, args...),
Cause: err,
Stack: callers(2),
}
return w
}
func callers(depth int) errors.StackTrace {
const maxDepth = 32
var pcs [maxDepth]uintptr
n := runtime.Callers(depth+1, pcs[:])
frames := make([]errors.Frame, n)
for i, frame := range pcs[:n] {
frames[i] = errors.Frame(frame)
}
return errors.StackTrace(frames)
}
package filter
import (
"github.com/klothoplatform/klotho/pkg/filter/predicate"
)
// Filter is a generic interface for a filter that is applied to a slice of type 'T' and returns a subset of the original input as []T
type Filter[T any] interface {
Apply(...T) []T
Find(...T) (T, bool)
}
// SimpleFilter is a filter that filters based on a supplied predicate (Predicate)
type SimpleFilter[T any] struct {
Predicate predicate.Predicate[T]
}
// Apply returns the subset of inputs matching the SimpleFilter's Predicate
func (f SimpleFilter[T]) Apply(inputs ...T) []T {
var result []T
for _, input := range inputs {
if f.Predicate(input) {
result = append(result, input)
}
}
return result
}
func (f SimpleFilter[T]) Find(inputs ...T) (T, bool) {
for _, input := range inputs {
if f.Predicate(input) {
return input, true
}
}
var zero T
return zero, false
}
// NewSimpleFilter returns a SimpleFilter matching each supplied predicate.Predicate sequentially on a per-input basis
func NewSimpleFilter[T any](predicates ...predicate.Predicate[T]) Filter[T] {
return SimpleFilter[T]{Predicate: predicate.AllOf(predicates...)}
}
package predicate
import (
"regexp"
)
type Predicate[T any] func(p T) bool
// Not negates the supplied Predicate
func Not[T any](predicate Predicate[T]) Predicate[T] {
return func(p T) bool {
return !predicate(p)
}
}
// AnyOf returns true if any of the supplied predicates returns true or else returns false
func AnyOf[T any](predicates ...Predicate[T]) Predicate[T] {
return func(p T) bool {
for _, predicate := range predicates {
if predicate(p) {
return true
}
}
return false
}
}
// AllOf returns true if all the supplied predicates return true or else returns false
func AllOf[T any](predicates ...Predicate[T]) Predicate[T] {
return func(p T) bool {
for _, predicate := range predicates {
if !predicate(p) {
return false
}
}
return true
}
}
func StringMatchesPattern(pattern string) Predicate[string] {
return func(target string) bool {
return regexp.MustCompile(pattern).MatchString(target)
}
}
package graph_addons
import (
"errors"
"github.com/dominikbraun/graph"
)
// LayeredGraph is a graph that is composed of multiple layers.
// When a vertex is added, it is added to the first layer (`[0]`).
// When an edge is added, if the source and target exist in the same layer, the edge is added to that layer,
// otherwise, the source and target are added to the first layer, and the edge is added.
// Remove and update operations are applied to all layers.
type LayeredGraph[K comparable, T any] []graph.Graph[K, T]
func LayeredGraphOf[K comparable, T any](g ...graph.Graph[K, T]) LayeredGraph[K, T] {
return g
}
func (g LayeredGraph[K, T]) Traits() *graph.Traits {
t := g[0].Traits()
for i := 1; i < len(g); i++ {
lt := g[i].Traits()
t.IsDirected = t.IsDirected || lt.IsDirected
t.IsAcyclic = t.IsAcyclic || lt.IsAcyclic
t.IsWeighted = t.IsWeighted || lt.IsWeighted
t.IsRooted = t.IsRooted || lt.IsRooted
t.PreventCycles = t.PreventCycles || lt.PreventCycles
}
return t
}
func (g LayeredGraph[K, T]) AddVertex(value T, options ...func(*graph.VertexProperties)) error {
return g[0].AddVertex(value, options...)
}
func (g LayeredGraph[K, T]) AddVerticesFrom(o graph.Graph[K, T]) error {
return g[0].AddVerticesFrom(o)
}
func (g LayeredGraph[K, T]) Vertex(hash K) (v T, err error) {
for _, layer := range g {
if v, err = layer.Vertex(hash); err == nil {
return v, nil
} else if !errors.Is(err, graph.ErrVertexNotFound) {
return
}
}
err = graph.ErrVertexNotFound
return
}
func (g LayeredGraph[K, T]) VertexWithProperties(hash K) (v T, p graph.VertexProperties, err error) {
for _, layer := range g {
if v, p, err = layer.VertexWithProperties(hash); err == nil {
return v, p, nil
} else if !errors.Is(err, graph.ErrVertexNotFound) {
return
}
}
err = graph.ErrVertexNotFound
return
}
func (g LayeredGraph[K, T]) RemoveVertex(hash K) error {
for _, layer := range g {
err := layer.RemoveVertex(hash)
if err != nil && !errors.Is(err, graph.ErrVertexNotFound) {
return err
}
}
return nil
}
func (g LayeredGraph[K, T]) AddEdge(sourceHash, targetHash K, options ...func(*graph.EdgeProperties)) error {
var src, tgt T
srcLayer, tgtLayer := -1, -1
for i, layer := range g {
sV, srcErr := layer.Vertex(sourceHash)
tV, tgtErr := layer.Vertex(targetHash)
if srcErr == nil && tgtErr == nil {
return layer.AddEdge(sourceHash, targetHash, options...)
}
if srcErr == nil && srcLayer == -1 {
srcLayer = i
src = sV
}
if errors.Is(srcErr, graph.ErrVertexNotFound) {
srcErr = nil
}
if tgtErr == nil && tgtLayer == -1 {
tgtLayer = i
tgt = tV
}
if errors.Is(tgtErr, graph.ErrVertexNotFound) {
tgtErr = nil
}
err := errors.Join(srcErr, tgtErr)
if err != nil {
return err
}
}
// no layer has both vertices, so add them both to the first layer
// then add the edge
err := errors.Join(
g[0].AddVertex(src),
g[0].AddVertex(tgt),
)
if err != nil {
return err
}
return g[0].AddEdge(sourceHash, targetHash, options...)
}
func (g LayeredGraph[K, T]) AddEdgesFrom(o graph.Graph[K, T]) error {
edges, err := o.Edges()
if err != nil {
return err
}
for _, edge := range edges {
err = errors.Join(err, g.AddEdge(edge.Source, edge.Target, func(ep *graph.EdgeProperties) {
*ep = edge.Properties
}))
}
return err
}
func (g LayeredGraph[K, T]) Edge(sourceHash, targetHash K) (graph.Edge[T], error) {
for _, layer := range g {
e, err := layer.Edge(sourceHash, targetHash)
if err == nil {
return e, nil
} else if !errors.Is(err, graph.ErrEdgeNotFound) {
return graph.Edge[T]{}, err
}
}
return graph.Edge[T]{}, graph.ErrEdgeNotFound
}
// Edges may return duplicate edges if an edge exists in multiple layers. This is intentional because those edges
// may contain different properties.
func (g LayeredGraph[K, T]) Edges() ([]graph.Edge[K], error) {
var edges []graph.Edge[K]
for _, layer := range g {
layerEdges, err := layer.Edges()
if err != nil {
return nil, err
}
edges = append(edges, layerEdges...)
}
return edges, nil
}
func (g LayeredGraph[K, T]) UpdateEdge(source, target K, options ...func(properties *graph.EdgeProperties)) error {
for _, layer := range g {
err := layer.UpdateEdge(source, target, options...)
if err != nil && !errors.Is(err, graph.ErrEdgeNotFound) {
return err
}
}
return nil
}
func (g LayeredGraph[K, T]) RemoveEdge(source, target K) error {
for _, layer := range g {
err := layer.RemoveEdge(source, target)
if err != nil && !errors.Is(err, graph.ErrEdgeNotFound) {
return err
}
}
return nil
}
func (g LayeredGraph[K, T]) AdjacencyMap() (map[K]map[K]graph.Edge[K], error) {
m := make(map[K]map[K]graph.Edge[K])
// iterate backwards so that the first layer has the highest priority
for _, layer := range g {
adj, err := layer.AdjacencyMap()
if err != nil {
return nil, err
}
for s, ts := range adj {
if m[s] == nil {
m[s] = make(map[K]graph.Edge[K])
}
for t, e := range ts {
existing, hasExisting := m[s][t]
if hasExisting && existing.Properties.Weight > e.Properties.Weight {
continue
}
m[s][t] = e
}
}
}
return m, nil
}
func (g LayeredGraph[K, T]) PredecessorMap() (map[K]map[K]graph.Edge[K], error) {
m := make(map[K]map[K]graph.Edge[K])
for _, layer := range g {
pred, err := layer.PredecessorMap()
if err != nil {
return nil, err
}
for t, ss := range pred {
if m[t] == nil {
m[t] = make(map[K]graph.Edge[K])
}
for s, e := range ss {
existing, hasExisting := m[t][s]
if hasExisting && existing.Properties.Weight > e.Properties.Weight {
continue
}
m[t][s] = e
}
}
}
return m, nil
}
func (g LayeredGraph[K, T]) Clone() (graph.Graph[K, T], error) {
g2 := make(LayeredGraph[K, T], len(g))
for i, layer := range g {
var err error
g2[i], err = layer.Clone()
if err != nil {
return nil, err
}
}
return g2, nil
}
func (g LayeredGraph[K, T]) Order() (int, error) {
adj, err := g.AdjacencyMap()
if err != nil {
return 0, err
}
return len(adj), nil
}
func (g LayeredGraph[K, T]) Size() (int, error) {
srcToTgt := make(map[K]K)
edges, err := g.Edges()
if err != nil {
return 0, err
}
for _, edge := range edges {
srcToTgt[edge.Source] = edge.Target
}
return len(srcToTgt), nil
}
package graph_addons
import (
"github.com/dominikbraun/graph"
"go.uber.org/zap"
)
type LoggingGraph[K comparable, T any] struct {
graph.Graph[K, T]
Log *zap.SugaredLogger
Hash func(T) K
}
func (g LoggingGraph[K, T]) AddVertex(value T, options ...func(*graph.VertexProperties)) error {
err := g.Graph.AddVertex(value, options...)
if err != nil {
g.Log.Errorf("AddVertex(%v) error: %v", g.Hash(value), err)
} else {
g.Log.Debugf("AddVertex(%v)", g.Hash(value))
}
return err
}
func (g LoggingGraph[K, T]) AddVerticesFrom(other graph.Graph[K, T]) error {
// TODO
return g.Graph.AddVerticesFrom(other)
}
func (g LoggingGraph[K, T]) RemoveVertex(hash K) error {
err := g.Graph.RemoveVertex(hash)
if err != nil {
g.Log.Errorf("RemoveVertex(%v) error: %v", hash, err)
} else {
g.Log.Debugf("RemoveVertex(%v)", hash)
}
return err
}
func (g LoggingGraph[K, T]) AddEdge(sourceHash K, targetHash K, options ...func(*graph.EdgeProperties)) error {
err := g.Graph.AddEdge(sourceHash, targetHash, options...)
if err != nil {
g.Log.Errorf("AddEdge(%v -> %v) error: %v", sourceHash, targetHash, err)
} else {
e, _ := g.Graph.Edge(sourceHash, targetHash)
if e.Properties.Data == nil {
g.Log.Debugf("AddEdge(%v -> %v)", sourceHash, targetHash)
} else {
g.Log.Debugf("AddEdge(%v -> %v, %+v)", sourceHash, targetHash, e.Properties.Data)
}
}
return err
}
func (g LoggingGraph[K, T]) AddEdgesFrom(other graph.Graph[K, T]) error {
// TODO
return g.Graph.AddEdgesFrom(other)
}
func (g LoggingGraph[K, T]) UpdateEdge(source K, target K, options ...func(properties *graph.EdgeProperties)) error {
err := g.Graph.UpdateEdge(source, target, options...)
if err != nil {
g.Log.Errorf("UpdateEdge(%v, %v) error: %v", source, target, err)
} else {
g.Log.Debugf("UpdateEdge(%v, %v)", source, target)
}
return err
}
func (g LoggingGraph[K, T]) RemoveEdge(source K, target K) error {
err := g.Graph.RemoveEdge(source, target)
if err != nil {
g.Log.Errorf("RemoveEdge(%v, %v) error: %v", source, target, err)
} else {
g.Log.Debugf("RemoveEdge(%v, %v)", source, target)
}
return err
}
func (g LoggingGraph[K, T]) Clone() (graph.Graph[K, T], error) {
cloned, err := g.Graph.Clone()
if err != nil {
return nil, err
}
return LoggingGraph[K, T]{Log: g.Log, Graph: cloned}, nil
}
package graph_addons
import (
"fmt"
"reflect"
"sync"
"github.com/dominikbraun/graph"
)
// MemoryStore is like the default store returned by [graph.New] except that [AddVertex] and [AddEdge]
// are idempotent - they do not return an error if the vertex or edge already exists with the exact same value.
type MemoryStore[K comparable, T comparable] struct {
lock sync.RWMutex
vertices map[K]T
vertexProperties map[K]graph.VertexProperties
// outEdges and inEdges store all outgoing and ingoing edges for all vertices. For O(1) access,
// these edges themselves are stored in maps whose keys are the hashes of the target vertices.
outEdges map[K]map[K]graph.Edge[K] // source -> target
inEdges map[K]map[K]graph.Edge[K] // target -> source
}
type equaller interface {
Equals(any) bool
}
func NewMemoryStore[K comparable, T comparable]() graph.Store[K, T] {
return &MemoryStore[K, T]{
vertices: make(map[K]T),
vertexProperties: make(map[K]graph.VertexProperties),
outEdges: make(map[K]map[K]graph.Edge[K]),
inEdges: make(map[K]map[K]graph.Edge[K]),
}
}
func vertexPropsEqual(a, b graph.VertexProperties) bool {
if a.Weight != b.Weight {
return false
}
if len(a.Attributes) != len(b.Attributes) {
return false
}
for k, aV := range a.Attributes {
if bV, ok := b.Attributes[k]; !ok || aV != bV {
return false
}
}
return true
}
func (s *MemoryStore[K, T]) AddVertex(k K, t T, p graph.VertexProperties) error {
s.lock.Lock()
defer s.lock.Unlock()
if p.Attributes == nil {
p.Attributes = make(map[string]string)
}
if existing, ok := s.vertices[k]; ok {
// Fastest check, use ==
if t == existing && vertexPropsEqual(s.vertexProperties[k], p) {
return nil
}
// Slower, check if it implements the equaller interface
var t any = t // this is needed to satisfy the compiler, since Go can't type assert on a generic type
if tEq, ok := t.(equaller); ok && tEq.Equals(existing) && vertexPropsEqual(s.vertexProperties[k], p) {
return nil
}
return &graph.VertexAlreadyExistsError[K, T]{
Key: k,
ExistingValue: existing,
}
}
s.vertices[k] = t
s.vertexProperties[k] = p
return nil
}
func (s *MemoryStore[K, T]) ListVertices() ([]K, error) {
s.lock.RLock()
defer s.lock.RUnlock()
hashes := make([]K, 0, len(s.vertices))
for k := range s.vertices {
hashes = append(hashes, k)
}
return hashes, nil
}
func (s *MemoryStore[K, T]) VertexCount() (int, error) {
s.lock.RLock()
defer s.lock.RUnlock()
return len(s.vertices), nil
}
func (s *MemoryStore[K, T]) Vertex(k K) (T, graph.VertexProperties, error) {
s.lock.RLock()
defer s.lock.RUnlock()
return s.vertexWithLock(k)
}
// vertexWithLock returns the vertex and vertex properties - the caller must be holding at least a
// read-level lock.
func (s *MemoryStore[K, T]) vertexWithLock(k K) (T, graph.VertexProperties, error) {
v, ok := s.vertices[k]
if !ok {
return v, graph.VertexProperties{}, &graph.VertexNotFoundError[K]{Key: k}
}
p := s.vertexProperties[k]
return v, p, nil
}
func (s *MemoryStore[K, T]) RemoveVertex(k K) error {
s.lock.RLock()
defer s.lock.RUnlock()
if _, ok := s.vertices[k]; !ok {
return &graph.VertexNotFoundError[K]{Key: k}
}
count := 0
if edges, ok := s.inEdges[k]; ok {
inCount := len(edges)
count += inCount
if inCount == 0 {
delete(s.inEdges, k)
}
}
if edges, ok := s.outEdges[k]; ok {
outCount := len(edges)
count += outCount
if outCount == 0 {
delete(s.outEdges, k)
}
}
if count > 0 {
return &graph.VertexHasEdgesError[K]{Key: k, Count: count}
}
delete(s.vertices, k)
delete(s.vertexProperties, k)
return nil
}
func edgesEqual[K comparable](a, b graph.Edge[K]) bool {
if a.Source != b.Source || a.Target != b.Target {
return false
}
// Do all that fast/easy comparisons first so failures are quick
if a.Properties.Weight != b.Properties.Weight {
return false
}
if len(a.Properties.Attributes) != len(b.Properties.Attributes) {
return false
}
for k, aV := range a.Properties.Attributes {
if bV, ok := b.Properties.Attributes[k]; !ok || aV != bV {
return false
}
}
if a.Properties.Data == nil || b.Properties.Data == nil {
// Can only safely check `==` if one is nil because a map cannot `==` anything else
return a.Properties.Data == b.Properties.Data
} else if aEq, ok := a.Properties.Data.(equaller); ok {
return aEq.Equals(b.Properties.Data)
} else if bEq, ok := b.Properties.Data.(equaller); ok {
return bEq.Equals(a.Properties.Data)
} else {
// Do the reflection last, since that is slow. We need to use reflection unlike for attributes
// because we don't know what type the data is.
return reflect.DeepEqual(a.Properties.Data, b.Properties.Data)
}
}
func (s *MemoryStore[K, T]) AddEdge(sourceHash, targetHash K, edge graph.Edge[K]) error {
s.lock.Lock()
defer s.lock.Unlock()
if _, _, err := s.vertexWithLock(sourceHash); err != nil {
return fmt.Errorf("could not get source vertex: %w", &graph.VertexNotFoundError[K]{Key: sourceHash})
}
if _, _, err := s.vertexWithLock(targetHash); err != nil {
return fmt.Errorf("could not get target vertex: %w", &graph.VertexNotFoundError[K]{Key: targetHash})
}
if existing, ok := s.outEdges[sourceHash][targetHash]; ok {
if !edgesEqual(existing, edge) {
return &graph.EdgeAlreadyExistsError[K]{Source: sourceHash, Target: targetHash}
}
}
if existing, ok := s.inEdges[targetHash][sourceHash]; ok {
if !edgesEqual(existing, edge) {
return &graph.EdgeAlreadyExistsError[K]{Source: sourceHash, Target: targetHash}
}
}
if _, ok := s.outEdges[sourceHash]; !ok {
s.outEdges[sourceHash] = make(map[K]graph.Edge[K])
}
s.outEdges[sourceHash][targetHash] = edge
if _, ok := s.inEdges[targetHash]; !ok {
s.inEdges[targetHash] = make(map[K]graph.Edge[K])
}
s.inEdges[targetHash][sourceHash] = edge
return nil
}
func (s *MemoryStore[K, T]) UpdateEdge(sourceHash, targetHash K, edge graph.Edge[K]) error {
s.lock.Lock()
defer s.lock.Unlock()
if _, err := s.edgeWithLock(sourceHash, targetHash); err != nil {
return err
}
s.outEdges[sourceHash][targetHash] = edge
s.inEdges[targetHash][sourceHash] = edge
return nil
}
func (s *MemoryStore[K, T]) RemoveEdge(sourceHash, targetHash K) error {
s.lock.Lock()
defer s.lock.Unlock()
delete(s.inEdges[targetHash], sourceHash)
delete(s.outEdges[sourceHash], targetHash)
return nil
}
func (s *MemoryStore[K, T]) Edge(sourceHash, targetHash K) (graph.Edge[K], error) {
s.lock.RLock()
defer s.lock.RUnlock()
return s.edgeWithLock(sourceHash, targetHash)
}
// edgeWithLock returns the edge - the caller must be holding at least a read-level lock.
func (s *MemoryStore[K, T]) edgeWithLock(sourceHash, targetHash K) (graph.Edge[K], error) {
sourceEdges, ok := s.outEdges[sourceHash]
if !ok {
return graph.Edge[K]{}, &graph.EdgeNotFoundError[K]{Source: sourceHash, Target: targetHash}
}
edge, ok := sourceEdges[targetHash]
if !ok {
return graph.Edge[K]{}, &graph.EdgeNotFoundError[K]{Source: sourceHash, Target: targetHash}
}
return edge, nil
}
func (s *MemoryStore[K, T]) ListEdges() ([]graph.Edge[K], error) {
s.lock.RLock()
defer s.lock.RUnlock()
res := make([]graph.Edge[K], 0)
for _, edges := range s.outEdges {
for _, edge := range edges {
res = append(res, edge)
}
}
return res, nil
}
// CreatesCycle is a fastpath version of [CreatesCycle] that avoids calling
// [PredecessorMap], which generates large amounts of garbage to collect.
//
// Because CreatesCycle doesn't need to modify the PredecessorMap, we can use
// inEdges instead to compute the same thing without creating any copies.
func (s *MemoryStore[K, T]) CreatesCycle(source, target K) (bool, error) {
if source == target {
return true, nil
}
s.lock.RLock()
defer s.lock.RUnlock()
if _, _, err := s.vertexWithLock(source); err != nil {
return false, fmt.Errorf("could not get source vertex: %w", err)
}
if _, _, err := s.vertexWithLock(target); err != nil {
return false, fmt.Errorf("could not get target vertex: %w", err)
}
stack := []K{source}
visited := make(map[K]struct{})
var currentHash K
for len(stack) > 0 {
currentHash, stack = stack[len(stack)-1], stack[:len(stack)-1]
if _, ok := visited[currentHash]; !ok {
// If the adjacent vertex also is the target vertex, the target is a
// parent of the source vertex. An edge would introduce a cycle.
if currentHash == target {
return true, nil
}
visited[currentHash] = struct{}{}
for adjacency := range s.inEdges[currentHash] {
stack = append(stack, adjacency)
}
}
}
return false, nil
}
package graph_addons
import (
"fmt"
"github.com/dominikbraun/graph"
)
type Path[K comparable] []K
func PathWeight[K comparable, V any](g graph.Graph[K, V], path Path[K]) (weight int, err error) {
if !g.Traits().IsWeighted {
return len(path), nil
}
for i := 1; i < len(path)-1; i++ {
var e graph.Edge[V]
e, err = g.Edge(path[i-1], path[i])
if err != nil {
err = fmt.Errorf("edge(path[%d], path[%d]): %w", i-1, i, err)
return
}
weight += e.Properties.Weight
}
return
}
func (p Path[K]) Contains(k K) bool {
for _, elem := range p {
if elem == k {
return true
}
}
return false
}
package graph_addons
import (
"errors"
"github.com/dominikbraun/graph"
)
func RemoveVertexAndEdges[K comparable, T any](g graph.Graph[K, T], id K) error {
edges, err := g.Edges()
if err != nil {
return err
}
var errs error
for _, e := range edges {
if e.Source != id && e.Target != id {
continue
}
errs = errors.Join(errs, g.RemoveEdge(e.Source, e.Target))
}
if errs != nil {
return errs
}
return g.RemoveVertex(id)
}
package graph_addons
import (
"errors"
"fmt"
"github.com/dominikbraun/graph"
)
func ReplaceVertex[K comparable, T any](g graph.Graph[K, T], oldId K, newValue T, hasher func(T) K) error {
newKey := hasher(newValue)
if newKey == oldId {
return nil
}
_, props, err := g.VertexWithProperties(oldId)
if err != nil {
return err
}
err = g.AddVertex(newValue, func(vp *graph.VertexProperties) { *vp = props })
if err != nil {
return fmt.Errorf("could not add new vertex %v: %w", newKey, err)
}
edges, err := g.Edges()
if err != nil {
return err
}
var errs error
for _, e := range edges {
if e.Source != oldId && e.Target != oldId {
continue
}
newEdge := e
if e.Source == oldId {
newEdge.Source = newKey
}
if e.Target == oldId {
newEdge.Target = newKey
}
edgeErr := errors.Join(
g.RemoveEdge(e.Source, e.Target),
g.AddEdge(newEdge.Source, newEdge.Target, func(ep *graph.EdgeProperties) { *ep = e.Properties }),
)
if edgeErr != nil {
errs = errors.Join(errs, fmt.Errorf("failed to update edge %v -> %v: %w", e.Source, e.Target, edgeErr))
}
}
if errs != nil {
return errs
}
return g.RemoveVertex(oldId)
}
package graph_addons
import "github.com/dominikbraun/graph"
// ReverseLess is a helper function that returns a new less function that reverses the order of the original less function.
func ReverseLess[K any](less func(K, K) bool) func(K, K) bool {
return func(a, b K) bool {
return less(b, a)
}
}
// TopologicalSort provides a stable topological ordering.
func ReverseTopologicalSort[K comparable, T any](g graph.Graph[K, T], less func(K, K) bool) ([]K, error) {
adjacencyMap, err := g.AdjacencyMap()
if err != nil {
return nil, err
}
return topologicalSort(adjacencyMap, less)
}
func ReverseGraph[K comparable, T any](g graph.Graph[K, T]) (graph.Graph[K, T], error) {
reverse := graph.NewLike(g)
err := reverse.AddVerticesFrom(g)
if err != nil {
return nil, err
}
edges, err := g.Edges()
if err != nil {
return nil, err
}
for _, e := range edges {
err = reverse.AddEdge(e.Target, e.Source, func(ep *graph.EdgeProperties) {
*ep = e.Properties
})
if err != nil {
return nil, err
}
}
return reverse, nil
}
package graph_addons
import (
"sort"
"github.com/dominikbraun/graph"
)
// TopologicalSort provides a stable topological ordering.
func TopologicalSort[K comparable, T any](g graph.Graph[K, T], less func(K, K) bool) ([]K, error) {
predecessors, err := g.PredecessorMap()
if err != nil {
return nil, err
}
return topologicalSort(predecessors, less)
}
// topologicalSort performs a topological sort on a graph with the given dependencies.
// Whether the sort is regular or reverse is determined by whether the `deps` map is a PredecessorMap or AdjacencyMap.
// The `less` function is used to determine the order of vertices in the result.
// This is a modified implementation of graph.StableTopologicalSort with the primary difference
// being any uses of the internal function `enqueueArbitrary`.
func topologicalSort[K comparable](deps map[K]map[K]graph.Edge[K], less func(K, K) bool) ([]K, error) {
if len(deps) == 0 {
return nil, nil
}
queue := make([]K, 0)
queued := make(map[K]struct{})
enqueue := func(vs ...K) {
for _, vertex := range vs {
queue = append(queue, vertex)
queued[vertex] = struct{}{}
}
}
for vertex, vdeps := range deps {
if len(vdeps) == 0 {
enqueue(vertex)
}
}
sort.Slice(queue, func(i, j int) bool {
return less(queue[i], queue[j])
})
// enqueueArbitrary enqueues an arbitray but deterministic id from the remaining unvisited ids.
// It should only be used if len(queue) == 0 && len(deps) > 0
enqueueArbitrary := func() {
remaining := make([]K, 0, len(deps))
for vertex := range deps {
remaining = append(remaining, vertex)
}
sort.Slice(remaining, func(i, j int) bool {
// Start based first on the number of remaining deps, prioritizing vertices with fewer deps
// to make it most likely to break any cycles, reducing the amount of arbitrary choices.
ic := len(deps[remaining[i]])
jc := len(deps[remaining[j]])
if ic != jc {
return ic < jc
}
// Tie-break using the less function on contents themselves
return less(remaining[i], remaining[j])
})
enqueue(remaining[0])
}
if len(queue) == 0 {
enqueueArbitrary()
}
order := make([]K, 0, len(deps))
visited := make(map[K]struct{})
for len(queue) > 0 {
currentVertex := queue[0]
queue = queue[1:]
if _, ok := visited[currentVertex]; ok {
continue
}
order = append(order, currentVertex)
visited[currentVertex] = struct{}{}
delete(deps, currentVertex)
frontier := make([]K, 0)
for vertex, predecessors := range deps {
delete(predecessors, currentVertex)
if len(predecessors) != 0 {
continue
}
if _, ok := queued[vertex]; ok {
continue
}
frontier = append(frontier, vertex)
}
sort.Slice(frontier, func(i, j int) bool {
return less(frontier[i], frontier[j])
})
enqueue(frontier...)
if len(queue) == 0 && len(deps) > 0 {
enqueueArbitrary()
}
}
return order, nil
}
package graph_addons
import (
"errors"
"github.com/dominikbraun/graph"
)
type WalkGraphFunc[K comparable] func(p Path[K], nerr error) error
var (
StopWalk = errors.New("stop walk")
SkipPath = errors.New("skip path")
)
// WalkUp walks up through the graph starting at `start` in BFS order.
func WalkUp[K comparable, T any](g graph.Graph[K, T], start K, f WalkGraphFunc[K]) error {
pred, err := g.PredecessorMap()
if err != nil {
return err
}
return walk(g, start, f, pred)
}
// WalkDown walks down through the graph starting at `start` in BFS order.
func WalkDown[K comparable, T any](g graph.Graph[K, T], start K, f WalkGraphFunc[K]) error {
adj, err := g.AdjacencyMap()
if err != nil {
return err
}
return walk(g, start, f, adj)
}
func walk[K comparable, T any](
g graph.Graph[K, T],
start K,
f WalkGraphFunc[K],
deps map[K]map[K]graph.Edge[K],
) error {
var queue []Path[K]
enqueue := func(current Path[K], next K) {
if current.Contains(next) {
// Prevent loops
return
}
// make a new slice because `append` won't copy if there's capacity
// which causes the latest `append` to overwrite the last element of any previous appends
// (as happens when appending in a loop as we do below).
// x := make([]int, 2, 3); x[0] = 1; x[1] = 2
// y := append(x, 3)
// z := append(x, 4)
// fmt.Println(y) // [1 2 4] !!
nextPath := make(Path[K], len(current)+1)
copy(nextPath, current)
nextPath[len(nextPath)-1] = next
queue = append(queue, nextPath)
}
startPath := Path[K]{start}
for d := range deps[start] {
enqueue(startPath, d)
}
var err error
var current Path[K]
for len(queue) > 0 {
current, queue = queue[0], queue[1:]
nerr := f(current, err)
if errors.Is(nerr, StopWalk) {
return err
}
if errors.Is(nerr, SkipPath) {
continue
}
err = nerr
last := current[len(current)-1]
for d := range deps[last] {
enqueue(current, d)
}
}
return err
}
package infra
import (
"bytes"
"errors"
"fmt"
"os"
"path/filepath"
"runtime/pprof"
construct "github.com/klothoplatform/klotho/pkg/construct"
engine "github.com/klothoplatform/klotho/pkg/engine"
"github.com/klothoplatform/klotho/pkg/engine/constraints"
"github.com/klothoplatform/klotho/pkg/infra/iac"
"github.com/klothoplatform/klotho/pkg/infra/kubernetes"
statereader "github.com/klothoplatform/klotho/pkg/infra/state_reader"
statetemplate "github.com/klothoplatform/klotho/pkg/infra/state_reader/state_template"
kio "github.com/klothoplatform/klotho/pkg/io"
"github.com/klothoplatform/klotho/pkg/knowledgebase/reader"
"github.com/klothoplatform/klotho/pkg/logging"
"github.com/klothoplatform/klotho/pkg/templates"
"github.com/spf13/cobra"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
)
var commonCfg struct {
verbose bool
jsonLog bool
}
var generateIacCfg struct {
provider string
inputGraph string
outputDir string
appName string
verbose bool
jsonLog bool
profileTo string
}
var getImportConstraintsCfg struct {
provider string
inputGraph string
stateFile string
}
func AddIacCli(root *cobra.Command) error {
flags := root.PersistentFlags()
flags.BoolVarP(&commonCfg.verbose, "verbose", "v", false, "Verbose flag")
flags.BoolVar(&commonCfg.jsonLog, "json-log", false, "Output logs in JSON format.")
root.PersistentPreRun = func(cmd *cobra.Command, args []string) {
logOpts := logging.LogOpts{
Verbose: commonCfg.verbose,
CategoryLogsDir: "", // IaC doesn't generate enough logs to warrant category-specific logs
}
if commonCfg.jsonLog {
logOpts.Encoding = "json"
}
zap.ReplaceGlobals(logOpts.NewLogger())
}
root.PersistentPostRun = func(cmd *cobra.Command, args []string) {
zap.L().Sync() //nolint:errcheck
}
generateCmd := &cobra.Command{
Use: "Generate",
Short: "Generate IaC for a given graph",
RunE: GenerateIac,
}
flags = generateCmd.Flags()
flags.StringVarP(&generateIacCfg.provider, "provider", "p", "pulumi", "Provider to use")
flags.StringVarP(&generateIacCfg.inputGraph, "input-graph", "i", "", "Input graph to use")
flags.StringVarP(&generateIacCfg.outputDir, "output-dir", "o", "", "Output directory to use")
flags.StringVarP(&generateIacCfg.appName, "app-name", "a", "", "App name to use")
flags.StringVar(&generateIacCfg.profileTo, "profiling", "", "Profile to file")
root.AddCommand(generateCmd)
getLiveStateCmd := &cobra.Command{
Use: "GetLiveState",
Short: "Reads the state file from the provider specified and translates it to Klotho Engine state graph.",
RunE: GetLiveState,
}
flags = getLiveStateCmd.Flags()
flags.StringVarP(&getImportConstraintsCfg.provider, "provider", "p", "pulumi", "Provider to use")
flags.StringVarP(&getImportConstraintsCfg.inputGraph, "input-graph", "i", "", "Input graph to use to provide additional context to the state file.")
flags.StringVarP(&getImportConstraintsCfg.stateFile, "state-file", "s", "", "State file to use")
root.AddCommand(getLiveStateCmd)
return nil
}
func GetLiveState(cmd *cobra.Command, args []string) error {
log := zap.S().Named("LiveState")
kb, err := reader.NewKBFromFs(templates.ResourceTemplates, templates.EdgeTemplates, templates.Models)
if err != nil {
return err
}
log.Info("Loaded knowledge base")
templates, err := statetemplate.LoadStateTemplates(getImportConstraintsCfg.provider)
if err != nil {
return err
}
log.Info("Loaded state templates")
// read in the state file
if getImportConstraintsCfg.stateFile == "" {
log.Error("State file path is empty")
return errors.New("state file path is empty")
}
log.Info("Reading state file")
stateBytes, err := os.ReadFile(getImportConstraintsCfg.stateFile)
if err != nil {
log.Error("Failed to read state file", zap.Error(err))
return err
}
var input engine.FileFormat
if getImportConstraintsCfg.inputGraph != "" {
inputF, err := os.Open(getImportConstraintsCfg.inputGraph)
if err != nil {
return err
}
defer inputF.Close()
err = yaml.NewDecoder(inputF).Decode(&input)
if err != nil {
log.Error("Failed to decode input graph", zap.Error(err))
return err
}
}
bytesReader := bytes.NewReader(stateBytes)
reader := statereader.NewPulumiReader(input.Graph, templates, kb)
result, err := reader.ReadState(bytesReader)
if err != nil {
return err
}
enc := yaml.NewEncoder(os.Stdout)
return enc.Encode(construct.YamlGraph{Graph: result})
}
func GenerateIac(cmd *cobra.Command, args []string) error {
if generateIacCfg.profileTo != "" {
err := os.MkdirAll(filepath.Dir(generateIacCfg.profileTo), 0755)
if err != nil {
return fmt.Errorf("failed to create profile directory: %w", err)
}
profileF, err := os.OpenFile(generateIacCfg.profileTo, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
return fmt.Errorf("failed to open profile file: %w", err)
}
defer func() {
pprof.StopCPUProfile()
profileF.Close()
}()
err = pprof.StartCPUProfile(profileF)
if err != nil {
return fmt.Errorf("failed to start profile: %w", err)
}
}
var files []kio.File
if generateIacCfg.inputGraph == "" {
return fmt.Errorf("input graph required")
}
inputF, err := os.Open(generateIacCfg.inputGraph)
if err != nil {
return fmt.Errorf("failed to open input graph: %w", err)
}
defer inputF.Close()
var input construct.YamlGraph
err = yaml.NewDecoder(inputF).Decode(&input)
if err != nil {
return err
}
kb, err := reader.NewKBFromFs(templates.ResourceTemplates, templates.EdgeTemplates, templates.Models)
if err != nil {
return err
}
solCtx := engine.NewSolution(cmd.Context(), kb, "", &constraints.Constraints{})
err = solCtx.LoadGraph(input.Graph)
if err != nil {
return err
}
kubernetesPlugin := kubernetes.Plugin{
AppName: generateIacCfg.appName,
KB: kb,
}
k8sfiles, err := kubernetesPlugin.Translate(solCtx)
if err != nil {
return err
}
files = append(files, k8sfiles...)
switch generateIacCfg.provider {
case "pulumi":
pulumiPlugin := iac.Plugin{
Config: &iac.PulumiConfig{AppName: generateIacCfg.appName},
KB: kb,
}
iacFiles, err := pulumiPlugin.Translate(solCtx)
if err != nil {
return err
}
files = append(files, iacFiles...)
default:
return fmt.Errorf("provider %s not supported", generateIacCfg.provider)
}
err = kio.OutputTo(files, generateIacCfg.outputDir)
if err != nil {
return err
}
return nil
}
package iac
import (
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/solution"
kio "github.com/klothoplatform/klotho/pkg/io"
)
// RenderDockerfiles is a temporary workaround for rendering trivial Dockerfiles for resources.
// Ideally this isn't explicit but instead handled by a template in some fashion.
func RenderDockerfiles(ctx solution.Solution) ([]kio.File, error) {
resources, err := construct.ReverseTopologicalSort(ctx.DeploymentGraph())
if err != nil {
return nil, err
}
var files []kio.File
for _, rid := range resources {
if rid.QualifiedTypeName() != "aws:ecr_image" {
continue
}
res, err := ctx.DeploymentGraph().Vertex(rid)
if err != nil {
return nil, err
}
baseImage, err := res.GetProperty("BaseImage")
if err != nil {
return nil, err
}
if baseImage == nil {
continue
}
dockerfile, err := res.GetProperty("Dockerfile")
if err != nil {
return nil, err
}
files = append(files, &kio.RawFile{
FPath: dockerfile.(string),
Content: []byte("FROM " + baseImage.(string)),
})
}
return files, nil
}
package iac
import (
"errors"
construct "github.com/klothoplatform/klotho/pkg/construct"
)
func (tc *TemplatesCompiler) AddExtraResources(r construct.ResourceId) error {
switch r.QualifiedTypeName() {
case "aws:eks_cluster":
err := addKubernetesProvider(tc.graph, r)
err = errors.Join(err, addIngressRuleToCluster(tc.graph, r))
return err
case "aws:public_subnet", "aws:private_subnet":
return addRouteTableAssociation(tc.graph, r)
case "aws:target_group":
return addTargetGroupAttachment(tc.graph, r)
}
return nil
}
func addKubernetesProvider(g construct.Graph, cluster construct.ResourceId) error {
clusterRes, err := g.Vertex(cluster)
if err != nil {
return err
}
kubeConfig, ok := clusterRes.Properties["KubeConfig"].(construct.ResourceId)
if !ok {
return errors.New("cluster must have KubeConfig property")
}
gb := construct.NewGraphBatch(g)
provider := &construct.Resource{
ID: construct.ResourceId{
Provider: "pulumi",
Type: "kubernetes_provider",
Name: cluster.Name,
},
Properties: construct.Properties{
"KubeConfig": kubeConfig,
},
}
gb.AddVertices(provider)
gb.AddEdges(construct.Edge{Source: provider.ID, Target: kubeConfig})
downstream, err := construct.DirectUpstreamDependencies(g, cluster)
if err != nil {
return err
}
for _, dep := range downstream {
depR, _ := g.Vertex(dep)
for name, prop := range depR.Properties {
if prop == cluster {
depR.Properties[name+"Provider"] = provider.ID
gb.AddEdges(construct.Edge{Source: dep, Target: provider.ID})
}
}
}
return gb.Err
}
// addIngressRuleToCluster TODO move this into engine
func addIngressRuleToCluster(g construct.Graph, cluster construct.ResourceId) error {
clusterRes, err := g.Vertex(cluster)
if err != nil {
return err
}
subnets, ok := clusterRes.Properties["Subnets"].([]construct.ResourceId)
if !ok {
return errors.New("cluster must have Subnets property")
}
cidrBlocks := make([]construct.PropertyRef, len(subnets))
for i, subnet := range subnets {
cidrBlocks[i] = construct.PropertyRef{
Resource: subnet,
Property: "cidr_block",
}
}
sgRule := &construct.Resource{
ID: construct.ResourceId{
Provider: "aws",
Type: "security_group_rule",
Namespace: cluster.Name,
Name: "ingress",
},
Properties: construct.Properties{
"Description": "Allows access to cluster from the VPCs private and public subnets",
"FromPort": 0,
"ToPort": 0,
"Protocol": "-1",
"CidrBlocks": cidrBlocks,
"SecurityGroupId": construct.PropertyRef{
Resource: cluster,
Property: "cluster_security_group_id",
},
"Type": "ingress",
},
}
gb := construct.NewGraphBatch(g)
gb.AddVertices(sgRule)
gb.AddEdges(construct.Edge{Source: sgRule.ID, Target: cluster})
return gb.Err
}
// addRouteTableAssociation TODO move this into engine
func addRouteTableAssociation(g construct.Graph, subnet construct.ResourceId) error {
upstream, err := construct.DirectUpstreamDependencies(g, subnet)
if err != nil {
return err
}
gb := construct.NewGraphBatch(g)
for _, routeTable := range upstream {
if routeTable.QualifiedTypeName() != "aws:route_table" {
continue
}
association := &construct.Resource{
ID: construct.ResourceId{
Provider: "aws",
Type: "route_table_association",
Namespace: subnet.Name,
Name: "association",
},
Properties: construct.Properties{
"Subnet": subnet,
"RouteTable": routeTable,
},
}
gb.AddVertices(association)
gb.AddEdges(
construct.Edge{Source: association.ID, Target: subnet},
construct.Edge{Source: association.ID, Target: routeTable},
)
}
return gb.Err
}
// addTargetGroupAttachment TODO move this into engine
func addTargetGroupAttachment(g construct.Graph, tg construct.ResourceId) error {
tgRes, err := g.Vertex(tg)
if err != nil {
return err
}
targets, ok := tgRes.Properties["Targets"].([]construct.ResourceId)
if !ok {
return errors.New("target group must have Targets property")
}
gb := construct.NewGraphBatch(g)
for _, target := range targets {
attachment := &construct.Resource{
ID: construct.ResourceId{
Provider: "aws",
Type: "target_group_attachment",
Namespace: tg.Name,
Name: target.Name,
},
Properties: construct.Properties{
"Port": construct.PropertyRef{Resource: target, Property: "Port"},
"TargetGroupArn": construct.PropertyRef{Resource: tg, Property: "Arn"},
"TargetId": construct.PropertyRef{Resource: target, Property: "Id"},
},
}
gb.AddVertices(attachment)
gb.AddEdges(
construct.Edge{Source: attachment.ID, Target: tg},
construct.Edge{Source: attachment.ID, Target: target},
)
}
return gb.Err
}
package iac
import (
"fmt"
"sort"
"strings"
)
type (
MapMarshaller interface {
Map() map[string]any
String() string
SetKey(val any)
}
ListMarshaller interface {
List() []any
String() string
Append(val any)
}
TsMap map[string]any
TsList []any
)
func (m TsMap) String() string {
buf := strings.Builder{}
buf.WriteRune('{')
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
i := 0
for _, k := range keys {
v := m[k]
buf.WriteString(fmt.Sprintf("%s: %v", k, v))
if i < len(m)-1 {
buf.WriteString(", ")
}
i++
}
buf.WriteRune('}')
return buf.String()
}
func (l TsList) String() string {
buf := strings.Builder{}
buf.WriteRune('[')
for i, v := range l {
fmt.Fprintf(&buf, "%v", v)
if i < len(l)-1 {
buf.WriteString(", ")
}
}
buf.WriteRune(']')
return buf.String()
}
package iac
import (
"bufio"
"bytes"
"embed"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"regexp"
"sort"
"strings"
"text/template"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/solution"
kio "github.com/klothoplatform/klotho/pkg/io"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/klothoplatform/klotho/pkg/templateutils"
)
type (
PulumiConfig struct {
AppName string
}
Plugin struct {
Config *PulumiConfig
KB knowledgebase.TemplateKB
}
)
func (p Plugin) Name() string {
return "pulumi3"
}
var (
//go:embed Pulumi.yaml.tmpl Pulumi.dev.yaml.tmpl templates/globals.ts templates/tsconfig.json
files embed.FS
//go:embed templates/aws/*/factory.ts templates/aws/*/package.json templates/aws/*/*.ts.tmpl
//go:embed templates/kubernetes/*/factory.ts templates/kubernetes/*/package.json templates/kubernetes/*/*.ts.tmpl
standardTemplates embed.FS
pulumiBase = templateutils.MustTemplate(files, "Pulumi.yaml.tmpl")
pulumiStack = templateutils.MustTemplate(files, "Pulumi.dev.yaml.tmpl")
)
func (p Plugin) Translate(sol solution.Solution) ([]kio.File, error) {
err := p.sanitizeConfig()
if err != nil {
return nil, err
}
// TODO We'll eventually want to split the output into different files, but we don't know exactly what that looks
// like yet. For now, just write to a single file, "index.ts"
buf := getBuffer()
defer releaseBuffer(buf)
templatesFS, err := fs.Sub(standardTemplates, "templates")
if err != nil {
return nil, err
}
err = addPulumiKubernetesProviders(sol.DeploymentGraph())
if err != nil {
return nil, fmt.Errorf("error adding pulumi kubernetes providers: %w", err)
}
tc := &TemplatesCompiler{
graph: sol.DeploymentGraph(),
templates: &templateStore{fs: templatesFS},
}
tc.vars, err = VariablesFromGraph(tc.graph)
if err != nil {
return nil, err
}
if err := tc.RenderImports(buf); err != nil {
return nil, err
}
buf.WriteString("\n\n")
if err := renderGlobals(buf); err != nil {
return nil, err
}
resources, err := construct.ReverseTopologicalSort(tc.graph)
if err != nil {
return nil, err
}
var errs error
for _, r := range resources {
errs = errors.Join(errs, tc.RenderResource(buf, r))
buf.WriteString("\n")
}
if errs != nil {
return nil, errs
}
buf.WriteString("\n")
renderStackOutputs(tc, buf, sol.Outputs())
buf.WriteString("\n")
tc.renderUrnMap(buf, resources)
indexTs := &kio.RawFile{
FPath: `index.ts`,
Content: make([]byte, buf.Len()),
}
copy(indexTs.Content, buf.Bytes())
pJson, err := tc.PackageJSON()
if err != nil {
return nil, err
}
pulumiYaml, err := addTemplate("Pulumi.yaml", pulumiBase, p.Config)
if err != nil {
return nil, err
}
pulumiStack, err := addTemplate(fmt.Sprintf("Pulumi.%s.yaml", p.Config.AppName), pulumiStack, p.Config)
if err != nil {
return nil, err
}
var content []byte
content, err = files.ReadFile("templates/tsconfig.json")
if err != nil {
return nil, err
}
tsConfig := &kio.RawFile{
FPath: "tsconfig.json",
Content: content,
}
files := []kio.File{indexTs, pJson, pulumiYaml, pulumiStack, tsConfig}
dockerfiles, err := RenderDockerfiles(sol)
if err != nil {
return nil, err
}
files = append(files, dockerfiles...)
return files, nil
}
func renderStackOutputs(tc *TemplatesCompiler, buf *bytes.Buffer, outputs map[string]construct.Output) {
buf.WriteString("export const $outputs = {\n")
names := make([]string, 0, len(outputs))
for name := range outputs {
names = append(names, name)
}
sort.Strings(names)
for _, name := range names {
output := outputs[name]
if !output.Ref.IsZero() {
val, err := tc.PropertyRefValue(output.Ref)
if err != nil {
buf.WriteString(fmt.Sprintf("\t%s: null,\n", name))
continue
}
buf.WriteString(fmt.Sprintf("\t%s: %s,\n", name, val))
} else {
val, err := json.Marshal(output.Value)
if err != nil {
buf.WriteString(fmt.Sprintf("\t%s: null,\n", name))
} else {
buf.WriteString(fmt.Sprintf("\t%s: %s,\n", name, string(val)))
}
}
}
buf.WriteString("}\n")
}
func (tc *TemplatesCompiler) renderUrnMap(buf *bytes.Buffer, resources []construct.ResourceId) {
buf.WriteString("export const $urns = {\n")
for _, id := range resources {
obj, ok := tc.vars[id]
if !ok {
continue
}
// in TS/JS, if the object doesn't have property `urn`, it will be `undefined` and will not throw any errors
buf.WriteString(fmt.Sprintf("\t\"%s\": (%s as any).urn,\n", id, obj))
}
buf.WriteString("}\n")
}
func (p *Plugin) sanitizeConfig() error {
reg, err := regexp.Compile("[^a-zA-Z0-9-_]+")
if err != nil {
return fmt.Errorf("error compiling regex: %v", err)
}
p.Config.AppName = reg.ReplaceAllString(p.Config.AppName, "")
return nil
}
func renderGlobals(w io.Writer) error {
globalsFile, err := files.Open("templates/globals.ts")
if err != nil {
return err
}
defer globalsFile.Close()
scan := bufio.NewScanner(globalsFile)
for scan.Scan() {
text := strings.TrimSpace(scan.Text())
if text == "" {
continue
}
if strings.HasPrefix(text, "import") {
continue
}
text = strings.TrimPrefix(text, "export ")
_, err := fmt.Fprintln(w, text)
if err != nil {
return err
}
}
_, err = fmt.Fprintln(w)
return err
}
func addPulumiKubernetesProviders(g construct.Graph) error {
providers := make(map[construct.ResourceId]*construct.Resource)
kubeconfigId := construct.ResourceId{Provider: "kubernetes", Type: "kube_config"}
err := construct.WalkGraph(g, func(id construct.ResourceId, resource *construct.Resource, nerr error) error {
if !kubeconfigId.Matches(id) {
return nerr
}
provider := &construct.Resource{
ID: construct.ResourceId{
Provider: "kubernetes",
Type: "kubernetes_provider",
Name: id.Name,
},
Properties: construct.Properties{
"KubeConfig": id,
},
}
err := g.AddVertex(provider)
if err != nil {
return errors.Join(nerr, err)
}
err = g.AddEdge(provider.ID, id)
if err != nil {
return errors.Join(nerr, err)
}
providers[id] = provider
return nerr
})
if err != nil {
return err
}
err = construct.WalkGraph(g, func(id construct.ResourceId, resource *construct.Resource, nerr error) error {
if id.Provider != "kubernetes" {
return nerr
}
cluster, err := resource.GetProperty("Cluster")
if err != nil {
return errors.Join(nerr, err)
}
if cluster == nil {
return nerr
}
clusterId, ok := cluster.(construct.ResourceId)
if !ok {
return errors.Join(nerr, fmt.Errorf("resource %s is a kubernetes resource but does not have an id as cluster property (is: %T)", id, cluster))
}
upstreams, err := construct.DirectUpstreamDependencies(g, clusterId)
if err != nil {
return errors.Join(nerr, err)
}
var kubeconfig construct.ResourceId
for _, upstream := range upstreams {
if kubeconfigId.Matches(upstream) {
kubeconfig = upstream
break
}
}
provider, ok := providers[kubeconfig]
if !ok {
return errors.Join(nerr, fmt.Errorf("resource %s is a kubernetes resource but does not have a provider resource for cluster %s", id, clusterId))
}
err = resource.SetProperty("Provider", provider.ID)
if err != nil {
return errors.Join(nerr, err)
}
err = g.AddEdge(id, provider.ID)
if err != nil {
return errors.Join(nerr, err)
}
return nerr
})
return err
}
func addTemplate(name string, t *template.Template, data any) (*kio.RawFile, error) {
buf := new(bytes.Buffer) // Don't use the buffer pool since RawFile uses the byte array
err := t.Execute(buf, data)
if err != nil {
return nil, fmt.Errorf("error executing template %s: %w", name, err)
}
return &kio.RawFile{
FPath: name,
Content: buf.Bytes(),
}, nil
}
package iac
import (
"bytes"
"fmt"
"sync"
"text/template"
construct "github.com/klothoplatform/klotho/pkg/construct"
)
var bufPool = sync.Pool{
New: func() any {
return new(bytes.Buffer)
},
}
func getBuffer() *bytes.Buffer {
buf := bufPool.Get().(*bytes.Buffer)
buf.Reset()
return buf
}
func releaseBuffer(buf *bytes.Buffer) {
bufPool.Put(buf)
}
func executeToString(tmpl *template.Template, data any) (string, error) {
buf := getBuffer()
defer releaseBuffer(buf)
err := tmpl.Execute(buf, data)
if err != nil {
return "", err
}
return buf.String(), nil
}
func (tc *TemplatesCompiler) PropertyRefValue(ref construct.PropertyRef) (any, error) {
tmpl, err := tc.ResourceTemplate(ref.Resource)
if err != nil {
return nil, err
}
refRes, err := tc.graph.Vertex(ref.Resource)
if err != nil {
return nil, err
}
if tmpl.PropertyTemplates != nil {
if mapping, ok := tmpl.PropertyTemplates[ref.Property]; ok {
inputArgs, err := tc.getInputArgs(refRes, tmpl)
if err != nil {
return nil, err
}
data := PropertyTemplateData{
Resource: ref.Resource,
Object: tc.vars[ref.Resource],
Input: inputArgs,
}
return executeToString(mapping, data)
}
}
path, err := refRes.PropertyPath(ref.Property)
if err != nil {
return nil, err
}
if path != nil {
val, _ := path.Get()
if val == nil {
return nil, fmt.Errorf("property ref %s is nil", ref)
}
return tc.convertArg(val, nil)
}
return nil, fmt.Errorf("unsupported property ref %s", ref)
}
package iac
import (
"errors"
"fmt"
"io"
"path"
"reflect"
"regexp"
"sort"
"strings"
"text/template"
"github.com/iancoleman/strcase"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/set"
"go.uber.org/zap"
)
type templateInputArgs map[string]any
var validIdentifierPattern = regexp.MustCompile(`^[a-zA-Z_$][a-zA-Z_$0-9]*$`)
func (tc *TemplatesCompiler) RenderResource(out io.Writer, rid construct.ResourceId) error {
resTmpl, err := tc.ResourceTemplate(rid)
if err != nil {
return err
}
r, err := tc.graph.Vertex(rid)
if err != nil {
return err
}
inputs, err := tc.getInputArgs(r, resTmpl)
if err != nil {
return err
}
if resTmpl.OutputType != "void" {
_, err = fmt.Fprintf(out, "const %s = ", tc.vars[rid])
if err != nil {
return err
}
}
if r.Imported {
if resTmpl.ImportResource == nil {
return fmt.Errorf("resource %s is imported but has no import resource template", rid)
}
err = resTmpl.ImportResource.Execute(out, inputs)
if err != nil {
return fmt.Errorf("could not render resource %s: %w", rid, err)
}
} else {
err = resTmpl.Template.Execute(out, inputs)
if err != nil {
return fmt.Errorf("could not render resource %s: %w", rid, err)
}
}
exportData := PropertyTemplateData{
Resource: rid,
Object: tc.vars[rid],
Input: inputs,
}
var errs error
for export, tmpl := range resTmpl.Exports {
_, err = fmt.Fprintf(out, "\nexport const %s_%s = ", tc.vars[rid], export)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("could not render export name %s: %w", export, err))
continue
}
err = tmpl.Execute(out, exportData)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("could not render export value %s: %w", export, err))
continue
}
}
if errs != nil {
return errs
}
return nil
}
func (tc *TemplatesCompiler) convertArg(arg any, templateArg *Arg) (any, error) {
switch arg := arg.(type) {
case construct.ResourceId:
return tc.vars[arg], nil
case construct.PropertyRef:
return tc.PropertyRefValue(arg)
case string:
// use templateString to quote the string value
return templateString(arg), nil
case bool, int, float64:
// safe to use as-is
return arg, nil
case nil:
// TODO when we're more confident in the properties, replace the `nil` with `undefined`
// This will render as `<no content>`, so any properties that are optional
// must be guarded by an `if` in the template.
return nil, nil
default:
switch val := reflect.ValueOf(arg); val.Kind() {
case reflect.Slice, reflect.Array:
list := make(TsList, 0, val.Len())
for i := 0; i < val.Len(); i++ {
if !val.Index(i).IsValid() || val.Index(i).IsNil() {
continue
}
output, err := tc.convertArg(val.Index(i).Interface(), templateArg)
if err != nil {
return "", err
}
list = append(list, output)
}
return list, nil
case reflect.Map:
TsMap := make(TsMap, val.Len())
for _, key := range val.MapKeys() {
if !val.MapIndex(key).IsValid() || val.MapIndex(key).IsZero() {
continue
}
keyStr, found := key.Interface().(string)
if !found {
return "", fmt.Errorf("map key is not a string")
}
keyResult := strcase.ToLowerCamel(keyStr)
if templateArg != nil {
switch templateArg.Wrapper {
case CamelCaseWrapper:
keyResult = strcase.ToCamel(keyStr)
case ModelCaseWrapper:
if validIdentifierPattern.MatchString(keyStr) {
keyResult = keyStr
} else {
keyResult = fmt.Sprintf(`"%s"`, keyStr)
}
}
}
output, err := tc.convertArg(val.MapIndex(key).Interface(), templateArg)
if err != nil {
return "", err
}
TsMap[keyResult] = output
}
return TsMap, nil
case reflect.Struct:
if hashset, ok := val.Interface().(set.HashedSet[string, any]); ok {
return tc.convertArg(hashset.ToSlice(), templateArg)
}
fallthrough
default:
return jsonValue{Raw: arg}, nil
}
}
}
func (tc *TemplatesCompiler) getInputArgs(r *construct.Resource, template *ResourceTemplate) (templateInputArgs, error) {
var errs error
inputs := make(map[string]any, len(r.Properties)+len(globalVariables)+2) // +2 for Name and dependsOn
selfReferences := make(map[string]construct.PropertyRef)
for name, value := range r.Properties {
templateArg := template.Args[name]
var argValue any
var err error
if templateArg.Wrapper == TemplateWrapper {
argValue, err = tc.useNestedTemplate(template, value, templateArg)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("could not use nested template for arg %q: %w", name, err))
continue
}
} else if ref, ok := value.(construct.PropertyRef); ok && ref.Resource == r.ID {
selfReferences[name] = ref
} else {
if name == "EnvironmentVariables" {
zap.S().Debugf("EnvironmentVariables: %v", value)
}
argValue, err = tc.convertArg(value, &templateArg)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("could not convert arg %q: %w", name, err))
continue
}
}
if argValue != nil {
inputs[name] = argValue
}
}
for name, value := range selfReferences {
if mapping, ok := template.PropertyTemplates[value.Property]; ok {
data := PropertyTemplateData{
Resource: r.ID,
Object: tc.vars[r.ID],
Input: inputs,
}
result, err := executeToString(mapping, data)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("could not execute self-reference %q: %w", name, err))
continue
}
inputs[name] = result
} else {
errs = errors.Join(errs, fmt.Errorf("could not find mapping for self-reference %q", name))
}
}
if errs != nil {
return templateInputArgs{}, errs
}
downstream, err := construct.DirectDownstreamDependencies(tc.graph, r.ID)
if err != nil {
return templateInputArgs{}, err
}
var dependsOn []string
var applied appliedOutputs
for _, dep := range downstream {
switch dep.QualifiedTypeName() {
case "aws:region", "aws:availability_zone", "aws:account_id":
continue
case "kubernetes:manifest", "kubernetes:kustomize_directory":
ao := tc.NewAppliedOutput(construct.PropertyRef{
Resource: dep,
// resources: pulumi.Output<{
// [key: string]: pulumi.CustomResource;
// }>
Property: "resources",
}, "")
applied = append(applied, ao)
dependsOn = append(dependsOn, fmt.Sprintf("...Object.values(%s)", ao.Name))
default:
dependsOn = append(dependsOn, tc.vars[dep])
}
}
sort.Strings(dependsOn)
if len(applied) > 0 {
buf := getBuffer()
defer releaseBuffer(buf)
err = applied.Render(buf, func(w io.Writer) error {
_, err := w.Write([]byte("["))
if err != nil {
return err
}
for i, dep := range dependsOn {
_, err = w.Write([]byte(dep))
if err != nil {
return err
}
if i < len(dependsOn)-1 {
_, err = w.Write([]byte(", "))
if err != nil {
return err
}
}
}
_, err = w.Write([]byte("]"))
if err != nil {
return err
}
return nil
})
if err != nil {
return templateInputArgs{}, err
}
inputs["dependsOn"] = buf.String()
} else {
inputs["dependsOn"] = "[" + strings.Join(dependsOn, ", ") + "]"
}
inputs["Name"] = templateString(r.ID.Name)
for g := range globalVariables {
inputs[g] = g
}
return inputs, nil
}
func (tc *TemplatesCompiler) useNestedTemplate(resTmpl *ResourceTemplate, val any, arg Arg) (string, error) {
var contents []byte
var err error
nestedTemplatePath := path.Join(resTmpl.Path, strcase.ToSnake(arg.Name)+".ts.tmpl")
f, err := tc.templates.fs.Open(nestedTemplatePath)
if err != nil {
return "", fmt.Errorf("could not find template for %s: %w", nestedTemplatePath, err)
}
contents, err = io.ReadAll(f)
if err != nil {
return "", fmt.Errorf("could not read template for %s: %w", nestedTemplatePath, err)
}
if len(contents) == 0 {
return "", fmt.Errorf("no contents in template for %s: %w", nestedTemplatePath, err)
}
tmpl, err := template.New(nestedTemplatePath).Funcs(template.FuncMap{
"modelCase": tc.modelCase,
"lowerCamelCase": tc.lowerCamelCase,
"camelCase": tc.camelCase,
"getVar": func(id construct.ResourceId) string {
return tc.vars[id]
},
}).Parse(string(contents))
if err != nil {
return "", fmt.Errorf("could not parse template for %s: %w", nestedTemplatePath, err)
}
result := getBuffer()
err = tmpl.Execute(result, val)
if err != nil {
return "", fmt.Errorf("could not execute template for %s: %w", nestedTemplatePath, err)
}
return result.String(), nil
}
func (tc *TemplatesCompiler) modelCase(val any) (any, error) {
return tc.convertArg(val, &Arg{Wrapper: ModelCaseWrapper})
}
func (tc *TemplatesCompiler) lowerCamelCase(val any) (any, error) {
return tc.convertArg(val, &Arg{Wrapper: LowerCamelCaseWrapper})
}
func (tc *TemplatesCompiler) camelCase(val any) (any, error) {
return tc.convertArg(val, &Arg{Wrapper: CamelCaseWrapper})
}
package iac
import (
"context"
_ "embed"
"errors"
"fmt"
"io"
"path/filepath"
"regexp"
"sort"
"strings"
"text/template"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/query"
sitter "github.com/smacker/go-tree-sitter"
"github.com/smacker/go-tree-sitter/typescript/typescript"
)
type (
ResourceTemplate struct {
Name string
Imports []string
OutputType string
Template *template.Template
PropertyTemplates map[string]*template.Template
Args map[string]Arg
Path string
Exports map[string]*template.Template
ImportResource *template.Template
}
PropertyTemplateData struct {
Resource construct.ResourceId
Object string
Input templateInputArgs
}
Arg struct {
Name string
Type string
Wrapper WrapperType
}
WrapperType string
)
const (
TemplateWrapper WrapperType = "TemplateWrapper"
CamelCaseWrapper WrapperType = "CamelCaseWrapper"
LowerCamelCaseWrapper WrapperType = "LowerCamelCaseWrapper"
ModelCaseWrapper WrapperType = "ModelCaseWrapper"
)
var (
//go:embed find_create_func.scm
findCreateFuncQuery string
//go:embed find_imports.scm
findImportsQuery string
//go:embed find_props_func.scm
findPropsFuncQuery string
//go:embed find_property.scm
findPropertyQuery string
//go:embed find_args.scm
findArgs string
//go:embed find_export_func.scm
findExportFunc string
//go:embed find_import_func.scm
findImportFunc string
curlyEscapes = regexp.MustCompile(`~~{{`)
templateComments = regexp.MustCompile(`//*TMPL ?`)
)
func (tc *TemplatesCompiler) ParseTemplate(name string, r io.Reader) (*ResourceTemplate, error) {
rt := &ResourceTemplate{Name: name}
node, err := parseFile(r)
if err != nil {
return nil, err
}
rt.Template, rt.OutputType, err = createNodeToTemplate(node, name)
if err != nil {
return nil, err
}
rt.PropertyTemplates, err = propertiesNodeToTemplate(node, name)
if err != nil {
return nil, err
}
rt.Imports, err = importsFromTemplate(node)
if err != nil {
return nil, err
}
rt.Args, err = parseArgs(node, name)
if err != nil {
return nil, err
}
rt.Exports, err = exportsNodeToTemplate(tc, rt, node, name)
if err != nil {
return nil, err
}
var outputType string
rt.ImportResource, outputType, err = importFuncNodeToTemplate(node, name)
if err != nil {
return nil, err
}
if outputType != "" && outputType != rt.OutputType {
return nil, fmt.Errorf("output type mismatch: %s != %s", outputType, rt.OutputType)
}
return rt, nil
}
func parseArgs(node *sitter.Node, name string) (map[string]Arg, error) {
argsFunc := doQuery(node, findArgs)
args := map[string]Arg{}
for {
argMatches, found := argsFunc()
if !found {
break
}
interfaceName := argMatches["name"].Content()
if interfaceName != "Args" {
continue
}
argName := argMatches["property_name"].Content()
argType := argMatches["property_type"].Content()
argWrapper := argMatches["nested"]
if argWrapper == nil {
args[argName] = Arg{Name: argName, Type: argType}
continue
}
args[argName] = Arg{Name: argName, Type: argType, Wrapper: WrapperType(argWrapper.Content())}
}
return args, nil
}
// getReturn returns the top-level return _statement node in a function `node`.
// Returns nil if no return statement is found.
// Use this instead of a query so that it doesn't pick any nested return functions instead
// (found in anonymous functions) or otherwise.
func getReturn(node *sitter.Node) *sitter.Node {
for i := 0; i < int(node.NamedChildCount()); i++ {
child := node.NamedChild(i)
if child.Type() == "return_statement" {
if child.ChildCount() == 0 {
return nil
}
// Unwrap to the actual value so that for example:
// return 1
// will have the result's Content() be `1` and not `return 1`
return child.NamedChild(0)
}
}
return nil
}
func createNodeToTemplate(node *sitter.Node, name string) (*template.Template, string, error) {
createFunc := doQuery(node, findCreateFuncQuery)
create, found := createFunc()
if !found {
return nil, "", fmt.Errorf("no create function found in %s", name)
}
outputType := create["return_type"].Content()
var expressionBody string
if outputType == "void" {
expressionBody = bodyContents(create["body"])
} else {
body := getReturn(create["body"])
if body == nil {
return nil, "", fmt.Errorf("no 'return' found in %s body:```\n%s\n```", name, create["body"].Content())
}
expressionBody = body.Content()
}
expressionBody = parameterizeArgs(expressionBody, "")
expressionBody = templateComments.ReplaceAllString(expressionBody, "")
// transform escaped double curly brace literals e.g. ~~{{ .ID }} -> {{ `{{` }} .ID }}
expressionBody = curlyEscapes.ReplaceAllString(expressionBody, "{{ `{{` }}")
tmpl, err := template.New(name).Funcs(template.FuncMap{
"dir": func(path templateString) string {
// dir returns the parent directory of the current path
return filepath.Dir(string(path))
},
"filepathBase": func(path templateString) string {
// filename returns the basename of the current path
return filepath.Base(string(path))
},
"matches": func(pattern string, value templateString) bool {
// matches returns true if the value matches the pattern
matched, _ := regexp.MatchString(pattern, string(value))
return matched
}}).Parse(expressionBody)
return tmpl, outputType, err
}
func propertiesNodeToTemplate(node *sitter.Node, name string) (map[string]*template.Template, error) {
propsFunc := doQuery(node, findPropsFuncQuery)
propsNode, found := propsFunc()
if !found {
return nil, nil
}
propTemplates := make(map[string]*template.Template)
var errs error
nextProp := doQuery(propsNode["body"], findPropertyQuery)
for {
propMatches, found := nextProp()
if !found {
break
}
propName := propMatches["key"].Content()
valueBase := propMatches["value"].Content()
valueBase = parameterizeArgs(valueBase, ".Input")
valueBase = parameterizeObject(valueBase)
t, err := template.New(propName).Parse(valueBase)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("error parsing property %q: %w", propName, err))
continue
}
propTemplates[propName] = t
}
return propTemplates, errs
}
func exportsNodeToTemplate(tc *TemplatesCompiler, tmpl *ResourceTemplate, node *sitter.Node, name string) (map[string]*template.Template, error) {
exportFunc := doQuery(node, findExportFunc)
exportsNode, found := exportFunc()
if !found {
return nil, nil
}
exportsTemplates := make(map[string]*template.Template)
var errs error
nextProp := doQuery(exportsNode["body"], findPropertyQuery)
for {
propMatches, found := nextProp()
if !found {
break
}
propName := propMatches["key"].Content()
valueBase := propMatches["value"].Content()
valueBase = parameterizeArgs(valueBase, ".Input")
valueBase = parameterizeObject(valueBase)
valueBase = parameterizeProps(valueBase)
t, err := template.New(propName).Funcs(template.FuncMap{
"property": func(propName string, rid construct.ResourceId) (any, error) {
mapping, ok := tmpl.PropertyTemplates[propName]
if !ok {
return nil, fmt.Errorf("no property template found for %s", propName)
}
refRes, err := tc.graph.Vertex(rid)
if err != nil {
return nil, err
}
inputArgs, err := tc.getInputArgs(refRes, tmpl)
if err != nil {
return nil, err
}
data := PropertyTemplateData{
Resource: rid,
Object: tc.vars[rid],
Input: inputArgs,
}
return executeToString(mapping, data)
},
}).Parse(valueBase)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("error parsing property %q: %w", propName, err))
continue
}
exportsTemplates[propName] = t
}
return exportsTemplates, errs
}
func importFuncNodeToTemplate(node *sitter.Node, name string) (*template.Template, string, error) {
importFunc := doQuery(node, findImportFunc)
imp, found := importFunc()
if !found {
return nil, "", nil
}
outputType := imp["return_type"].Content()
var expressionBody string
body := getReturn(imp["body"])
if body == nil {
return nil, "", fmt.Errorf("no 'return' found in %s body:```\n%s\n```", name, imp["body"].Content())
}
expressionBody = body.Content()
expressionBody = parameterizeArgs(expressionBody, "")
expressionBody = templateComments.ReplaceAllString(expressionBody, "")
// transform escaped double curly brace literals e.g. ~~{{ .ID }} -> {{ `{{` }} .ID }}
expressionBody = curlyEscapes.ReplaceAllString(expressionBody, "{{ `{{` }}")
tmpl, err := template.New(name).Parse(expressionBody)
return tmpl, outputType, err
}
var tsLang = typescript.GetLanguage()
func doQuery(c *sitter.Node, q string) query.NextMatchFunc {
return query.Exec(tsLang, c, q)
}
func parseFile(r io.Reader) (*sitter.Node, error) {
content, err := io.ReadAll(r)
if err != nil {
return nil, err
}
parser := sitter.NewParser()
parser.SetLanguage(tsLang)
tree, err := parser.ParseCtx(context.TODO(), nil, content)
if err != nil {
return nil, err
}
return tree.RootNode(), nil
}
// bodyContents returns the contents of a 'statement_block' with the surrounding {}
// and indentation removed so that the contents of a void function
// can be inlined with the rest of the index.ts.
func bodyContents(node *sitter.Node) string {
if node.ChildCount() == 0 || node.Child(0).Content() != "{" {
return node.Content()
}
var buf strings.Builder
buf.Grow(len(node.Content()))
for i := 0; i < int(node.NamedChildCount()); i++ {
if i > 0 {
buf.WriteRune('\n')
}
buf.WriteString(node.NamedChild(i).Content())
}
return strings.TrimSuffix(buf.String(), ";") // Remove any trailing ';' since one is added later to prevent ';;'
}
var (
curlyArgsEscapes = regexp.MustCompile(`({+)(args\.)`)
parameterizeArgsRegex = regexp.MustCompile(`\bargs(\.\w+)`)
)
// parameterizeArgs turns "args.foo" into {{.Foo}}. It's very simplistic and just works off regex
// If the source has "{args.Foo}", then just turning "args.Foo" -> "{{.Foo}}" would result in "{{{.Foo}}}", which is
// invalid go-template. So, we first turn "{args." into "{{`{`}}args.", which will eventually result in
// "{{`{`}}{{.Foo}}" — which, while ugly, will result in the correct template execution.
func parameterizeArgs(contents string, argPrefix string) string {
contents = curlyArgsEscapes.ReplaceAllString(contents, "{{`$1`}}$2")
contents = parameterizeArgsRegex.ReplaceAllString(contents, fmt.Sprintf(`{{%s$1}}`, argPrefix))
return contents
}
var (
curlyObjectEscapes = regexp.MustCompile(`({+)(object\.)`)
parameterizeObjectRegex = regexp.MustCompile(`\bobject(\.\w+)`)
)
// parameterizeObject is like [parameterizeArgs], but for "object.foo" -> "{{.Object}}.foo".
func parameterizeObject(contents string) string {
contents = curlyObjectEscapes.ReplaceAllString(contents, "{{`$1`}}$2")
contents = parameterizeObjectRegex.ReplaceAllString(contents, `{{.Object}}$1`)
return contents
}
var (
curlyPropsEscapes = regexp.MustCompile(`({+)(props\.)`)
parameterizePropsRegex = regexp.MustCompile(`\bprops\.(\w+)`)
)
// parameterizeObject is like [parameterizeArgs], but for "object.foo" -> "{{.Object}}.foo".
func parameterizeProps(contents string) string {
contents = curlyPropsEscapes.ReplaceAllString(contents, "{{`$1`}}$2")
contents = parameterizePropsRegex.ReplaceAllString(contents, `{{ property "$1" .Resource }}`)
return contents
}
func importsFromTemplate(node *sitter.Node) ([]string, error) {
imports := make(map[string]struct{})
importsQuery := doQuery(node, findImportsQuery)
for {
match, found := importsQuery()
if !found {
break
}
importLine := match["import"].Content()
// Trim any trailing semicolons. This helps normalize imports, so that we don't include them twice if one file
// includes the semicolon and the other doesn't.
importLine = strings.TrimRight(importLine, ";")
imports[importLine] = struct{}{}
}
list := make([]string, 0, len(imports))
for imp := range imports {
list = append(list, imp)
}
sort.Strings(list)
return list, nil
}
package iac
import (
"bytes"
"encoding/json"
"errors"
"io"
"io/fs"
construct "github.com/klothoplatform/klotho/pkg/construct"
kio "github.com/klothoplatform/klotho/pkg/io"
"go.uber.org/zap"
)
type TemplatesCompiler struct {
templates *templateStore
graph construct.Graph
vars variables
}
// globalVariables are variables set in the global template and available to all resources
var globalVariables = map[string]struct{}{
"kloConfig": {},
"awsConfig": {},
"protect": {},
"awsProfile": {},
"accountId": {},
"region": {},
"aws": {},
"pulumi": {},
}
type PackageJsonFile struct {
Dependencies map[string]string
DevDependencies map[string]string
OtherFields map[string]json.RawMessage
}
func (tc TemplatesCompiler) PackageJSON() (*PackageJsonFile, error) {
resources, err := construct.ReverseTopologicalSort(tc.graph)
if err != nil {
return nil, err
}
var errs error
mainPJson := PackageJsonFile{}
for _, id := range resources {
pJson, err := tc.GetPackageJSON(id)
if err != nil {
errs = errors.Join(errs, err)
continue
}
if pJson != nil {
mainPJson.Merge(pJson)
}
}
return &mainPJson, errs
}
func (tc TemplatesCompiler) GetPackageJSON(v construct.ResourceId) (*PackageJsonFile, error) {
templateFilePath := v.Provider + "/" + v.Type + `/package.json`
contents, err := fs.ReadFile(tc.templates.fs, templateFilePath)
switch {
case errors.Is(err, fs.ErrNotExist):
return nil, nil
case err != nil:
return nil, err
}
var packageContent PackageJsonFile
err = json.NewDecoder(bytes.NewReader(contents)).Decode(&packageContent)
if err != nil {
return &packageContent, err
}
return &packageContent, nil
}
func (f *PackageJsonFile) Merge(other *PackageJsonFile) {
if f.Dependencies == nil {
f.Dependencies = make(map[string]string)
}
for k, v := range other.Dependencies {
currentVersion, ok := f.Dependencies[k]
if ok {
if currentVersion != v {
zap.S().Warnf(`Found conflicting dependencies in package.json.
Found version of package, %s = %s.
Found version of package, %s = %s.
Using version %s`, k, currentVersion, k, v, currentVersion)
}
} else {
f.Dependencies[k] = v
}
}
if f.DevDependencies == nil {
f.DevDependencies = make(map[string]string)
}
for k, v := range other.DevDependencies {
f.DevDependencies[k] = v
}
// Ignore all other (non-supported / unmergeable) fields
}
func (f *PackageJsonFile) UnmarshalJSON(b []byte) error {
var m map[string]json.RawMessage
err := json.Unmarshal(b, &m)
if err != nil {
return err
}
if deps, ok := m["dependencies"]; ok {
err = json.Unmarshal(deps, &f.Dependencies)
if err != nil {
return err
}
delete(m, "dependencies")
}
if deps, ok := m["devDependencies"]; ok {
err = json.Unmarshal(deps, &f.DevDependencies)
if err != nil {
return err
}
delete(m, "devDependencies")
}
f.OtherFields = m
return nil
}
func (f *PackageJsonFile) Path() string {
return "package.json"
}
func (f *PackageJsonFile) WriteTo(w io.Writer) (n int64, err error) {
m := map[string]interface{}{
"dependencies": f.Dependencies,
"devDependencies": f.DevDependencies,
}
for k, v := range f.OtherFields {
m[k] = v
}
h := &kio.CountingWriter{Delegate: w}
enc := json.NewEncoder(h)
enc.SetIndent("", " ")
err = enc.Encode(m)
return int64(h.BytesWritten), err
}
func (f *PackageJsonFile) Clone() kio.File {
clone := &PackageJsonFile{
Dependencies: make(map[string]string, len(f.Dependencies)),
DevDependencies: make(map[string]string, len(f.DevDependencies)),
OtherFields: make(map[string]json.RawMessage, len(f.OtherFields)),
}
for k, v := range f.Dependencies {
clone.Dependencies[k] = v
}
for k, v := range f.DevDependencies {
clone.DevDependencies[k] = v
}
for k, v := range f.OtherFields {
clone.OtherFields[k] = v
}
return clone
}
package iac
import (
"fmt"
"io/fs"
construct "github.com/klothoplatform/klotho/pkg/construct"
)
type templateStore struct {
fs fs.FS
resourceTemplates map[string]*ResourceTemplate
}
func (tc *TemplatesCompiler) ResourceTemplate(id construct.ResourceId) (*ResourceTemplate, error) {
ts := tc.templates
typeName := id.QualifiedTypeName()
if ts.resourceTemplates == nil {
ts.resourceTemplates = make(map[string]*ResourceTemplate)
}
tmpl, ok := ts.resourceTemplates[typeName]
if ok {
return tmpl, nil
}
path := id.Provider + "/" + id.Type
f, err := ts.fs.Open(path + `/factory.ts`)
if err != nil {
return nil, fmt.Errorf("could not find template for %s: %w", typeName, err)
}
template, err := tc.ParseTemplate(typeName, f)
if err != nil {
return nil, fmt.Errorf("could not parse template for %s: %w", typeName, err)
}
template.Path = path
ts.resourceTemplates[typeName] = template
return template, nil
}
package iac
import (
"encoding/json"
"errors"
"fmt"
"io"
"sort"
"strconv"
construct "github.com/klothoplatform/klotho/pkg/construct"
)
// templateString is a quoted string, but evaluates as false-y in a template if the string is empty.
//
// Example:
//
// data.Value = templateString("")
//
// {{ if .Value }}value: {{ .Value }}{{ else }}no value: {{ .Value }}{{ end }}
//
// results in:
//
// no value: ""
type templateString string
func (s templateString) String() string {
return strconv.Quote(string(s))
}
type (
appliedOutput struct {
Ref string
Name string
}
appliedOutputs []appliedOutput
)
func (tc TemplatesCompiler) NewAppliedOutput(ref construct.PropertyRef, name string) appliedOutput {
ao := appliedOutput{Name: name}
if ao.Name == "" {
ao.Name = tc.vars[ref.Resource]
}
if ref.Property == "" {
ao.Ref = tc.vars[ref.Resource]
} else {
ao.Ref = fmt.Sprintf("%s.%s", tc.vars[ref.Resource], ref.Property)
}
return ao
}
func (ao *appliedOutputs) dedupe() error {
if ao == nil || len(*ao) == 0 {
return nil
}
var err error
values := make(map[appliedOutput]struct{})
names := make(map[string]struct{})
for i := 0; i < len(*ao); i++ {
v := (*ao)[i]
if _, ok := values[v]; ok {
i--
// Delete the duplicate (shift everything down)
copy((*ao)[i:], (*ao)[i+1:])
*ao = (*ao)[:len(*ao)-1]
continue
}
values[v] = struct{}{}
if _, ok := names[v.Name]; ok {
err = errors.Join(err, fmt.Errorf("duplicate applied output name %q", v.Name))
}
names[v.Name] = struct{}{}
}
sort.Sort(*ao)
return err
}
func (ao appliedOutputs) Len() int {
return len(ao)
}
func (ao appliedOutputs) Less(i, j int) bool {
if ao[i].Ref < ao[j].Ref {
return true
}
return ao[i].Name < ao[j].Name
}
func (ao appliedOutputs) Swap(i, j int) {
ao[i], ao[j] = ao[j], ao[i]
}
// Render writes the applied outputs to the given writer, running the given function in between
// as the body of the apply function.
func (ao appliedOutputs) Render(out io.Writer, f func(io.Writer) error) error {
var errs error
write := func(msg string, args ...interface{}) {
_, err := fmt.Fprintf(out, msg, args...)
errs = errors.Join(errs, err)
}
switch len(ao) {
case 0:
return nil
case 1:
write("%s.apply(%s => { return ",
ao[0].Ref,
ao[0].Name,
)
default:
write("pulumi.all([")
for i := 0; i < len(ao); i++ {
write(ao[i].Ref)
if i < len(ao)-1 {
write(", ")
}
}
write("])\n.apply(([")
for i := 0; i < len(ao); i++ {
write(ao[i].Name)
if i < len(ao)-1 {
write(", ")
}
}
write("]) => {\n return ")
}
errs = errors.Join(errs, f(out))
write("\n})")
return errs
}
// jsonValue is a value that will be marshaled to JSON when evaluated in a template.
// But also lets the value to be used as-is in template functions (such as map access or number comparisons).
//
// Unfortunately, we can't use the same trick as in [templateString] because you can't define methods on an interface,
// which
//
// type jsonValue any
//
// would be an interface. This means that the value won't be false-y and will have to access the underlying type
// via `.Raw` in the template.
type jsonValue struct {
Raw any
}
func (j jsonValue) String() string {
b, err := json.Marshal(j.Raw)
if err != nil {
// pretty unlikely to happen, but if it does, the template evaluation (via fmt.Fprint)
// with recover this panic.
// If at some point, text/template could support MarshalText (which can return an error)
// we should migrate to using that.
panic(err)
}
return string(b)
}
package iac
import (
"errors"
"fmt"
"io"
"sort"
"strings"
construct "github.com/klothoplatform/klotho/pkg/construct"
)
func (tc *TemplatesCompiler) RenderImports(out io.Writer) error {
resources, err := construct.ReverseTopologicalSort(tc.graph)
if err != nil {
return err
}
allImports := make(map[string]struct{})
var errs error
for _, r := range resources {
t, err := tc.ResourceTemplate(r)
if err != nil {
errs = errors.Join(errs, err)
continue
}
for _, statement := range t.Imports {
allImports[statement] = struct{}{}
}
}
if errs != nil {
return errs
}
sortedImports := make([]string, 0, len(allImports))
for statement := range allImports {
sortedImports = append(sortedImports, statement)
}
sort.Strings(sortedImports)
_, err = fmt.Fprintf(
out,
"%s\n",
strings.Join(sortedImports, "\n"),
)
if err != nil {
return err
}
return nil
}
package iac
import (
"strings"
construct "github.com/klothoplatform/klotho/pkg/construct"
)
type variables map[construct.ResourceId]string
var reservedVariables = map[string]struct{}{
// This list from https://github.com/microsoft/TypeScript/issues/2536#issuecomment-87194347
// typescript reserved keywords that cannot be variable names
"break": {},
"case": {},
"catch": {},
"class": {},
"const": {},
"continue": {},
"debugger": {},
"default": {},
"delete": {},
"do": {},
"else": {},
"enum": {},
"export": {},
"extends": {},
"false": {},
"finally": {},
"for": {},
"function": {},
"if": {},
"import": {},
"in": {},
"instanceof": {},
"new": {},
"null": {},
"return": {},
"super": {},
"switch": {},
"this": {},
"throw": {},
"true": {},
"try": {},
"typeof": {},
"var": {},
"void": {},
"while": {},
"with": {},
"as": {},
"implements": {},
"interface": {},
"let": {},
"package": {},
"private": {},
"protected": {},
"public": {},
"static": {},
"yield": {},
// These are reserved by klotho
"$urns": {},
"$outputs": {},
}
func VariablesFromGraph(g construct.Graph) (variables, error) {
resources, err := construct.ReverseTopologicalSort(g)
if err != nil {
return nil, err
}
vars := make(variables, len(resources))
type varInfo struct {
all []construct.ResourceId
types map[string][]construct.ResourceId
}
nameInfo := make(map[string]*varInfo)
for _, r := range resources {
info, ok := nameInfo[r.Name]
if !ok {
info = &varInfo{
types: make(map[string][]construct.ResourceId),
}
nameInfo[r.Name] = info
}
info.all = append(info.all, r)
info.types[r.Type] = append(info.types[r.Type], r)
}
sanitizeName := func(parts ...string) string {
for i, a := range parts {
a = strings.ToLower(a)
a = strings.ReplaceAll(a, "-", "_")
parts[i] = a
}
return strings.Join(parts, "_")
}
for _, r := range resources {
info := nameInfo[r.Name]
// if there's only one resource wanting the name, it gets it
if len(info.all) == 1 {
_, isGlobal := globalVariables[r.Name]
_, isReserved := reservedVariables[r.Name]
if !isGlobal && !isReserved {
vars[r] = sanitizeName(r.Name)
continue
}
}
typeResources := info.types[r.Type]
// Type + Name unambiguously identifies the resource
if len(typeResources) == 1 {
vars[r] = sanitizeName(r.Type, r.Name)
continue
}
if len(info.all) == len(typeResources) {
// Namespace + Name unambiguously identifies the resource
vars[r] = sanitizeName(r.Namespace, r.Name)
continue
}
// This doesn't account for providers being different (and the rest being the same),
// but the chances of that are low. So not implementing that until we have a real use case.
vars[r] = sanitizeName(r.Type, r.Namespace, r.Name)
}
return vars, nil
}
package kubernetes
import (
"fmt"
"math/rand"
"reflect"
"strconv"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/set"
"gopkg.in/yaml.v3"
)
type (
ObjectOutput struct {
Content []byte
Values map[string]construct.PropertyRef
}
)
var excludedObjects = []construct.ResourceId{
{Provider: "kubernetes", Type: "helm_chart"},
{Provider: "kubernetes", Type: "kustomize_directory"},
{Provider: "kubernetes", Type: "manifest"},
{Provider: "kubernetes", Type: "kube_config"},
}
func includeObjectInChart(res construct.ResourceId) bool {
for _, id := range excludedObjects {
if id.Matches(res) {
return false
}
}
return true
}
func AddObject(res *construct.Resource) (*ObjectOutput, error) {
object, err := res.GetProperty("Object")
if err != nil {
return nil, fmt.Errorf("unable to find object property on resource %s: %w", res.ID, err)
}
output := &ObjectOutput{
Values: make(map[string]construct.PropertyRef),
}
converted, err := output.convertObject(object)
if err != nil {
return nil, fmt.Errorf("unable to convert object property on resource %s: %w", res.ID, err)
}
content, err := yaml.Marshal(converted)
if err != nil {
return output, fmt.Errorf("unable to marshal object property on resource %s: %w", res.ID, err)
}
output.Content = content
return output, nil
}
func (h ObjectOutput) convertObject(arg any) (any, error) {
switch arg := arg.(type) {
case construct.ResourceId:
if arg.Provider != "kubernetes" {
return nil, fmt.Errorf("resource %s is not a kubernetes resource", arg)
}
return arg.Name, nil
case construct.PropertyRef:
valuesString := generateStringSuffix(5)
h.Values[valuesString] = arg
return fmt.Sprintf("{{ .Values.%s }}", valuesString), nil
case string:
// use templateString to quote the string value
return templateString(arg), nil
case bool, int, float64:
// safe to use as-is
return arg, nil
case nil:
// don't add to inputs
return nil, nil
default:
switch val := reflect.ValueOf(arg); val.Kind() {
case reflect.Slice, reflect.Array:
yamlList := []any{}
for i := 0; i < val.Len(); i++ {
if !val.Index(i).IsValid() || val.Index(i).IsNil() {
continue
}
output, err := h.convertObject(val.Index(i).Interface())
if err != nil {
return "", err
}
yamlList = append(yamlList, output)
}
return yamlList, nil
case reflect.Map:
yamlMap := make(map[string]any)
for _, key := range val.MapKeys() {
if !val.MapIndex(key).IsValid() || val.MapIndex(key).IsNil() {
continue
}
keyStr, found := key.Interface().(string)
if !found {
return "", fmt.Errorf("map key is not a string")
}
output, err := h.convertObject(val.MapIndex(key).Interface())
if err != nil {
return "", err
}
yamlMap[keyStr] = output
}
return yamlMap, nil
case reflect.Struct:
if hashset, ok := val.Interface().(set.HashedSet[string, any]); ok {
return h.convertObject(hashset.ToSlice())
}
fallthrough
default:
return nil, fmt.Errorf("unable to convert arg %v to yaml", arg)
}
}
}
type templateString string
func (s templateString) String() string {
return strconv.Quote(string(s))
}
func generateStringSuffix(n int) string {
var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
b := make([]rune, n)
for i := range b {
b[i] = letterRunes[rand.Intn(len(letterRunes))]
}
return string(b)
}
package kubernetes
import (
"errors"
"fmt"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/solution"
kio "github.com/klothoplatform/klotho/pkg/io"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"gopkg.in/yaml.v3"
"helm.sh/helm/v3/pkg/chart"
)
type Plugin struct {
AppName string
KB *knowledgebase.KnowledgeBase
files []kio.File
resourcesInChart map[construct.ResourceId][]construct.ResourceId
}
func (p Plugin) Name() string {
return "kubernetes"
}
const HELM_CHARTS_DIR = "helm_charts"
func (p Plugin) Translate(ctx solution.Solution) ([]kio.File, error) {
internalCharts := make(map[string]*construct.Resource)
customerCharts := make(map[string]*construct.Resource)
p.resourcesInChart = make(map[construct.ResourceId][]construct.ResourceId)
err := construct.WalkGraphReverse(ctx.DeploymentGraph(), func(id construct.ResourceId, resource *construct.Resource, nerr error) error {
if id.Provider == "kubernetes" {
if !includeObjectInChart(id) {
return nerr
}
cluster, err := resource.GetProperty("Cluster")
if err != nil {
return errors.Join(nerr, err)
}
clusterId, ok := cluster.(construct.ResourceId)
if !ok {
return errors.Join(nerr, fmt.Errorf("cluster property is not a resource id"))
}
// attempt to add to internal chart
internalChart, ok := internalCharts[clusterId.Name]
if !ok {
internalChart, err = p.createChart("internal-chart", clusterId, ctx)
if err != nil {
return errors.Join(nerr, err)
}
internalCharts[clusterId.Name] = internalChart
}
placed, err := p.placeResourceInChart(ctx, resource, internalChart)
if err != nil {
return err
}
if !placed {
// attempt to add to app chart for cluster if it cannot be in the internal chart
appChart, ok := customerCharts[clusterId.Name]
if !ok {
appChart, err = p.createChart("application-chart", clusterId, ctx)
if err != nil {
return errors.Join(nerr, err)
}
customerCharts[clusterId.Name] = appChart
}
placed, err = p.placeResourceInChart(ctx, resource, appChart)
if err != nil {
return errors.Join(nerr, err)
}
if !placed {
return errors.Join(nerr, fmt.Errorf("could not place resource %s in chart", resource.ID))
}
}
}
return nerr
})
return p.files, err
}
func (p *Plugin) placeResourceInChart(ctx solution.Solution, resource *construct.Resource, chart *construct.Resource) (
bool,
error,
) {
edges, err := ctx.DeploymentGraph().Edges()
if err != nil {
return false, err
}
edgesToRemove := make([]graph.Edge[construct.ResourceId], 0)
edgesToAdd := make([]graph.Edge[construct.ResourceId], 0)
tmpGraph, err := ctx.DeploymentGraph().Clone()
if err != nil {
return false, err
}
for _, e := range edges {
if e.Source != resource.ID && e.Target != resource.ID {
continue
}
newEdge := e
if e.Source == resource.ID {
newEdge.Source = chart.ID
}
if e.Target == resource.ID {
newEdge.Target = chart.ID
}
err = tmpGraph.RemoveEdge(e.Source, e.Target)
if err != nil {
return false, err
}
edgesToRemove = append(edgesToRemove, e)
if newEdge.Source == newEdge.Target {
continue
}
err = tmpGraph.AddEdge(newEdge.Source, newEdge.Target)
switch {
case errors.Is(err, graph.ErrEdgeCreatesCycle):
return false, nil
}
edgesToAdd = append(edgesToAdd, newEdge)
}
for _, e := range edgesToRemove {
err = ctx.DeploymentGraph().RemoveEdge(e.Source, e.Target)
if err != nil {
return false, err
}
}
for _, e := range edgesToAdd {
err = ctx.DeploymentGraph().AddEdge(e.Source, e.Target)
if err != nil {
return false, err
}
}
err = ctx.DeploymentGraph().RemoveVertex(resource.ID)
if err != nil {
return false, fmt.Errorf("could not remove vertex %s from graph: %s", resource.ID, err)
}
chartDir, err := chart.GetProperty("Directory")
if err != nil {
return false, err
}
output, err := AddObject(resource)
if err != nil {
return false, err
}
if output == nil {
return false, err
}
p.resourcesInChart[chart.ID] = append(p.resourcesInChart[chart.ID], resource.ID)
p.files = append(p.files, &kio.RawFile{
FPath: fmt.Sprintf("%s/templates/%s_%s.yaml", chartDir, resource.ID.Type, resource.ID.Name),
Content: output.Content,
})
err = chart.AppendProperty("Values", output.Values)
if err != nil {
return true, err
}
return true, nil
}
func writeChartYaml(c *construct.Resource) (kio.File, error) {
chartContent := &chart.Chart{
Metadata: &chart.Metadata{
Name: c.ID.Name,
APIVersion: "v2",
AppVersion: "0.0.1",
Version: "0.0.1",
KubeVersion: ">= 1.19.0-0",
Type: "application",
},
}
output, err := yaml.Marshal(chartContent.Metadata)
if err != nil {
return nil, err
}
directory, err := c.GetProperty("Directory")
if err != nil {
return nil, err
}
return &kio.RawFile{
FPath: fmt.Sprintf("%s/Chart.yaml", directory),
Content: output,
}, nil
}
func (p *Plugin) createChart(name string, clusterId construct.ResourceId, ctx solution.Solution) (*construct.Resource, error) {
chart, err := knowledgebase.CreateResource(ctx.KnowledgeBase(), construct.ResourceId{
Provider: "kubernetes",
Type: "helm_chart",
Namespace: clusterId.Name,
Name: name,
})
if err != nil {
return chart, err
}
chartDir := fmt.Sprintf("%s/%s/%s", HELM_CHARTS_DIR, chart.ID.Namespace, chart.ID.Name)
err = chart.SetProperty("Directory", chartDir)
if err != nil {
return chart, err
}
err = chart.SetProperty("Cluster", clusterId)
if err != nil {
return chart, err
}
err = ctx.RawView().AddVertex(chart)
if err != nil {
return chart, err
}
err = ctx.RawView().AddEdge(chart.ID, clusterId)
if err != nil {
return chart, err
}
file, err := writeChartYaml(chart)
if err != nil {
return chart, err
} else {
p.files = append(p.files, file)
}
return chart, nil
}
package stateconverter
import (
"encoding/json"
"errors"
"io"
"slices"
"strings"
"github.com/klothoplatform/klotho/pkg/construct"
statetemplate "github.com/klothoplatform/klotho/pkg/infra/state_reader/state_template"
"go.uber.org/zap"
)
type (
PulumiState []Resource
Resource struct {
Urn string `json:"urn"`
Type string `json:"type"`
Outputs map[string]interface{} `json:"outputs"`
}
pulumiStateConverter struct {
templates map[string]statetemplate.StateTemplate
}
)
func (p pulumiStateConverter) ConvertState(reader io.Reader) (State, error) {
var pulumiState PulumiState
dec := json.NewDecoder(reader)
err := dec.Decode(&pulumiState)
if err != nil {
return nil, err
}
internalModel := make(State)
var errs error
// Convert the Pulumi state to the internal model
for _, resource := range pulumiState {
mapping, ok := p.templates[resource.Type]
if !ok {
zap.S().Debugf("no mapping found for resource type %s", resource.Type)
continue
}
resource, err := p.convertResource(resource, mapping)
if err != nil {
errs = errors.Join(errs, err)
continue
}
internalModel[resource.ID] = resource.Properties
}
return internalModel, errs
}
func (p pulumiStateConverter) ConvertResource(resource Resource) (
*construct.Resource,
error,
) {
mapping, ok := p.templates[resource.Type]
if !ok {
zap.S().Debugf("no mapping found for resource type %s", resource.Type)
return nil, nil
}
return p.convertResource(resource, mapping)
}
func (p pulumiStateConverter) convertResource(resource Resource, template statetemplate.StateTemplate) (
*construct.Resource,
error,
) {
// Get the type from the resource
parts := strings.Split(resource.Urn, ":")
name := parts[len(parts)-1]
id := construct.ResourceId{
Provider: strings.Split(template.QualifiedTypeName, ":")[0],
Type: strings.Split(template.QualifiedTypeName, ":")[1],
Name: name,
}
properties := make(construct.Properties)
for k, v := range resource.Outputs {
if mapping, ok := template.PropertyMappings[k]; ok {
properties[mapping] = v
}
}
//TODO: find a better way to handle subnet types
if id.QualifiedTypeName() == "aws:subnet" {
if rawCidr, ok := properties["CidrBlock"]; ok {
if cidr, ok := rawCidr.(string); ok && slices.Contains([]string{"10.0.0.0/18", "10.0.64.0/18"}, cidr) {
properties["Type"] = "public"
} else {
properties["Type"] = "private"
}
}
}
// Convert the keys to camel case
klothoResource := &construct.Resource{
ID: id,
Properties: properties,
}
return klothoResource, nil
}
package stateconverter
import (
"io"
"github.com/klothoplatform/klotho/pkg/construct"
statetemplate "github.com/klothoplatform/klotho/pkg/infra/state_reader/state_template"
)
//go:generate mockgen -source=./state_converter.go --destination=../state_converter_mock_test.go --package=statereader
type (
State map[construct.ResourceId]construct.Properties
StateConverter interface {
// ConvertState converts the state to the Klotho state
ConvertState(io.Reader) (State, error)
ConvertResource(Resource) (*construct.Resource, error)
}
)
func NewStateConverter(provider string, templates map[string]statetemplate.StateTemplate) StateConverter {
return &pulumiStateConverter{templates: templates}
}
package statereader
import (
"errors"
"fmt"
"io"
"reflect"
"strings"
"github.com/dominikbraun/graph"
"github.com/klothoplatform/klotho/pkg/construct"
stateconverter "github.com/klothoplatform/klotho/pkg/infra/state_reader/state_converter"
statetemplate "github.com/klothoplatform/klotho/pkg/infra/state_reader/state_template"
"github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/klothoplatform/klotho/pkg/knowledgebase/properties"
)
//go:generate mockgen -source=./state_reader.go --destination=./state_reader_mock_test.go --package=statereader
type (
// StateReader is an interface for reading state from a state store
StateReader interface {
// ReadState reads the state from the state store
ReadState(io.Reader) (construct.Graph, error)
}
propertyCorrelation interface {
setProperty(
resource *construct.Resource,
property string,
value any,
) error
checkValueForReferences(
step knowledgebase.OperationalStep,
value string,
src construct.ResourceId,
propertyRef string,
) (*construct.Edge, *construct.PropertyRef, error)
}
stateReader struct {
templates map[string]statetemplate.StateTemplate
kb knowledgebase.TemplateKB
converter stateconverter.StateConverter
graph construct.Graph
}
propertyCorrelator struct {
ctx knowledgebase.DynamicValueContext
resources []*construct.Resource
}
)
func NewPulumiReader(g construct.Graph, templates map[string]statetemplate.StateTemplate, kb knowledgebase.TemplateKB) StateReader {
return &stateReader{graph: g, templates: templates, kb: kb, converter: stateconverter.NewStateConverter("pulumi", templates)}
}
func (p stateReader) ReadState(reader io.Reader) (construct.Graph, error) {
internalState, err := p.converter.ConvertState(reader)
if err != nil {
return nil, err
}
if p.graph == nil {
p.graph = construct.NewGraph()
}
if err = p.loadGraph(internalState); err != nil {
return p.graph, err
}
existingResources := make([]*construct.Resource, 0)
adj, err := p.graph.AdjacencyMap()
if err != nil {
return p.graph, err
}
for id := range adj {
r, err := p.graph.Vertex(id)
if err != nil {
return p.graph, err
}
existingResources = append(existingResources, r)
}
ctx := knowledgebase.DynamicValueContext{Graph: p.graph, KnowledgeBase: p.kb}
pc := propertyCorrelator{
ctx: ctx,
resources: existingResources,
}
if err = p.loadProperties(internalState, pc); err != nil {
return p.graph, err
}
return p.graph, nil
}
func (p stateReader) loadGraph(state stateconverter.State) error {
var errs error
for id, properties := range state {
resource, err := p.graph.Vertex(id)
if err != nil && !errors.Is(err, graph.ErrVertexNotFound) {
errs = errors.Join(errs, err)
continue
}
if resource == nil {
resource = &construct.Resource{
ID: id,
Properties: make(construct.Properties),
}
err = p.graph.AddVertex(resource)
if err != nil {
errs = errors.Join(errs, err)
continue
}
}
rt, err := p.kb.GetResourceTemplate(id)
if err != nil {
errs = errors.Join(errs, err)
continue
} else if rt == nil {
errs = errors.Join(errs, fmt.Errorf("resource template not found for resource %s", id))
continue
}
for key, value := range properties {
if !strings.Contains(key, "#") {
prop := rt.GetProperty(key)
if prop == nil {
errs = errors.Join(errs, fmt.Errorf("property %s not found in resource template %s", key, id))
continue
}
errs = errors.Join(errs, prop.SetProperty(resource, value))
}
}
}
return errs
}
func (p stateReader) loadProperties(state stateconverter.State, pc propertyCorrelation) error {
var errs error
for id, properties := range state {
resource, err := p.graph.Vertex(id)
if err != nil {
errs = errors.Join(errs, err)
continue
}
for key, value := range properties {
errs = errors.Join(errs, pc.setProperty(resource, key, value))
}
}
return errs
}
func (p propertyCorrelator) setProperty(
resource *construct.Resource,
property string,
value any,
) error {
edges := make([]*construct.Edge, 0)
rt, err := p.ctx.KnowledgeBase.GetResourceTemplate(resource.ID)
if err != nil {
return err
} else if rt == nil {
return fmt.Errorf("resource template not found for resource %s", resource.ID)
}
parts := strings.Split(property, "#")
property = parts[0]
prop := rt.GetProperty(property)
if prop == nil {
return fmt.Errorf("property %s not found in resource template %s", property, resource.ID)
}
opRule := prop.Details().OperationalRule
if opRule == nil || len(opRule.Step.Resources) == 0 {
return resource.SetProperty(property, value)
}
var ref string
if len(parts) > 1 {
ref = parts[1]
} else {
ref = opRule.Step.UsePropertyRef
}
switch rval := reflect.ValueOf(value); rval.Kind() {
case reflect.String:
edge, pref, err := p.checkValueForReferences(opRule.Step, value.(string), resource.ID, ref)
if err != nil {
return err
}
if edge != nil {
edges = append(edges, edge)
}
if pref != nil {
switch prop.(type) {
case *properties.ResourceProperty:
err = prop.SetProperty(resource, pref.Resource)
if err != nil {
return err
}
default:
err = prop.SetProperty(resource, pref)
if err != nil {
return err
}
}
}
case reflect.Slice, reflect.Array:
var val []any
for i := 0; i < rval.Len(); i++ {
edge, pref, err := p.checkValueForReferences(opRule.Step, rval.Index(i).Interface().(string), resource.ID, ref)
if err != nil {
return err
}
if edge != nil {
edges = append(edges, edge)
}
if pref != nil {
val = append(val, *pref)
}
}
collectionProp, ok := prop.(knowledgebase.CollectionProperty)
if !ok {
return fmt.Errorf("property %s is not a collection property", property)
}
switch collectionProp.Item().(type) {
case *properties.ResourceProperty:
resources := make([]construct.ResourceId, 0)
for _, v := range val {
resources = append(resources, v.(construct.PropertyRef).Resource)
}
err = prop.SetProperty(resource, resources)
if err != nil {
return err
}
default:
err = prop.SetProperty(resource, val)
if err != nil {
return err
}
}
}
for _, edge := range edges {
err := p.ctx.Graph.AddEdge(edge.Source, edge.Target)
if err != nil {
return err
}
}
return nil
}
// checkValueForReferences checks if the value of a property is a reference to another resource
// If it is a reference then it will substitute the live state value for a property ref or resource id
// if no resource exists in the live state for the reference, then it will try to create a new resource representing the value
func (p propertyCorrelator) checkValueForReferences(
step knowledgebase.OperationalStep,
value string,
src construct.ResourceId,
propertyRef string,
) (*construct.Edge, *construct.PropertyRef, error) {
var possibleIds []construct.ResourceId
data := knowledgebase.DynamicValueData{Resource: src}
for _, selector := range step.Resources {
ids, err := selector.ExtractResourceIds(p.ctx, data)
if err != nil {
return nil, nil, err
}
possibleIds = append(possibleIds, ids...)
for _, id := range ids {
for _, resource := range p.resources {
if id.Matches(resource.ID) {
val, err := p.ctx.FieldValue(propertyRef, resource.ID)
if err != nil {
return nil, nil, err
}
if val == value {
if step.Direction == knowledgebase.DirectionDownstream {
return &construct.Edge{
Source: src,
Target: resource.ID,
}, &construct.PropertyRef{
Resource: resource.ID,
Property: propertyRef,
}, nil
} else {
return &construct.Edge{
Source: resource.ID,
Target: src,
}, &construct.PropertyRef{
Resource: resource.ID,
Property: propertyRef,
}, nil
}
}
}
}
}
}
if len(step.Resources) == 1 {
idToUse := possibleIds[0]
id := construct.ResourceId{Provider: idToUse.Provider, Type: idToUse.Type, Name: value}
newRes, err := p.ctx.Graph.Vertex(id)
if err != nil && !errors.Is(err, graph.ErrVertexNotFound) {
return nil, nil, err
}
if newRes == nil {
newRes = &construct.Resource{
ID: id,
Properties: make(construct.Properties),
}
}
err = newRes.SetProperty(propertyRef, value)
if err != nil {
return nil, nil, err
}
err = p.ctx.Graph.AddVertex(newRes)
if err != nil {
return nil, nil, err
}
if step.Direction == knowledgebase.DirectionDownstream {
return &construct.Edge{
Source: src,
Target: newRes.ID,
}, &construct.PropertyRef{
Resource: newRes.ID,
Property: propertyRef,
}, nil
} else {
return &construct.Edge{
Source: newRes.ID,
Target: src,
}, &construct.PropertyRef{
Resource: newRes.ID,
Property: propertyRef,
}, nil
}
}
return nil, nil, nil
}
package statetemplate
import (
"embed"
"gopkg.in/yaml.v3"
)
type (
// StateTemplate is a template for reading state from a state store
StateTemplate struct {
// QualifiedTypeName is the qualified type name of the resource
QualifiedTypeName string `json:"qualified_type_name" yaml:"qualified_type_name"`
// IaCQualifiedType is the qualified type of the IaC resource
IaCQualifiedType string `json:"iac_qualified_type" yaml:"iac_qualified_type"`
// PropertyMappings is a map of property mappings
PropertyMappings map[string]string `json:"property_mappings" yaml:"property_mappings"`
}
)
//go:embed mappings/*/*.yaml
var PulumiTemplates embed.FS
func LoadStateTemplates(provider string) (map[string]StateTemplate, error) {
stateTemplates := make(map[string]StateTemplate)
files, err := PulumiTemplates.ReadDir("mappings/" + provider)
if err != nil {
return nil, err
}
for _, file := range files {
data, err := PulumiTemplates.ReadFile("mappings/" + provider + "/" + file.Name())
if err != nil {
return nil, err
}
var stateTemplate StateTemplate
err = yaml.Unmarshal(data, &stateTemplate)
if err != nil {
return nil, err
}
stateTemplates[stateTemplate.IaCQualifiedType] = stateTemplate
}
return stateTemplates, nil
}
package io
import "io"
type CountingWriter struct {
Delegate io.Writer
BytesWritten int
}
func (w *CountingWriter) Write(p []byte) (int, error) {
n, err := w.Delegate.Write(p)
w.BytesWritten += n
return n, err
}
package io
import (
"errors"
"fmt"
"io"
"os"
"path/filepath"
"github.com/spf13/afero"
)
type (
// FileRef is a lightweight representation of a file, deferring reading its contents until `WriteTo` is called.
FileRef struct {
FPath string
RootConfigPath string
}
)
func (r *FileRef) Clone() File {
return r
}
func (r *FileRef) Path() string {
return r.FPath
}
func (r *FileRef) WriteTo(w io.Writer) (int64, error) {
f, err := os.Open(filepath.Join(r.RootConfigPath, r.FPath))
if err != nil {
return 0, err
}
defer f.Close()
return io.Copy(w, f)
}
func OutputTo(files []File, dest string) error {
return OutputToFS(afero.NewOsFs(), files, dest)
}
func OutputToFS(fs afero.Fs, files []File, dest string) error {
errChan := make(chan error)
for idx := range files {
go func(f File) {
path := filepath.Join(dest, f.Path())
dir := filepath.Dir(path)
err := fs.MkdirAll(dir, 0777)
if err != nil {
errChan <- fmt.Errorf("could not create directory for %s: %w", path, err)
return
}
file, err := fs.OpenFile(path, os.O_RDWR, 0777)
if os.IsNotExist(err) {
file, err = fs.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0777)
} else if err == nil {
err = file.Truncate(0)
}
if err != nil {
errChan <- fmt.Errorf("could not open file for writing %s: %w", path, err)
return
}
_, err = f.WriteTo(file)
file.Close()
if err != nil {
errChan <- fmt.Errorf("could not write file %s: %w", path, err)
return
}
errChan <- nil
}(files[idx])
}
var errs error
for i := 0; i < len(files); i++ {
errs = errors.Join(errs, <-errChan)
}
return errs
}
package io
import (
"io"
)
// RawFile represents a file with its included `Content` in case the compiler needs to read/manipulate it.
// If the content is not needed except to `WriteTo`, then try using [FileRef] instead.
type RawFile struct {
FPath string
Content []byte
}
type File interface {
Path() string
WriteTo(io.Writer) (int64, error)
Clone() File
}
func (r *RawFile) Clone() File {
nf := &RawFile{
FPath: r.FPath,
}
nf.Content = make([]byte, len(r.Content))
copy(nf.Content, r.Content)
return nf
}
func (r *RawFile) Path() string {
return r.FPath
}
func (r *RawFile) WriteTo(w io.Writer) (int64, error) {
n, err := w.Write(r.Content)
return int64(n), err
}
package ioutil
import (
"fmt"
"io"
"github.com/klothoplatform/klotho/pkg/multierr"
)
type (
// WriteToHelper simplifies the use of a [io.Writer], and specifically in a way that helps you implement
// [io.WriterTo]. It does so by wrapping the Writer along with a reference to the count and err that WriterTo
// requires. When you write to the WriteToHelper, it either delegates to the Writer if there has not been an error,
// or else ignores the write if there has been. If it delegates, it also updates the count and err values.
WriteToHelper struct {
out io.Writer
count *int64
err *error
}
)
// NewWriteToHelper creates a new WriteToHelper which delegates to the given Writer and updates the given count and err
// as needed.
//
// A good pattern for how to use this is:
//
// func (wt *MyWriterTo) (w io.Writer) (count int64, err error)
// wh := ioutil.NewWriteToHelper(w, &count, &err)
// wh.Write("hello")
// wh.Write("world")
// return
// }
//
// The "wh" helper will delegate each of its writes to the "w" Writer, updating count and err as needed along the way.
// If the Writer ever returns a non-nil error, subsequent write operations on the "wh" helper will be ignored.
func NewWriteToHelper(out io.Writer, count *int64, err *error) WriteToHelper {
return WriteToHelper{
out: out,
count: count,
err: err,
}
}
func (w WriteToHelper) AddErr(err error) {
if *w.err == nil {
*w.err = err
} else if multiErr, ok := (*w.err).(multierr.Error); ok {
multiErr.Append(err)
} else {
multiErr = multierr.Error{}
multiErr.Append(*w.err)
multiErr.Append(err)
*w.err = multiErr
}
}
func (w WriteToHelper) Write(s string) {
w.Writef(`%s`, s)
}
func (w WriteToHelper) Writef(format string, a ...any) {
if *w.err != nil {
return
}
count, err := fmt.Fprintf(w.out, format, a...)
*w.count += int64(count)
*w.err = err
}
package cleanup
import (
"context"
"errors"
"os"
"os/signal"
"sync"
"syscall"
"go.uber.org/zap"
)
type Callback func(signal syscall.Signal) error
var callbacks []Callback
var callbackMu sync.Mutex
func OnKill(callback Callback) {
callbackMu.Lock()
defer callbackMu.Unlock()
callbacks = append(callbacks, callback)
}
func Execute(signal syscall.Signal) error {
callbackMu.Lock()
defer callbackMu.Unlock()
var errs []error
for _, cb := range callbacks {
if err := cb(signal); err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}
func InitializeHandler(ctx context.Context) context.Context {
ctx, cancel := context.WithCancel(ctx)
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT)
// Handle termination signals
go func() {
sig := <-sigCh
zap.S().Infof("Received signal: %v", sig)
err := Execute(sig.(syscall.Signal))
if err != nil {
zap.S().Errorf("Error running executing cleanup: %v", err)
}
cancel()
}()
return ctx
}
func SignalProcessGroup(pid int, signal syscall.Signal) {
zap.S().Infof("Sending %s signal to process group: %v", signal, pid)
// Use the negative PID to signal the entire process group
err := syscall.Kill(-pid, syscall.SIGTERM)
if err != nil {
zap.S().Errorf("Error sending %s to process group: %v", signal, err)
}
}
package constructs
import (
"errors"
"fmt"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/property"
"github.com/klothoplatform/klotho/pkg/k2/model"
"go.uber.org/zap"
"io/fs"
)
type (
BindingDeclaration struct {
From model.URN
To model.URN
Inputs map[string]model.Input
}
Binding struct {
Owner *Construct
From *Construct
To *Construct
Priority int
BindingTemplate template.BindingTemplate
Meta map[string]any
Inputs construct.Properties
Resources map[string]*Resource
Edges []*Edge
OutputDeclarations map[string]OutputDeclaration
Outputs map[string]any
InitialGraph construct.Graph
}
)
func (b *Binding) GetInputs() construct.Properties {
return b.Inputs
}
func (b *Binding) GetInputValue(name string) (value any, err error) {
return b.Inputs.GetProperty(name)
}
func (b *Binding) GetTemplateResourcesIterator() template.Iterator[string, template.ResourceTemplate] {
return b.BindingTemplate.ResourcesIterator()
}
func (b *Binding) GetTemplateEdges() []template.EdgeTemplate {
return b.BindingTemplate.Edges
}
func (b *Binding) GetEdges() []*Edge {
return b.Edges
}
func (b *Binding) SetEdges(edges []*Edge) {
b.Edges = edges
}
func (b *Binding) GetResource(resourceId string) (resource *Resource, ok bool) {
resource, ok = b.Resources[resourceId]
return
}
func (b *Binding) SetResource(resourceId string, resource *Resource) {
b.Resources[resourceId] = resource
}
func (b *Binding) GetResources() map[string]*Resource {
return b.Resources
}
func (b *Binding) GetInputRules() []template.InputRuleTemplate {
return b.BindingTemplate.InputRules
}
func (b *Binding) GetTemplateOutputs() map[string]template.OutputTemplate {
return b.BindingTemplate.Outputs
}
func (b *Binding) GetInitialGraph() construct.Graph {
return b.InitialGraph
}
func (b *Binding) DeclareOutput(key string, declaration OutputDeclaration) {
b.OutputDeclarations[key] = declaration
}
func (b *Binding) GetURN() model.URN {
if b.Owner == nil {
return model.URN{}
}
return b.Owner.GetURN()
}
func (b *Binding) String() string {
e := Edge{
From: template.ResourceRef{ConstructURN: b.From.URN},
To: template.ResourceRef{ConstructURN: b.To.URN},
}
return e.String()
}
func (b *Binding) GetPropertySource() *template.PropertySource {
ps := map[string]any{
"inputs": b.Inputs,
"resources": b.Resources,
"edges": b.Edges,
"meta": b.Meta,
}
if b.From != nil {
ps["from"] = map[string]any{
"urn": b.From.URN,
"inputs": b.From.Inputs,
"resources": b.From.Resources,
"edges": b.From.Edges,
"meta": b.From.Meta,
}
}
if b.To != nil {
ps["to"] = map[string]any{
"urn": b.To.URN,
"inputs": b.To.Inputs,
"resources": b.To.Resources,
"edges": b.To.Edges,
"meta": b.To.Meta,
"outputs": b.To.Outputs,
}
}
return template.NewPropertySource(ps)
}
func (b *Binding) GetConstruct() *Construct {
return b.Owner
}
// newBinding initializes a new binding instance using the template associated with the owner construct
// returns: the new binding instance or an error if one occurred
func (ce *ConstructEvaluator) newBinding(owner model.URN, d BindingDeclaration) (*Binding, error) {
ownerTemplateId, err := property.ParseConstructType(owner.Subtype)
if err != nil {
return nil, err
}
fromTemplateId, err := property.ParseConstructType(d.From.Subtype)
if err != nil {
return nil, err
}
toTemplateId, err := property.ParseConstructType(d.To.Subtype)
if err != nil {
return nil, err
}
oc, _ := ce.Constructs.Get(owner)
fc, _ := ce.Constructs.Get(d.From)
tc, _ := ce.Constructs.Get(d.To)
bt, err := template.LoadBindingTemplate(ownerTemplateId, fromTemplateId, toTemplateId)
var pathError *fs.PathError
if errors.As(err, &pathError) {
zap.S().Debugf("template not found for binding %s -> %s -> %s", ownerTemplateId, fromTemplateId, toTemplateId)
bt = template.BindingTemplate{
From: fromTemplateId,
To: toTemplateId,
Priority: 0,
Inputs: template.NewProperties(nil),
Outputs: make(map[string]template.OutputTemplate),
Resources: make(map[string]template.ResourceTemplate),
}
} else if err != nil {
return nil, fmt.Errorf("failed to load binding template %s -> %s -> %s: %w", ownerTemplateId.String(), fromTemplateId.String(), toTemplateId.String(), err)
}
b := &Binding{
Owner: oc,
From: fc,
To: tc,
BindingTemplate: bt,
Priority: bt.Priority,
Meta: make(map[string]any),
Inputs: make(map[string]any),
Resources: make(map[string]*Resource),
Edges: []*Edge{},
OutputDeclarations: make(map[string]OutputDeclaration),
Outputs: make(map[string]any),
InitialGraph: construct.NewGraph(),
}
inputs, err := ce.convertInputs(d.Inputs)
if err != nil {
return nil, fmt.Errorf("invalid inputs for binding %s -> %s: %w", d.From, d.To, err)
}
err = ce.initializeInputs(b, inputs)
if err != nil {
return nil, fmt.Errorf("input initialization failed for binding %s -> %s: %w", d.From, d.To, err)
}
return b, nil
}
func (b *Binding) ForEachInput(f func(property.Property) error) error {
return b.BindingTemplate.ForEachInput(b.Inputs, f)
}
package constructs
import (
"errors"
"fmt"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template"
inputs2 "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property"
"sort"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/solution"
"github.com/klothoplatform/klotho/pkg/k2/model"
)
type (
Construct struct {
URN model.URN
ConstructTemplate template.ConstructTemplate
Meta map[string]any
Inputs construct.Properties
Resources map[string]*Resource
Edges []*Edge
OutputDeclarations map[string]OutputDeclaration
Outputs map[string]any
InitialGraph construct.Graph
Bindings []*Binding
Solution solution.Solution
}
Resource struct {
Id construct.ResourceId
Properties construct.Properties
}
Edge struct {
From template.ResourceRef
To template.ResourceRef
Data construct.EdgeData
}
OutputDeclaration struct {
Name string
Ref construct.PropertyRef
Value any
}
)
func (c *Construct) GetInputValue(name string) (value any, err error) {
return c.Inputs.GetProperty(name)
}
func (c *Construct) GetTemplateResourcesIterator() template.Iterator[string, template.ResourceTemplate] {
return c.ConstructTemplate.ResourcesIterator()
}
func (c *Construct) GetTemplateEdges() []template.EdgeTemplate {
return c.ConstructTemplate.Edges
}
func (c *Construct) GetEdges() []*Edge {
return c.Edges
}
func (c *Construct) SetEdges(edges []*Edge) {
c.Edges = edges
}
func (c *Construct) GetInputRules() []template.InputRuleTemplate {
return c.ConstructTemplate.InputRules
}
func (c *Construct) GetTemplateOutputs() map[string]template.OutputTemplate {
return c.ConstructTemplate.Outputs
}
func (c *Construct) GetPropertySource() *template.PropertySource {
return template.NewPropertySource(map[string]any{
"inputs": c.Inputs,
"resources": c.Resources,
"edges": c.Edges,
"meta": c.Meta,
})
}
func (c *Construct) GetResource(resourceId string) (resource *Resource, ok bool) {
resource, ok = c.Resources[resourceId]
return
}
func (c *Construct) SetResource(resourceId string, resource *Resource) {
c.Resources[resourceId] = resource
}
func (c *Construct) GetResources() map[string]*Resource {
return c.Resources
}
func (c *Construct) GetInitialGraph() construct.Graph {
return c.InitialGraph
}
func (c *Construct) DeclareOutput(key string, declaration OutputDeclaration) {
c.OutputDeclarations[key] = declaration
}
func (c *Construct) GetURN() model.URN {
return c.URN
}
func (c *Construct) GetInputs() construct.Properties {
return c.Inputs
}
func (e *Edge) PrettyPrint() string {
return e.From.String() + " -> " + e.To.String()
}
func (e *Edge) String() string {
return e.PrettyPrint() + " :: " + fmt.Sprintf("%v", e.Data)
}
// OrderedBindings returns the bindings sorted by priority (lowest to highest).
// If two bindings have the same priority, their declaration order is preserved.
func (c *Construct) OrderedBindings() []*Binding {
if len(c.Bindings) == 0 {
return nil
}
sorted := append([]*Binding{}, c.Bindings...)
sort.SliceStable(sorted, func(i, j int) bool {
if c.Bindings[i].Priority == c.Bindings[j].Priority {
return i < j
}
return c.Bindings[i].Priority < c.Bindings[j].Priority
})
return sorted
}
func (c *Construct) GetConstruct() *Construct {
return c
}
func (c *Construct) ForEachInput(f func(input inputs2.Property) error) error {
return c.ConstructTemplate.ForEachInput(c.Inputs, f)
}
// newConstruct creates a new Construct instance from the given URN and inputs.
// The URN must be a construct URN.
// Any inputs that are not provided will be populated with default values from the construct template.
func (ce *ConstructEvaluator) newConstruct(constructUrn model.URN, i construct.Properties) (*Construct, error) {
if _, ok := i["Name"]; ok {
return nil, errors.New("'Name' is a reserved input key")
}
if !constructUrn.IsResource() || constructUrn.Type != "construct" {
return nil, errors.New("invalid construct URN")
}
/// Load the construct template
var templateId inputs2.ConstructType
err := templateId.FromURN(constructUrn)
if err != nil {
return nil, err
}
ct, err := template.LoadConstructTemplate(templateId)
if err != nil {
return nil, err
}
c := &Construct{
URN: constructUrn,
ConstructTemplate: ct,
Meta: make(map[string]any),
Inputs: make(construct.Properties),
Resources: make(map[string]*Resource),
Edges: []*Edge{},
OutputDeclarations: make(map[string]OutputDeclaration),
Outputs: make(map[string]any),
InitialGraph: construct.NewGraph(),
}
// Add the construct name to the inputs
err = c.Inputs.SetProperty("Name", constructUrn.ResourceID)
if err != nil {
return nil, err
}
err = ce.initializeInputs(c, i)
if err != nil {
return nil, err
}
return c, nil
}
package constructs
import (
"context"
"errors"
"fmt"
"github.com/klothoplatform/klotho/pkg/async"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/property"
"reflect"
"slices"
"strings"
"github.com/dominikbraun/graph"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine"
stateconverter "github.com/klothoplatform/klotho/pkg/infra/state_reader/state_converter"
"github.com/klothoplatform/klotho/pkg/k2/model"
"github.com/klothoplatform/klotho/pkg/k2/stack"
"github.com/klothoplatform/klotho/pkg/logging"
)
type ConstructEvaluator struct {
DryRun model.DryRun
stateManager *model.StateManager
stackStateManager *stack.StateManager
stateConverter stateconverter.StateConverter
Constructs *async.ConcurrentMap[model.URN, *Construct]
}
func NewConstructEvaluator(sm *model.StateManager, ssm *stack.StateManager) (*ConstructEvaluator, error) {
stateConverter, err := loadStateConverter()
if err != nil {
return nil, err
}
return &ConstructEvaluator{
stateManager: sm,
stackStateManager: ssm,
stateConverter: stateConverter,
Constructs: &async.ConcurrentMap[model.URN, *Construct]{},
}, nil
}
func (ce *ConstructEvaluator) Evaluate(constructUrn model.URN, state model.State, ctx context.Context) (engine.SolveRequest, error) {
ci, err := ce.evaluateConstruct(constructUrn, state, ctx)
if err != nil {
return engine.SolveRequest{}, fmt.Errorf("error evaluating construct %s: %w", constructUrn, err)
}
err = ce.evaluateBindings(ctx, ci)
if err != nil {
return engine.SolveRequest{}, fmt.Errorf("error evaluating bindings: %w", err)
}
marshaller := ConstructMarshaller{ConstructEvaluator: ce}
constraintList, err := marshaller.Marshal(constructUrn)
if err != nil {
return engine.SolveRequest{}, fmt.Errorf("error marshalling construct to constraints: %w", err)
}
cs, err := constraintList.ToConstraints()
if err != nil {
return engine.SolveRequest{}, fmt.Errorf("error converting constraint list to constraints: %w", err)
}
return engine.SolveRequest{
Constraints: cs,
InitialState: ci.InitialGraph,
}, nil
}
/*
evaluateInputRules evaluates the input rules of the construct
An input rule is a conditional expression that determines a set of resources, edges, and outputs based on the inputs of the construct
An input rule is evaluated by checking the if condition and then evaluating the then or else condition based on the result
the if condition is a go template that can access the inputs of the construct
input rules cannot use interpolation in the if condition
Example:
- if: {{ eq inputs("foo") "bar" }}
then:
resources:
"my-resource":
properties:
foo: "bar"
in the example input() is a function that returns the value of the input with the given key
*/
func (ce *ConstructEvaluator) evaluateInputRules(o InfraOwner) error {
for _, rule := range o.GetInputRules() {
dv := &DynamicValueData{
currentOwner: o,
}
if err := ce.evaluateInputRule(dv, rule); err != nil {
return fmt.Errorf("could not evaluate input rule: %w", err)
}
}
return nil
}
func (ce *ConstructEvaluator) evaluateInputRule(dv *DynamicValueData, rule template.InputRuleTemplate) error {
if rule.ForEach != "" {
return ce.evaluateForEachRule(dv, rule)
}
return ce.evaluateIfRule(dv, rule)
}
/*
Evaluation Order:
Construct Inputs
Construct Input Rules
Construct Resources
Construct Edges
Binding Priorities
Binding Inputs
Binding Input Rules
Binding Resources
Binding Edges
*/
func (ce *ConstructEvaluator) evaluateConstruct(constructUrn model.URN, state model.State, ctx context.Context) (*Construct, error) {
cState, ok := state.Constructs[constructUrn.ResourceID]
if !ok {
return nil, fmt.Errorf("could not get state state for construct: %s", constructUrn)
}
inputs, err := ce.convertInputs(cState.Inputs)
if err != nil {
return nil, fmt.Errorf("invalid inputs for construct: %w", err)
}
c, err := ce.newConstruct(constructUrn, inputs)
if err != nil {
return nil, fmt.Errorf("could not create construct: %w", err)
}
ce.Constructs.Set(constructUrn, c)
if err = ce.initBindings(c, state); err != nil {
return nil, fmt.Errorf("could not initialize bindings: %w", err)
}
if err = ce.importResourcesFromInputs(c, ctx); err != nil {
return nil, fmt.Errorf("could not import resources: %w", err)
}
if err = ce.evaluateResources(c); err != nil {
return nil, fmt.Errorf("could not evaluate resources: %w", err)
}
if err = ce.evaluateEdges(c); err != nil {
return nil, fmt.Errorf("could not evaluate edges: %w", err)
}
if err = ce.evaluateInputRules(c); err != nil {
return nil, err
}
if err = ce.evaluateOutputs(c); err != nil {
return nil, fmt.Errorf("could not evaluate outputs: %w", err)
}
return c, nil
}
func (ce *ConstructEvaluator) getBindingDeclarations(constructURN model.URN, state model.State) ([]BindingDeclaration, error) {
var bindings []BindingDeclaration
var err error
for _, c := range state.Constructs {
if c.URN.Equals(constructURN) {
for _, b := range c.Bindings {
bindings = append(bindings, newBindingDeclaration(constructURN, b))
}
continue
}
for _, b := range c.Bindings {
if b.URN.Equals(constructURN) {
bindings = append(bindings, newBindingDeclaration(*c.URN, b))
}
}
}
return bindings, err
}
func newBindingDeclaration(constructURN model.URN, b model.Binding) BindingDeclaration {
return BindingDeclaration{
From: constructURN,
To: *b.URN,
Inputs: b.Inputs,
}
}
func (ce *ConstructEvaluator) initBindings(c *Construct, state model.State) error {
declarations, err := ce.getBindingDeclarations(c.URN, state)
if err != nil {
return fmt.Errorf("could not get bindings: %w", err)
}
for _, d := range declarations {
if !d.From.Equals(c.URN) && !d.To.Equals(c.URN) {
return fmt.Errorf("binding %s -> %s is not valid on construct of type %s", d.From, d.To, c.ConstructTemplate.Id)
}
if _, ok := d.Inputs["from"]; ok {
return errors.New("from is a reserved input name")
}
if _, ok := d.Inputs["to"]; ok {
return errors.New("to is a reserved input name")
}
b, err := ce.newBinding(c.URN, d)
if err != nil {
return fmt.Errorf("could not create binding: %w", err)
}
c.Bindings = append(c.Bindings, b)
}
return nil
}
func (ce *ConstructEvaluator) evaluateBindings(ctx context.Context, c *Construct) error {
for _, binding := range c.OrderedBindings() {
if err := ce.evaluateBinding(ctx, binding); err != nil {
return fmt.Errorf("could not evaluate binding: %w", err)
}
}
return nil
}
func (ce *ConstructEvaluator) evaluateBinding(ctx context.Context, b *Binding) error {
if b == nil {
return fmt.Errorf("binding is nil")
}
owner := b.Owner
if owner == nil {
return fmt.Errorf("binding owner is nil")
}
if b.BindingTemplate.From.Name == "" || b.BindingTemplate.To.Name == "" {
return nil // assume that this binding does not modify the current construct
}
if err := ce.importResourcesFromInputs(b, ctx); err != nil {
return fmt.Errorf("could not import resources: %w", err)
}
if b.From != nil && owner.URN.Equals(b.From.GetURN()) {
// only import "to" resources if the binding is from the current construct
if err := ce.importBindingToResources(ctx, b); err != nil {
return fmt.Errorf("could not import binding resources: %w", err)
}
}
if err := ce.evaluateResources(b); err != nil {
return fmt.Errorf("could not evaluate resources: %w", err)
}
if err := ce.evaluateEdges(b); err != nil {
return fmt.Errorf("could not evaluate edges: %w", err)
}
if err := ce.evaluateInputRules(b); err != nil {
return fmt.Errorf("could not evaluate input rules: %w", err)
}
if err := ce.evaluateOutputs(b); err != nil {
return fmt.Errorf("could not evaluate outputs: %w", err)
}
if err := ce.applyBinding(b.Owner, b); err != nil {
return fmt.Errorf("could not apply bindings: %w", err)
}
return nil
}
func (ce *ConstructEvaluator) evaluateEdges(o InfraOwner) error {
dv := &DynamicValueData{
currentOwner: o,
propertySource: o.GetPropertySource(),
}
for _, edge := range o.GetTemplateEdges() {
e, err := ce.resolveEdge(dv, edge)
if err != nil {
return fmt.Errorf("could not resolve edge: %w", err)
}
o.SetEdges(append(o.GetEdges(), e))
}
return nil
}
// applyBinding applies the bindings to the construct by merging the resources, edges, and output declarations
// of the construct's bindings with the construct's resources, edges, and output declarations
func (ce *ConstructEvaluator) applyBinding(c *Construct, binding *Binding) error {
log := logging.GetLogger(context.Background()).Sugar()
// Merge resources
for key, bRes := range binding.Resources {
if res, exists := c.Resources[key]; exists {
res.Properties = mergeProperties(res.Properties, bRes.Properties)
} else {
c.Resources[key] = bRes
}
}
// Merge edges
for _, edge := range binding.Edges {
if !edgeExists(c.Edges, edge) {
c.Edges = append(c.Edges, edge)
}
}
// Merge output declarations
for key, output := range binding.OutputDeclarations {
if _, exists := c.OutputDeclarations[key]; !exists {
c.OutputDeclarations[key] = output
} else {
// If output already exists, log a warning or handle the conflict as needed
log.Warnf("Output %s already exists in construct, skipping binding output", key)
}
}
// upsert the vertices
ids, err := construct.TopologicalSort(binding.InitialGraph)
if err != nil {
return fmt.Errorf("could not topologically sort binding %s graph: %w", binding, err)
}
resources, err := construct.ResolveIds(binding.InitialGraph, ids)
if err != nil {
return fmt.Errorf("could not resolve ids from binding %s graph: %w", binding, err)
}
for _, vertex := range resources {
if err := c.InitialGraph.AddVertex(vertex); err != nil {
if errors.Is(err, graph.ErrVertexAlreadyExists) {
log.Debugf("Vertex already exists, skipping: %v", vertex)
continue
}
return fmt.Errorf("could not add vertex %v from binding %s graph: %w", vertex, binding, err)
}
}
// upsert the edges
edges, err := binding.InitialGraph.Edges()
if err != nil {
return fmt.Errorf("could not get edges from binding %s graph: %w", binding, err)
}
for _, edge := range edges {
// Attempt to add the edge to the initial graph
err = c.InitialGraph.AddEdge(edge.Source, edge.Target)
if err != nil {
if errors.Is(err, graph.ErrEdgeAlreadyExists) {
// Skip this edge if it already exists
log.Debugf("Edge already exists, skipping: %v -> %v\n", edge.Source, edge.Target)
continue
}
return fmt.Errorf("could not add edge %v -> %v from binding %s graph: %w", edge.Source, edge.Target, binding, err)
}
}
return nil
}
func mergeProperties(existing, new construct.Properties) construct.Properties {
merged := make(construct.Properties)
for k, v := range existing {
merged[k] = v
}
for k, v := range new {
// If property exists in both, prefer the new value
merged[k] = v
}
return merged
}
func edgeExists(edges []*Edge, newEdge *Edge) bool {
for _, edge := range edges {
if edge.From == newEdge.From && edge.To == newEdge.To {
return true
}
}
return false
}
func (ce *ConstructEvaluator) evaluateResources(o InfraOwner) error {
var err error
dv := &DynamicValueData{
currentOwner: o,
propertySource: o.GetPropertySource(),
}
ri := o.GetTemplateResourcesIterator()
ri.ForEach(func(key string, resource template.ResourceTemplate) error {
var r *Resource
r, err = ce.resolveResource(dv, key, resource)
if err != nil {
return template.StopIteration
}
o.SetResource(key, r)
return nil
})
if err != nil {
return err
}
return nil
}
func GetPropertyFunc(ps *template.PropertySource, path string) func(string) any {
return func(key string) any {
i, ok := ps.GetProperty(fmt.Sprintf("%s.%s", path, key))
if !ok {
return nil
}
return i
}
}
func (ce *ConstructEvaluator) evaluateForEachRule(dv *DynamicValueData, rule template.InputRuleTemplate) error {
parentPrefix := dv.resourceKeyPrefix
ctx := DynamicValueContext{
constructs: ce.Constructs,
}
var selected bool
if err := ctx.ExecuteUnmarshal(rule.ForEach, dv, &selected); err != nil {
return fmt.Errorf("result parsing failed: %w", err)
}
if !selected {
return nil
}
for _, hasNext := dv.currentSelection.Next(); hasNext; _, hasNext = dv.currentSelection.Next() {
prefix, err := ce.interpolateValue(dv, rule.Prefix)
if err != nil {
return fmt.Errorf("could not interpolate resource prefix: %w", err)
}
dv := &DynamicValueData{
currentOwner: dv.currentOwner,
currentSelection: dv.currentSelection,
propertySource: dv.propertySource,
}
if prefix != "" && prefix != nil {
if parentPrefix != "" {
dv.resourceKeyPrefix = strings.Join([]string{parentPrefix, fmt.Sprintf("%s", prefix)}, ".")
} else {
dv.resourceKeyPrefix = fmt.Sprintf("%s", prefix)
}
} else {
dv.resourceKeyPrefix = parentPrefix
}
ri := rule.Do.ResourcesIterator()
ri.ForEach(func(key string, resource template.ResourceTemplate) error {
if dv.resourceKeyPrefix != "" {
key = fmt.Sprintf("%s.%s", dv.resourceKeyPrefix, key)
}
r, err := ce.resolveResource(dv, key, resource)
if err != nil {
return fmt.Errorf("could not resolve resource %s : %w", key, err)
}
dv.currentOwner.SetResource(key, r)
return nil
})
for _, edge := range rule.Do.Edges {
e, err := ce.resolveEdge(dv, edge)
if err != nil {
return fmt.Errorf("could not resolve edge: %w", err)
}
dv.currentOwner.SetEdges(append(dv.currentOwner.GetEdges(), e))
}
for _, rule := range rule.Do.Rules {
if err := ce.evaluateInputRule(dv, rule); err != nil {
return fmt.Errorf("could not evaluate input rule: %w", err)
}
}
}
return nil
}
func (ce *ConstructEvaluator) evaluateIfRule(dv *DynamicValueData, rule template.InputRuleTemplate) error {
parentPrefix := dv.resourceKeyPrefix
prefix, err := ce.interpolateValue(dv, rule.Prefix)
if err != nil {
return fmt.Errorf("could not interpolate resource prefix: %w", err)
}
dv = &DynamicValueData{
currentOwner: dv.currentOwner,
currentSelection: dv.currentSelection,
propertySource: dv.propertySource,
}
if prefix != "" && prefix != nil {
if parentPrefix != "" {
dv.resourceKeyPrefix = strings.Join([]string{parentPrefix, fmt.Sprintf("%s", prefix)}, ".")
} else {
dv.resourceKeyPrefix = fmt.Sprintf("%s", prefix)
}
} else {
dv.resourceKeyPrefix = parentPrefix
}
ctx := DynamicValueContext{
constructs: ce.Constructs,
}
var boolResult bool
err = ctx.ExecuteUnmarshal(rule.If, dv, &boolResult)
if err != nil {
return fmt.Errorf("result parsing failed: %w", err)
}
executeThen := boolResult
var body template.ConditionalExpressionTemplate
if executeThen && rule.Then != nil {
body = *rule.Then
} else if rule.Else != nil {
body = *rule.Else
}
ri := body.ResourcesIterator()
ri.ForEach(func(key string, resource template.ResourceTemplate) error {
if dv.resourceKeyPrefix != "" {
key = fmt.Sprintf("%s.%s", dv.resourceKeyPrefix, key)
}
r, err := ce.resolveResource(dv, key, resource)
if err != nil {
return fmt.Errorf("could not resolve resource %s: %w", key, err)
}
dv.currentOwner.SetResource(key, r)
return nil
})
for _, edge := range body.Edges {
e, err := ce.resolveEdge(dv, edge)
if err != nil {
return fmt.Errorf("could not resolve edge: %w", err)
}
dv.currentOwner.SetEdges(append(dv.currentOwner.GetEdges(), e))
}
for _, rule := range body.Rules {
if err := ce.evaluateInputRule(dv, rule); err != nil {
return fmt.Errorf("could not evaluate input rule: %w", err)
}
}
return nil
}
func (ce *ConstructEvaluator) resolveResource(dv *DynamicValueData, key string, rt template.ResourceTemplate) (*Resource, error) {
// update the resource if it already exists
if dv.currentOwner == nil {
return nil, fmt.Errorf("current owner is nil")
}
resource, ok := dv.currentOwner.GetResource(key)
if !ok {
resource = &Resource{Properties: map[string]any{}}
}
tmpl, err := ce.interpolateValue(dv, rt)
if err != nil {
return nil, fmt.Errorf("could not interpolate resource %s: %w", key, err)
}
resTmpl := tmpl.(template.ResourceTemplate)
typeParts := strings.Split(resTmpl.Type, ":")
if len(typeParts) != 2 && resTmpl.Type != "" {
return nil, fmt.Errorf("invalid resource type: %s", resTmpl.Type)
}
if len(typeParts) == 2 {
provider := typeParts[0]
resourceType := typeParts[1]
id := construct.ResourceId{
Provider: provider,
Type: resourceType,
Namespace: resTmpl.Namespace,
Name: resTmpl.Name,
}
if resource.Id == (construct.ResourceId{}) {
resource.Id = id
} else if resource.Id != id {
return nil, fmt.Errorf("resource id mismatch: %s", key)
}
}
// #TODO: deep merge the properties by evaluating the properties recursively
// merge the properties
for k, v := range resTmpl.Properties {
// if the base resource does not have the property, set the property
if resource.Properties[k] == nil {
resource.Properties[k] = v
continue
}
// if the property is a map, merge the map
vt := reflect.TypeOf(v)
switch vt.Kind() {
case reflect.Map:
for mk, mv := range v.(map[string]any) {
resource.Properties[k].(map[string]any)[mk] = mv
}
case reflect.Slice:
for _, mv := range v.([]any) {
resource.Properties[k] = append(resource.Properties[k].([]any), mv)
}
default:
resource.Properties[k] = v
}
}
return resource, nil
}
func (ce *ConstructEvaluator) resolveEdge(dv *DynamicValueData, edge template.EdgeTemplate) (*Edge, error) {
from, err := ce.interpolateValue(dv, edge.From)
if err != nil {
return nil, err
}
if from == nil {
return nil, fmt.Errorf("from is nil")
}
to, err := ce.interpolateValue(dv, edge.To)
if err != nil {
return nil, err
}
if to == nil {
return nil, fmt.Errorf("to is nil")
}
data, err := ce.interpolateValue(dv, edge.Data)
if err != nil {
return nil, err
}
return &Edge{
From: from.(template.ResourceRef),
To: to.(template.ResourceRef),
Data: data.(construct.EdgeData),
}, nil
}
func (ce *ConstructEvaluator) evaluateOutputs(o InfraOwner) error {
// sort the keys of the outputs alphabetically to ensure deterministic ordering
sortKeys := func(m map[string]template.OutputTemplate) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
slices.Sort(keys)
return keys
}
outputs := o.GetTemplateOutputs()
keys := sortKeys(outputs)
for _, key := range keys {
ot := outputs[key]
dv := &DynamicValueData{
currentOwner: o,
propertySource: o.GetPropertySource(),
}
output, err := ce.interpolateValue(dv, ot)
if err != nil {
return fmt.Errorf("failed to interpolate value for output %s: %w", key, err)
}
outputTemplate, ok := output.(template.OutputTemplate)
if !ok {
return fmt.Errorf("invalid output template for output %s", key)
}
var value any
var ref construct.PropertyRef
r, ok := outputTemplate.Value.(template.ResourceRef)
if !ok {
value = outputTemplate.Value
} else {
serializedRef, err := ce.marshalRef(o, r)
if err != nil {
return fmt.Errorf("failed to serialize ref for output %s: %w", key, err)
}
var refString string
if sr, ok := serializedRef.(string); ok {
refString = sr
} else if sr, ok := serializedRef.(fmt.Stringer); ok {
refString = sr.String()
} else {
return fmt.Errorf("invalid ref string for output %s", key)
}
err = ref.Parse(refString)
if err != nil {
return fmt.Errorf("failed to parse ref string for output %s: %w", key, err)
}
}
if ref != (construct.PropertyRef{}) && value != nil {
return fmt.Errorf("output declaration must be a reference or a value for output %s", key)
}
o.DeclareOutput(key, OutputDeclaration{
Name: key,
Ref: ref,
Value: value,
})
}
return nil
}
func (ce *ConstructEvaluator) convertInputs(inputs map[string]model.Input) (construct.Properties, error) {
props := make(construct.Properties)
for k, v := range inputs {
if ce.DryRun == 0 && v.Status != model.InputStatusResolved {
return nil, fmt.Errorf("input %s is not resolved", k)
}
props[k] = v.Value
}
return props, nil
}
type HasInputs interface {
ForEachInput(f func(input property.Property) error) error
GetInputs() construct.Properties
}
func (ce *ConstructEvaluator) initializeInputs(c HasInputs, i construct.Properties) error {
var inputErrors error
_ = c.ForEachInput(func(input property.Property) error {
v, err := i.GetProperty(input.Details().Path)
if err == nil {
if (v == nil || v == input.ZeroValue()) && input.Details().Required {
inputErrors = errors.Join(inputErrors, fmt.Errorf("input %s is required", input.Details().Path))
return nil
}
if err = input.SetProperty(c.GetInputs(), v); err != nil {
inputErrors = errors.Join(inputErrors, err)
return nil
}
} else if errors.Is(err, construct.ErrPropertyDoesNotExist) {
if dv, err := input.GetDefaultValue(DynamicValueContext{}, nil); err == nil {
if dv == nil {
dv = input.ZeroValue()
}
if (dv == nil || dv == input.ZeroValue()) && input.Details().Required {
inputErrors = errors.Join(inputErrors, fmt.Errorf("input %s is required", input.Details().Path))
return nil
}
if dv == nil {
return nil // no default value (e.g., for collections or other types with type arguments, i.e., generics)
}
if err = input.SetProperty(c.GetInputs(), dv); err != nil {
inputErrors = errors.Join(inputErrors, err)
return nil
}
}
} else {
inputErrors = errors.Join(inputErrors, err)
}
return nil
})
return inputErrors
}
package constructs
import (
"fmt"
"reflect"
"sort"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template"
"github.com/klothoplatform/klotho/pkg/reflectutil"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/constraints"
"github.com/klothoplatform/klotho/pkg/k2/model"
)
type (
// ConstructMarshaller is a struct that marshals a Construct into a list of constraints
ConstructMarshaller struct {
ConstructEvaluator *ConstructEvaluator
}
)
// Marshal marshals a Construct into a list of constraints
func (m *ConstructMarshaller) Marshal(constructURN model.URN) (constraints.ConstraintList, error) {
var cs constraints.ConstraintList
c, ok := m.ConstructEvaluator.Constructs.Get(constructURN)
if !ok {
return nil, fmt.Errorf("could not find construct %s", constructURN)
}
for _, r := range c.Resources {
resourceConstraints, err := m.marshalResource(c, r)
if err != nil {
return nil, fmt.Errorf("could not marshal resource: %w", err)
}
cs = append(cs, resourceConstraints...)
}
for _, e := range c.Edges {
edgeConstraints, err := m.marshalEdge(c, e)
if err != nil {
return nil, fmt.Errorf("could not marshall edge: %w", err)
}
cs = append(cs, edgeConstraints...)
}
for _, o := range c.OutputDeclarations {
outputConstraints, err := m.marshalOutput(o)
if err != nil {
return nil, fmt.Errorf("could not marshall output: %w", err)
}
cs = append(cs, outputConstraints...)
}
sort.SliceStable(cs, cs.NaturalSort)
return cs, nil
}
func (m *ConstructMarshaller) marshalResource(o InfraOwner, r *Resource) (constraints.ConstraintList, error) {
var cs constraints.ConstraintList
cs = append(cs, &constraints.ApplicationConstraint{
Operator: "must_exist",
Node: r.Id,
})
for k, v := range r.Properties {
v, err := m.marshalRefs(o, v)
if err != nil {
return nil, fmt.Errorf("could not marshal resource properties: %w", err)
}
cs = append(cs, &constraints.ResourceConstraint{
Operator: "equals",
Target: r.Id,
Property: k,
Value: v,
})
}
return cs, nil
}
// marshalEdge marshals an Edge into a list of constraints
func (m *ConstructMarshaller) marshalEdge(o InfraOwner, e *Edge) (constraints.ConstraintList, error) {
var from construct.ResourceId
ref, err := m.ConstructEvaluator.marshalRef(o, e.From)
if err != nil {
return nil, fmt.Errorf("could not serialize from resource id: %w", err)
}
if idRef, ok := ref.(construct.ResourceId); ok {
from = idRef
} else {
err = from.Parse(ref.(string))
}
if err != nil {
return nil, fmt.Errorf("could not parse from resource id: %w", err)
}
var to construct.ResourceId
ref, err = m.ConstructEvaluator.marshalRef(o, e.To)
if err != nil {
return nil, fmt.Errorf("could not serialize to resource id: %w", err)
}
if idRef, ok := ref.(construct.ResourceId); ok {
to = idRef
} else {
err = to.Parse(ref.(string))
}
if err != nil {
return nil, fmt.Errorf("could not parse to resource id: %w", err)
}
v, err := m.marshalRefs(o, e.Data)
if err != nil {
return nil, fmt.Errorf("could not marshal resource properties: %w", err)
}
return constraints.ConstraintList{&constraints.EdgeConstraint{
Operator: "must_exist",
Target: constraints.Edge{
Source: from,
Target: to,
},
Data: v.(construct.EdgeData),
}}, nil
}
// marshalOutput marshals an OutputDeclaration into a list of constraints
func (m *ConstructMarshaller) marshalOutput(o OutputDeclaration) (constraints.ConstraintList, error) {
var cs constraints.ConstraintList
c := &constraints.OutputConstraint{
Operator: "must_exist",
Name: o.Name,
}
if o.Ref != (construct.PropertyRef{}) {
c.Ref = o.Ref
} else {
c.Value = o.Value
}
cs = append(cs, c)
return cs, nil
}
func (m *ConstructMarshaller) marshalRefs(o InfraOwner, rawVal any) (any, error) {
if rawVal == nil {
return rawVal, nil
}
switch val := rawVal.(type) {
case *template.ResourceRef:
if val == nil {
return rawVal, nil
}
return m.ConstructEvaluator.marshalRef(o, *val)
case template.ResourceRef:
return m.ConstructEvaluator.marshalRef(o, val)
case construct.ResourceId, construct.PropertyRef:
return rawVal, nil
}
ref := reflect.ValueOf(rawVal)
if ref.Kind() == reflect.Ptr {
if ref.IsNil() {
return rawVal, nil
}
ref = ref.Elem()
}
if !ref.IsValid() || ref.IsZero() {
return rawVal, nil
}
switch ref.Kind() {
case reflect.Struct:
for i := 0; i < ref.NumField(); i++ {
field := ref.Field(i)
fieldValue := reflectutil.GetConcreteElement(field)
if field.CanInterface() {
if _, ok := fieldValue.Interface().(template.ResourceRef); ok {
// If we encounter a ResourceRef in a struct, we skip it
// Since the result is not also a ResourceRef
continue
}
}
if !field.CanSet() {
continue
}
if _, err := m.marshalRefs(o, fieldValue.Interface()); err != nil {
return nil, err
}
}
case reflect.Map:
for _, key := range ref.MapKeys() {
field := reflectutil.GetConcreteElement(ref.MapIndex(key))
if !field.IsValid() || field.IsZero() {
continue
}
serializedField, err := m.marshalRefs(o, field.Interface())
if err != nil {
return nil, err
}
ref.SetMapIndex(key, reflect.ValueOf(serializedField))
}
case reflect.Slice, reflect.Array:
for i := 0; i < ref.Len(); i++ {
field := reflectutil.GetConcreteElement(ref.Index(i))
if !field.IsValid() || field.IsZero() {
continue
}
serializedField, err := m.marshalRefs(o, field.Interface())
if err != nil {
return nil, err
}
ref.Index(i).Set(reflect.ValueOf(serializedField))
}
}
return rawVal, nil
}
package constructs
import (
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/property"
"github.com/klothoplatform/klotho/pkg/k2/model"
)
type (
ResourceOwner interface {
GetResource(resourceId string) (resource *Resource, ok bool)
SetResource(resourceId string, resource *Resource)
GetResources() map[string]*Resource
GetTemplateResourcesIterator() template.Iterator[string, template.ResourceTemplate]
template.InterpolationSource
}
EdgeOwner interface {
GetTemplateEdges() []template.EdgeTemplate
GetEdges() []*Edge
SetEdges(edges []*Edge)
template.InterpolationSource
}
InfraOwner interface {
GetURN() model.URN
GetInputRules() []template.InputRuleTemplate
ResourceOwner
EdgeOwner
GetTemplateOutputs() map[string]template.OutputTemplate
DeclareOutput(key string, declaration OutputDeclaration)
ForEachInput(f func(input property.Property) error) error
GetInputValue(name string) (value any, err error)
GetInitialGraph() construct.Graph
GetConstruct() *Construct
}
)
// marshalRef marshals a resource reference into a [construct.ResourceId] or [construct.PropertyRef]
func (ce *ConstructEvaluator) marshalRef(owner InfraOwner, ref template.ResourceRef) (any, error) {
var resourceId construct.ResourceId
r, ok := owner.GetResource(ref.ResourceKey)
if ok {
resourceId = r.Id
} else {
err := resourceId.Parse(ref.ResourceKey)
if err != nil {
return nil, err
}
}
if ref.Property != "" {
return construct.PropertyRef{
Resource: resourceId,
Property: ref.Property,
}, nil
}
return resourceId, nil
}
package constructs
import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"slices"
"strings"
"text/template"
"github.com/klothoplatform/klotho/pkg/async"
"github.com/klothoplatform/klotho/pkg/construct"
template2 "github.com/klothoplatform/klotho/pkg/k2/constructs/template"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/properties"
"github.com/klothoplatform/klotho/pkg/k2/model"
"github.com/klothoplatform/klotho/pkg/templateutils"
"go.uber.org/zap"
)
type DynamicValueContext struct {
constructs *async.ConcurrentMap[model.URN, *Construct]
}
type DynamicValueData struct {
currentOwner InfraOwner
currentSelection DynamicValueSelection
propertySource *template2.PropertySource
resourceKeyPrefix string
}
func (ctx DynamicValueContext) TemplateFunctions() template.FuncMap {
return templateutils.WithCommonFuncs(template.FuncMap{
"fieldRef": ctx.FieldRef,
"pathAncestor": ctx.PathAncestor,
"pathAncestorExists": ctx.PathAncestorExists,
"toJSON": ctx.toJson,
})
}
func (ctx DynamicValueContext) Parse(tmpl string) (*template.Template, error) {
t, err := template.New("config").Funcs(ctx.TemplateFunctions()).Parse(tmpl)
return t, err
}
func (ctx DynamicValueContext) ExecuteUnmarshal(tmpl string, data any, value any) error {
t, err := ctx.Parse(tmpl)
if err != nil {
return err
}
return ctx.ExecuteTemplateUnmarshal(t, data, value)
}
func (ctx DynamicValueContext) Unmarshal(data *bytes.Buffer, v any) error {
return properties.UnmarshalAny(data, v)
}
// ExecuteTemplateUnmarshal executes the template tmpl using data as input and unmarshals the value into v
func (ctx DynamicValueContext) ExecuteTemplateUnmarshal(
t *template.Template,
data any,
v any,
) error {
buf := new(bytes.Buffer)
if err := t.Execute(buf, data); err != nil {
return err
}
if err := ctx.Unmarshal(buf, v); err != nil {
return fmt.Errorf("cannot decode template result '%s' into %T", buf, v)
}
return nil
}
// Self returns the owner of this dynamic value
func (data *DynamicValueData) Self() any {
return data.currentOwner
}
// Selected returns the current selection in the dynamic value data
func (data *DynamicValueData) Selected() DynamicValueSelection {
return data.currentSelection
}
func (data *DynamicValueData) Select(path string) bool {
var ps *template2.PropertySource
if data.currentSelection.Value != nil {
ps = template2.NewPropertySource(data.currentSelection.Value)
} else {
ps = data.propertySource
if ps == nil {
ps = data.currentOwner.GetPropertySource()
}
}
if v, ok := ps.GetProperty(path); ok {
s := SelectItem(v)
data.currentSelection = s
return true
}
return false
}
// Inputs returns the inputs of the current owner
func (data *DynamicValueData) Inputs() any {
ps := data.propertySource
if ps == nil {
ps = data.currentOwner.GetPropertySource()
}
val, _ := ps.GetProperty("inputs")
return val
}
// Resources returns the resources of the current owner
func (data *DynamicValueData) Resources() any {
ps := data.propertySource
if ps == nil {
ps = data.currentOwner.GetPropertySource()
}
val, _ := ps.GetProperty("resources")
return val
}
// Edges returns the edges of the current owner
func (data *DynamicValueData) Edges() any {
ps := data.propertySource
if ps == nil {
ps = data.currentOwner.GetPropertySource()
}
val, _ := ps.GetProperty("edges")
return val
}
// Meta returns the metadata of the current owner
func (data *DynamicValueData) Meta() any {
ps := data.propertySource
if ps == nil {
ps = data.currentOwner.GetPropertySource()
}
val, _ := ps.GetProperty("meta")
return val
}
func (data *DynamicValueData) Prefix() string {
return data.resourceKeyPrefix
}
// From returns the 'from' construct if the current owner is a binding
func (data *DynamicValueData) From() any {
ps := data.propertySource
if ps == nil {
ps = data.currentOwner.GetPropertySource()
}
val, _ := ps.GetProperty("from")
return val
}
// To returns the 'to' construct if the current owner is a binding
func (data *DynamicValueData) To() any {
ps := data.propertySource
if ps == nil {
ps = data.currentOwner.GetPropertySource()
}
val, _ := ps.GetProperty("to")
return val
}
// Log is primarily used for debugging templates and only be used in development to log messages to the console
func (data *DynamicValueData) Log(level string, message string, args ...interface{}) string {
l := zap.L()
ownerType := reflect.TypeOf(data.currentOwner).Kind().String()
ownerString := "unknown"
l = l.With(zap.String(ownerType, ownerString))
switch strings.ToLower(level) {
case "debug":
l.Sugar().Debugf(message, args...)
case "info":
l.Sugar().Infof(message, args...)
case "warn":
l.Sugar().Warnf(message, args...)
case "error":
l.Sugar().Errorf(message, args...)
default:
l.Sugar().Warnf(message, args...)
}
return ""
}
// toJson is used to return complex values that do not have TextUnmarshaler implemented
func (ctx DynamicValueContext) toJson(value any) (string, error) {
j, err := json.Marshal(value)
if err != nil {
return "", err
}
return string(j), nil
}
func (ctx DynamicValueContext) PathAncestor(path construct.PropertyPath, depth int) (string, error) {
if depth < 0 {
return "", fmt.Errorf("depth must be >= 0")
}
if depth == 0 {
return path.String(), nil
}
if len(path) <= depth {
return "", fmt.Errorf("depth %d is greater than path length %d", depth, len(path))
}
return path[:len(path)-depth].String(), nil
}
func (ctx DynamicValueContext) PathAncestorExists(path construct.PropertyPath, depth int) bool {
return len(path) > depth
}
// FieldRef returns a reference to `field` on `resource` (as a PropertyRef)
func (ctx DynamicValueContext) FieldRef(field string, resource any) (construct.PropertyRef, error) {
resId, err := TemplateArgToRID(resource)
if err != nil {
return construct.PropertyRef{}, err
}
return construct.PropertyRef{
Resource: resId,
Property: field,
}, nil
}
func TemplateArgToRID(arg any) (construct.ResourceId, error) {
switch arg := arg.(type) {
case construct.ResourceId:
return arg, nil
case construct.Resource:
return arg.ID, nil
case string:
var resId construct.ResourceId
err := resId.UnmarshalText([]byte(arg))
return resId, err
}
return construct.ResourceId{}, fmt.Errorf("invalid argument type %T", arg)
}
type DynamicValueSelection struct {
Source any
mapKeys []reflect.Value
next int
Key string
Value any
Index int
}
func SelectItem(src any) DynamicValueSelection {
srcValue := reflect.ValueOf(src)
switch srcValue.Kind() {
case reflect.Map:
if !srcValue.IsValid() || srcValue.Len() == 0 {
return DynamicValueSelection{
Source: src,
}
}
keys := srcValue.MapKeys()
slices.SortStableFunc(keys, func(i, j reflect.Value) int {
return strings.Compare(stringValue(i.Interface()), stringValue(j.Interface()))
})
if len(keys) == 0 {
return DynamicValueSelection{
Source: src,
}
}
return DynamicValueSelection{
Source: src,
mapKeys: keys,
}
default:
return DynamicValueSelection{
Source: src,
}
}
}
// Next returns the next value in the selection and whether there are more values
// If the selection is a map, the key is also returned.
// If the selection is a slice, the index is returned instead.
// If there are no more values, the second return value is false.
//
// This function is intended to be used by an orchestration layer across multiple go templates
// and is unavailable inside the templates themselves
func (s *DynamicValueSelection) Next() (any, bool) {
srcValue := reflect.ValueOf(s.Source)
if !srcValue.IsValid() {
return nil, false
}
if len(s.mapKeys) > 0 {
if s.next >= len(s.mapKeys) {
return nil, false
}
key := s.mapKeys[s.next]
value := srcValue.MapIndex(key).Interface()
s.Value = value
s.next++
return value, true
}
if s.Index >= srcValue.Len() {
return nil, false
}
value := srcValue.Index(s.Index).Interface()
s.Value = value
s.Index++
return value, true
}
func stringValue(v any) string {
return fmt.Sprintf("%v", v)
}
package graph
import (
"github.com/dominikbraun/graph"
"github.com/klothoplatform/klotho/pkg/graph_addons"
"github.com/klothoplatform/klotho/pkg/k2/model"
)
type (
Graph = graph.Graph[model.URN, model.URN]
Edge = graph.Edge[model.URN]
)
func NewGraphWithOptions(options ...func(*graph.Traits)) Graph {
return graph.NewWithStore(
UrnHasher,
graph_addons.NewMemoryStore[model.URN, model.URN](),
options...,
)
}
func NewGraph(options ...func(*graph.Traits)) Graph {
return NewGraphWithOptions(append(options,
graph.Directed(),
)...,
)
}
func NewAcyclicGraph(options ...func(*graph.Traits)) Graph {
return NewGraphWithOptions(append(options, graph.Directed(), graph.PreventCycles())...)
}
func UrnHasher(r model.URN) model.URN {
return r
}
func ResolveDeploymentGroups(g graph.Graph[model.URN, model.URN]) ([][]model.URN, error) {
sorted, err := graph_addons.ReverseTopologicalSort(g, func(a, b model.URN) bool {
return a.Compare(b) < 0
})
if err != nil {
return nil, err
}
var groups [][]model.URN
var currentGroup []model.URN
visited := make(map[model.URN]bool)
for _, node := range sorted {
if !hasEdges(node, currentGroup, g) {
currentGroup = append(currentGroup, node)
visited[node] = true
} else {
groups = append(groups, currentGroup)
currentGroup = []model.URN{node}
}
}
if len(currentGroup) > 0 {
groups = append(groups, currentGroup)
}
return groups, nil
}
// hasDependencies checks if a node has dependencies in the current group
func hasEdges(node model.URN, group []model.URN, g graph.Graph[model.URN, model.URN]) bool {
for _, n := range group {
if _, err := g.Edge(node, n); err == nil {
return true
}
if _, err := g.Edge(n, node); err == nil {
return true
}
}
return false
}
package constructs
import (
"context"
"errors"
"fmt"
"strings"
"github.com/dominikbraun/graph"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/solution"
stateconverter "github.com/klothoplatform/klotho/pkg/infra/state_reader/state_converter"
statetemplate "github.com/klothoplatform/klotho/pkg/infra/state_reader/state_template"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/property"
"github.com/klothoplatform/klotho/pkg/k2/model"
"github.com/klothoplatform/klotho/pkg/logging"
)
func (ce *ConstructEvaluator) importFrom(ctx context.Context, o InfraOwner, ic *Construct) error {
log := logging.GetLogger(ctx).Sugar()
initGraph := o.GetInitialGraph()
sol := ic.Solution
stackState, hasState := ce.stackStateManager.ConstructStackState[ic.URN]
// NOTE(gg): using topo sort to get all resources, order doesn't matter
resourceIds, err := construct.TopologicalSort(sol.DataflowGraph())
if err != nil {
return fmt.Errorf("could not get resources from %s solution: %w", ic.URN, err)
}
resources := make(map[construct.ResourceId]*construct.Resource)
for _, rId := range resourceIds {
var liveStateRes *construct.Resource
if hasState {
if state, ok := stackState.Resources[rId]; ok {
liveStateRes, err = ce.stateConverter.ConvertResource(stateconverter.Resource{
Urn: string(state.URN),
Type: string(state.Type),
Outputs: state.Outputs,
})
if err != nil {
return fmt.Errorf("could not convert state for %s.%s: %w", ic.URN, rId, err)
}
log.Debugf("Imported %s from state", rId)
}
}
originalRes, err := sol.DataflowGraph().Vertex(rId)
if err != nil {
return fmt.Errorf("could not get resource %s.%s from solution: %w", ic.URN, rId, err)
}
tmpl, err := sol.KnowledgeBase().GetResourceTemplate(rId)
if err != nil {
return fmt.Errorf("could not get resource template %s.%s: %w", ic.URN, rId, err)
}
props := make(construct.Properties)
for k, v := range originalRes.Properties {
props[k] = v
}
hasImportId := false
// set a fake import id, otherwise index.ts will have things like
// Type.get("name", <no value>)
for k, prop := range tmpl.Properties {
if prop.Details().Required && prop.Details().DeployTime {
if liveStateRes == nil {
if ce.DryRun > 0 {
props[k] = fmt.Sprintf("preview(id=%s)", rId)
hasImportId = true
continue
} else {
return fmt.Errorf("could not get live state resource %s (%s)", ic.URN, rId)
}
}
liveIdProp, err := liveStateRes.GetProperty(k)
if err != nil {
return fmt.Errorf("could not get property %s for %s: %w", k, rId, err)
}
props[k] = liveIdProp
hasImportId = true
}
}
if !hasImportId {
continue
}
res := &construct.Resource{
ID: originalRes.ID,
Properties: props,
Imported: true,
}
log.Debugf("Imported %s from solution", rId)
if err := initGraph.AddVertex(res); err != nil {
return fmt.Errorf("could not create imported resource %s from %s: %w", rId, ic.URN, err)
}
resources[rId] = res
}
err = filterImportProperties(resources)
if err != nil {
return fmt.Errorf("could not filter import properties for %s: %w", ic.URN, err)
}
edges, err := sol.DataflowGraph().Edges()
if err != nil {
return fmt.Errorf("could not get edges from %s solution: %w", ic.URN, err)
}
for _, e := range edges {
err := initGraph.AddEdge(e.Source, e.Target, func(ep *graph.EdgeProperties) {
ep.Data = e.Properties.Data
})
switch {
case err == nil:
log.Debugf("Imported edge %s -> %s from solution", e.Source, e.Target)
case errors.Is(err, graph.ErrVertexNotFound):
log.Debugf("Skipping import edge %s -> %s from solution", e.Source, e.Target)
default:
return fmt.Errorf("could not create imported edge %s -> %s from %s: %w", e.Source, e.Target, ic.URN, err)
}
}
return nil
}
// filterImportProperties filters out any references to resources that were skipped from importing.
func filterImportProperties(resources map[construct.ResourceId]*construct.Resource) error {
var errs []error
clearProp := func(id construct.ResourceId, path construct.PropertyPath) {
if err := path.Remove(nil); err != nil {
errs = append(errs,
fmt.Errorf("error clearing %s: %w", construct.PropertyRef{Resource: id, Property: path.String()}, err),
)
}
}
for id, r := range resources {
_ = r.WalkProperties(func(path construct.PropertyPath, _ error) error {
v, ok := path.Get()
if !ok {
return nil
}
switch v := v.(type) {
case construct.ResourceId:
if _, ok := resources[v]; !ok {
clearProp(id, path)
}
case construct.PropertyRef:
if _, ok := resources[v.Resource]; !ok {
clearProp(id, path)
}
}
return nil
})
}
return errors.Join(errs...)
}
// importResourcesFromInputs imports resources from the construct-type inputs of the provided [InfraOwner], o.
// It returns an error if the input value is does not represent a valid construct
// or if importing the resources fails.
func (ce *ConstructEvaluator) importResourcesFromInputs(o InfraOwner, ctx context.Context) error {
return o.ForEachInput(func(i property.Property) error {
// if the input is a construct, import the resources from it
if !strings.HasPrefix(i.Type(), "construct") {
return nil
}
resolvedInput, err := o.GetInputValue(i.Details().Path)
if err != nil {
return fmt.Errorf("could not get input %s: %w", i.Details().Path, err)
}
cURN, ok := resolvedInput.(model.URN)
if !ok || !cURN.IsResource() || cURN.Type != "construct" {
return fmt.Errorf("input %s is not a construct URN", i.Details().Path)
}
c, ok := ce.Constructs.Get(cURN)
if !ok {
return fmt.Errorf("could not find construct %s", cURN)
}
if err := ce.importFrom(ctx, o, c); err != nil {
return fmt.Errorf("could not import resources from %s: %w", cURN, err)
}
return nil
})
}
func (ce *ConstructEvaluator) importBindingToResources(ctx context.Context, b *Binding) error {
return ce.importFrom(ctx, b, b.To)
}
func (ce *ConstructEvaluator) RegisterOutputValues(urn model.URN, outputs map[string]any) {
if c, ok := ce.Constructs.Get(urn); ok {
c.Outputs = outputs
}
}
func (ce *ConstructEvaluator) AddSolution(urn model.URN, sol solution.Solution) {
// panic is fine here if urn isn't in map
// will only happen in programmer error cases
c, _ := ce.Constructs.Get(urn)
c.Solution = sol
}
func loadStateConverter() (stateconverter.StateConverter, error) {
templates, err := statetemplate.LoadStateTemplates("pulumi")
if err != nil {
return nil, err
}
return stateconverter.NewStateConverter("pulumi", templates), nil
}
package constructs
import (
"errors"
"fmt"
template2 "github.com/klothoplatform/klotho/pkg/k2/constructs/template"
"github.com/klothoplatform/klotho/pkg/k2/model"
"github.com/klothoplatform/klotho/pkg/reflectutil"
"go.uber.org/zap"
"reflect"
"regexp"
"strconv"
"strings"
)
// Matches one or more interpolation groups in a string e.g., ${inputs:foo.bar}-baz-${resource:Boz}
var interpolationPattern = regexp.MustCompile(`\$\{([^:]+):([^}]+)}`)
// Matches exactly one interpolation group e.g., ${inputs:foo.bar}
var isolatedInterpolationPattern = regexp.MustCompile(`^\$\{([^:]+):([^}]+)}$`)
var spreadPattern = regexp.MustCompile(`\.\.\.}$`)
// interpolateValue interpolates a value based on the context of the construct
//
// The format of a raw value is ${<prefix>:<key>} where prefix is the type of value to interpolate and key is the key to interpolate
//
// The key can be a path to a value in the context.
// For example, ${inputs:foo.bar} will interpolate the value of the key bar in the foo input.
//
// The target of a dot-separated path can be a map or a struct.
// The path can also include brackets to access an array or an element whose key contains a dot.
// For example, ${inputs:foo[0].bar} will interpolate the value of the key bar in the first element of the foo input array.
//
// The path can also include a spread operator to expand an array into the current array.
// For example, ${inputs:foo...} will expand the foo input array into the current array.
//
// A rawValue can contain a combination of interpolation expressions, literals, and go templates.
// For example, "${inputs:foo.bar}-baz-${resource:Boz}" is a valid rawValue.
func (ce *ConstructEvaluator) interpolateValue(dv *DynamicValueData, rawValue any) (any, error) {
if ref, ok := rawValue.(template2.ResourceRef); ok {
switch ref.Type {
case template2.ResourceRefTypeInterpolated:
return ce.interpolateValue(dv, ref.ResourceKey)
case template2.ResourceRefTypeTemplate:
ref.ConstructURN = dv.currentOwner.GetURN()
rk, err := ce.interpolateValue(dv, ref.ResourceKey)
if err != nil {
return nil, err
}
ref.ResourceKey = fmt.Sprintf("%s", rk)
return ref, nil
default:
return rawValue, nil
}
}
v := reflectutil.GetConcreteElement(reflect.ValueOf(rawValue))
if !v.IsValid() {
return rawValue, nil
}
rawValue = v.Interface()
switch v.Kind() {
case reflect.String:
resolvedVal, err := ce.interpolateString(dv, v.String())
if err != nil {
return nil, err
}
return resolvedVal, nil
case reflect.Slice:
length := v.Len()
var interpolated []any
for i := 0; i < length; i++ {
// handle spread operator by injecting the spread value into the array at the current index
originalValue := reflectutil.GetConcreteValue(v.Index(i))
if originalString, ok := originalValue.(string); ok && spreadPattern.MatchString(originalString) {
unspreadPath := originalString[:len(originalString)-4] + "}"
spreadValue, err := ce.interpolateValue(dv, unspreadPath)
if err != nil {
return nil, err
}
if spreadValue == nil {
continue
}
if reflect.TypeOf(spreadValue).Kind() != reflect.Slice {
return nil, errors.New("spread value must be a slice")
}
for i := 0; i < reflect.ValueOf(spreadValue).Len(); i++ {
interpolated = append(interpolated, reflect.ValueOf(spreadValue).Index(i).Interface())
}
continue
}
value, err := ce.interpolateValue(dv, v.Index(i).Interface())
if err != nil {
return nil, err
}
interpolated = append(interpolated, value)
}
return interpolated, nil
case reflect.Map:
keys := v.MapKeys()
interpolated := make(map[string]any)
for _, k := range keys {
key, err := ce.interpolateValue(dv, k.Interface())
if err != nil {
return nil, err
}
value, err := ce.interpolateValue(dv, v.MapIndex(k).Interface())
if err != nil {
return nil, err
}
interpolated[fmt.Sprint(key)] = value
}
return interpolated, nil
case reflect.Struct:
// Create a new instance of the struct
newStruct := reflect.New(v.Type()).Elem()
// Interpolate each field
for i := 0; i < v.NumField(); i++ {
fieldName := v.Type().Field(i).Name
fieldValue, err := ce.interpolateValue(dv, v.Field(i).Interface())
if err != nil {
return nil, err
}
// Set the interpolated value to the field in the new struct
if fieldValue != nil {
newStruct.FieldByName(fieldName).Set(reflect.ValueOf(fieldValue))
}
}
// Return the new struct
return newStruct.Interface(), nil
default:
return rawValue, nil
}
}
func (ce *ConstructEvaluator) interpolateString(dv *DynamicValueData, rawValue string) (any, error) {
// handle go template expressions
if strings.Contains(rawValue, "{{") {
ctx := DynamicValueContext{constructs: ce.Constructs}
err := ctx.ExecuteUnmarshal(rawValue, dv, &rawValue)
if err != nil {
return nil, err
}
}
ps := dv.propertySource
if ps == nil {
ps = dv.currentOwner.GetPropertySource()
}
// if the rawValue is an isolated interpolation expression, interpolate it and return the raw value
if isolatedInterpolationPattern.MatchString(rawValue) {
return ce.interpolateExpression(dv.currentOwner, ps, rawValue)
}
var err error
// Replace each match in the rawValue (mixed expressions are always interpolated as strings)
interpolated := interpolationPattern.ReplaceAllStringFunc(rawValue, func(match string) string {
var val any
val, err = ce.interpolateExpression(dv.currentOwner, ps, match)
return fmt.Sprint(val)
})
if err != nil {
return nil, err
}
return interpolated, nil
}
func (ce *ConstructEvaluator) interpolateExpression(owner InfraOwner, ps *template2.PropertySource, match string) (any, error) {
if ps == nil {
return nil, errors.New("property source is nil")
}
// Split the match into prefix and key
parts := interpolationPattern.FindStringSubmatch(match)
prefix := parts[1]
key := parts[2]
// Choose the correct root property from the source based on the prefix
var p any
ok := false
if prefix == "inputs" || prefix == "resources" || prefix == "edges" || prefix == "meta" ||
strings.HasPrefix(prefix, "from.") ||
strings.HasPrefix(prefix, "to.") {
p, ok = ps.GetProperty(prefix)
if !ok {
return nil, fmt.Errorf("could not get %s", prefix)
}
} else {
return nil, fmt.Errorf("invalid prefix: %s", prefix)
}
prefixParts := strings.Split(prefix, ".")
// associate any ResourceRefs with the URN of the property source they're being interpolated from
// if the prefix is "from" or "to", the URN of the property source is the "urn" field of that level in the property source
var refUrn model.URN
if strings.HasSuffix(prefix, "resources") {
urnKey := "urn"
if prefixParts[0] == "from" || prefixParts[0] == "to" {
urnKey = fmt.Sprintf("%s.urn", prefixParts[0])
}
psURN, ok := template2.GetTypedProperty[model.URN](ps, urnKey)
if !ok {
psURN = owner.GetURN()
}
refUrn = psURN
} else {
propTrace, err := reflectutil.TracePath(reflect.ValueOf(p), key)
if err == nil {
refConstruct, ok := reflectutil.LastOfType[*Construct](propTrace)
if ok {
refUrn = refConstruct.URN
}
}
if refUrn.Equals(model.URN{}) {
refUrn = owner.GetURN()
}
}
// return an IaC reference if the key matches the IaC reference pattern
if iacRefPattern.MatchString(key) {
return template2.ResourceRef{
ResourceKey: iacRefPattern.FindStringSubmatch(key)[1],
Property: iacRefPattern.FindStringSubmatch(key)[2],
Type: template2.ResourceRefTypeIaC,
ConstructURN: refUrn,
}, nil
}
// special cases for resources allowing for accessing the name of a resource directly instead of using .Id.Name
if prefix == "resources" || prefixParts[len(prefixParts)-1] == "resources" {
keyParts := reflectutil.SplitPath(key)
resourceKey := strings.Trim(keyParts[0], ".[]")
if len(keyParts) > 1 {
if path := keyParts[1]; path == ".Name" {
return p.(map[string]*Resource)[resourceKey].Id.Name, nil
}
}
}
// Retrieve the value from the designated property source
value, err := ce.getValueFromSource(p, key, false)
if err != nil {
zap.S().Debugf("could not get value from source: %s", err)
return nil, nil
}
keyAndRef := strings.Split(key, "#")
var refProperty string
if len(keyAndRef) == 2 {
refProperty = keyAndRef[1]
}
// If the value is a Resource, return a ResourceRef
if r, ok := value.(*Resource); ok {
return template2.ResourceRef{
ResourceKey: r.Id.String(),
Property: refProperty,
Type: template2.ResourceRefTypeIaC,
ConstructURN: refUrn,
}, nil
}
if r, ok := value.(template2.ResourceRef); ok {
r.ConstructURN = refUrn
return r, nil
}
// Replace the match with the value
return value, nil
}
// iacRefPattern is a regular expression pattern that matches an IaC reference
// IaC references are in the format <resource-key>#<property>
var iacRefPattern = regexp.MustCompile(`^([a-zA-Z0-9_-]+)#([a-zA-Z0-9._-]+)$`)
// indexPattern is a regular expression pattern that matches an array index in the format `[index]`
var indexPattern = regexp.MustCompile(`^\[\d+]$`)
// getValueFromSource retrieves a value from a property source based on a key
// the flat parameter is used to determine if the key is a flat key or a path (mixed keys aren't supported at the moment)
// e.g (flat = true): key = "foo.bar" -> value = collection["foo."bar"], flat = false: key = "foo.bar" -> value = collection["foo"]["bar"]
func (ce *ConstructEvaluator) getValueFromSource(source any, key string, flat bool) (any, error) {
value := reflect.ValueOf(source)
keyAndRef := strings.Split(key, "#")
if len(keyAndRef) > 2 {
return nil, fmt.Errorf("invalid engine reference property reference: %s", key)
}
var refProperty string
if len(keyAndRef) == 2 {
refProperty = keyAndRef[1]
key = keyAndRef[0]
}
// Split the key into parts if not flat
parts := []string{key}
if !flat {
parts = reflectutil.SplitPath(key)
}
for i, part := range parts {
parts[i] = strings.TrimPrefix(part, ".")
}
var err error
var lastValidValue reflect.Value
lastValidIndex := -1
// Traverse the map/struct/array according to the parts
for i, part := range parts {
// Check if the part is an array index
if indexPattern.MatchString(part) {
// Split the part into the key and the index
part = strings.TrimSuffix(strings.TrimPrefix(part, "["), "]")
var index int
index, err = strconv.Atoi(part)
if err != nil {
err = fmt.Errorf("could not parse index: %w", err)
break
}
value = reflectutil.GetConcreteElement(value)
kind := value.Kind()
switch kind {
case reflect.Slice | reflect.Array:
if index >= value.Len() {
err = fmt.Errorf("index out of bounds: %d", index)
break
}
value = value.Index(index)
default:
err = fmt.Errorf("invalid type: %s", kind)
}
} else {
// The part is not an array index
part = strings.TrimSuffix(strings.TrimPrefix(part, "["), "]")
if value.Kind() == reflect.Map {
v := value.MapIndex(reflect.ValueOf(part))
if v.IsValid() {
value = v
} else {
err = fmt.Errorf("could not get value for key: %s", key)
break
}
} else if r, ok := value.Interface().(*Resource); ok {
if len(parts) == 1 {
return template2.ResourceRef{
ResourceKey: part,
Property: refProperty,
Type: template2.ResourceRefTypeTemplate,
}, nil
} else {
// if the parent is a resource, children are implicitly properties of the resource
lastValidValue = reflect.ValueOf(r.Properties)
value, err = reflectutil.GetField(lastValidValue, part)
if err != nil {
err = fmt.Errorf("could not get field: %w", err)
break
}
}
} else if u, ok := value.Interface().(model.URN); ok {
if c, ok := ce.Constructs.Get(u); ok {
lastValidValue = reflect.ValueOf(c)
value, err = reflectutil.GetField(lastValidValue, part)
if err != nil {
err = fmt.Errorf("could not get field: %w", err)
break
}
} else {
err = fmt.Errorf("could not get construct: %s", u)
break
}
} else {
var rVal reflect.Value
rVal, err = reflectutil.GetField(value, part)
if err != nil {
err = fmt.Errorf("could not get field: %w", err)
break
}
value = rVal
}
}
if err != nil {
break
}
if i == len(parts)-1 {
return value.Interface(), nil
}
lastValidValue = value
lastValidIndex = i
}
if lastValidIndex > -1 {
return ce.getValueFromSource(lastValidValue.Interface(), strings.Join(parts[lastValidIndex+1:], "."), true)
}
return value.Interface(), err
}
package template
import (
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/property"
"gopkg.in/yaml.v3"
)
type BindingTemplate struct {
From property.ConstructType `yaml:"from"`
To property.ConstructType `yaml:"to"`
Priority int `yaml:"priority"`
Inputs *Properties `yaml:"inputs"`
Outputs map[string]OutputTemplate `yaml:"outputs"`
InputRules []InputRuleTemplate `yaml:"input_rules"`
Resources map[string]ResourceTemplate `yaml:"resources"`
Edges []EdgeTemplate `yaml:"edges"`
resourceOrder []string
}
func (bt *BindingTemplate) GetInput(path string) property.Property {
return property.GetProperty(bt.Inputs.propertyMap, path)
}
// ForEachInput walks the input properties of a construct template,
// including nested properties, and calls the given function for each input.
// If the function returns an error, the walk will stop and return that error.
// If the function returns [ErrStopWalk], the walk will stop and return nil.
func (bt *BindingTemplate) ForEachInput(c construct.Properties, f func(property.Property) error) error {
return bt.Inputs.ForEach(c, f)
}
func (bt *BindingTemplate) ResourcesIterator() Iterator[string, ResourceTemplate] {
return Iterator[string, ResourceTemplate]{
source: bt.Resources,
order: bt.resourceOrder,
}
}
func (bt *BindingTemplate) UnmarshalYAML(value *yaml.Node) error {
type bindingTemplate BindingTemplate
var template bindingTemplate
if err := value.Decode(&template); err != nil {
return err
}
resourceOrder, _ := captureYAMLKeyOrder(value, "resources")
template.resourceOrder = resourceOrder
if template.Inputs == nil {
template.Inputs = NewProperties(nil)
}
if template.Resources == nil {
template.Resources = make(map[string]ResourceTemplate)
}
if template.Edges == nil {
template.Edges = make([]EdgeTemplate, 0)
}
if template.Outputs == nil {
template.Outputs = make(map[string]OutputTemplate)
}
if template.InputRules == nil {
template.InputRules = make([]InputRuleTemplate, 0)
}
*bt = BindingTemplate(template)
return nil
}
package template
import (
"errors"
"fmt"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/property"
"gopkg.in/yaml.v3"
"regexp"
)
type (
ConstructTemplate struct {
Id property.ConstructType `yaml:"id"`
Version string `yaml:"version"`
Description string `yaml:"description"`
Resources map[string]ResourceTemplate `yaml:"resources"`
Edges []EdgeTemplate `yaml:"edges"`
Inputs *Properties `yaml:"inputs"`
Outputs map[string]OutputTemplate `yaml:"outputs"`
InputRules []InputRuleTemplate `yaml:"input_rules"`
resourceOrder []string
}
ResourceTemplate struct {
Type string `yaml:"type"`
Name string `yaml:"name"`
Namespace string `yaml:"namespace"`
Properties map[string]any `yaml:"properties"`
}
EdgeTemplate struct {
From ResourceRef `yaml:"from"`
To ResourceRef `yaml:"to"`
Data construct.EdgeData `yaml:"data"`
}
OutputTemplate struct {
Name string `yaml:"name"`
Description string `yaml:"description"`
Value any `yaml:"value"`
}
InputRuleTemplate struct {
If string `yaml:"if"`
Then *ConditionalExpressionTemplate `yaml:"then"`
Else *ConditionalExpressionTemplate `yaml:"else"`
ForEach string `yaml:"for_each"`
Do *ConditionalExpressionTemplate `yaml:"do"`
Prefix string `yaml:"prefix"`
}
ConditionalExpressionTemplate struct {
Resources map[string]ResourceTemplate `yaml:"resources"`
Edges []EdgeTemplate `yaml:"edges"`
Outputs map[string]OutputTemplate `yaml:"outputs"`
Rules []InputRuleTemplate `yaml:"rules"`
resourceOrder []string
}
ValidationTemplate struct {
MinLength int `yaml:"min_length"`
MaxLength int `yaml:"max_length"`
MinValue int `yaml:"min_value"`
MaxValue int `yaml:"max_value"`
Pattern string `yaml:"pattern"`
Enum []string `yaml:"enum"`
UniqueValues bool `yaml:"unique_values"`
}
)
var interpolationPattern = regexp.MustCompile(`\$\{([^:]+):([^}]+)}`)
func (e *EdgeTemplate) UnmarshalYAML(value *yaml.Node) error {
// Unmarshal the edge template from a YAML node
var edge struct {
From string `yaml:"from"`
To string `yaml:"to"`
Data construct.EdgeData `yaml:"data"`
}
if err := value.Decode(&edge); err != nil {
return err
}
if interpolationPattern.MatchString(edge.From) {
e.From = ResourceRef{
ResourceKey: edge.From,
Type: ResourceRefTypeInterpolated,
}
} else {
e.From = ResourceRef{
ResourceKey: edge.From,
Type: ResourceRefTypeTemplate,
}
}
if interpolationPattern.MatchString(edge.To) {
e.To = ResourceRef{
ResourceKey: edge.To,
Type: ResourceRefTypeInterpolated,
}
} else {
e.To = ResourceRef{
ResourceKey: edge.To,
Type: ResourceRefTypeTemplate,
}
}
e.Data = edge.Data
return nil
}
func (ct *ConstructTemplate) ResourcesIterator() Iterator[string, ResourceTemplate] {
return Iterator[string, ResourceTemplate]{
source: ct.Resources,
order: ct.resourceOrder,
}
}
func (ct *ConstructTemplate) UnmarshalYAML(value *yaml.Node) error {
type constructTemplate ConstructTemplate
var template constructTemplate
if err := value.Decode(&template); err != nil {
return err
}
resourceOrder, _ := captureYAMLKeyOrder(value, "resources")
template.resourceOrder = resourceOrder
if template.Inputs == nil {
template.Inputs = NewProperties(nil)
}
if template.Resources == nil {
template.Resources = make(map[string]ResourceTemplate)
}
if template.Outputs == nil {
template.Outputs = make(map[string]OutputTemplate)
}
*ct = ConstructTemplate(template)
return nil
}
func captureYAMLKeyOrder(rootNode *yaml.Node, sectionKey string) ([]string, error) {
var resourceOrder []string
foundKey := false
for i := 0; i < len(rootNode.Content); i += 2 {
if keyNode := rootNode.Content[i]; keyNode.Value == sectionKey {
foundKey = true
for j := 0; j < len(rootNode.Content[i+1].Content); j += 2 {
resourceOrder = append(resourceOrder, rootNode.Content[i+1].Content[j].Value)
}
break
}
}
if !foundKey {
return nil, fmt.Errorf("could not find key: %s", sectionKey)
}
return resourceOrder, nil
}
type Iterator[K comparable, V any] struct {
source map[K]V
order []K
index int
}
func (r *Iterator[K, V]) Next() (K, V, bool) {
if r.index >= len(r.order) || r.index >= len(r.source) {
var zeroK K
var zeroV V
return zeroK, zeroV, false
}
// Get the next resource that actually exists in the map
for _, ok := r.source[r.order[r.index]]; !ok && r.index < len(r.order); r.index++ {
// do nothing
}
key := r.order[r.index]
resource := r.source[key]
r.index++
return key, resource, true
}
type IterFunc[K comparable, V any] func(K, V) error
var StopIteration = fmt.Errorf("stop iteration")
func (r *Iterator[K, V]) ForEach(f IterFunc[K, V]) {
for key, resource, ok := r.Next(); ok; key, resource, ok = r.Next() {
if err := f(key, resource); err != nil {
if errors.Is(err, StopIteration) {
return
}
}
}
}
func (r *Iterator[K, V]) Reset() {
r.index = 0
}
type BindingDirection string
const (
BindingDirectionFrom = "from"
BindingDirectionTo = "to"
)
func (ct *ConstructTemplate) GetBindingTemplate(direction BindingDirection, other property.ConstructType) (BindingTemplate, error) {
if direction == BindingDirectionFrom {
return LoadBindingTemplate(ct.Id, ct.Id, other)
} else {
return LoadBindingTemplate(ct.Id, other, ct.Id)
}
}
func (cet *ConditionalExpressionTemplate) UnmarshalYAML(value *yaml.Node) error {
type conditionalExpressionTemplate ConditionalExpressionTemplate
var temp conditionalExpressionTemplate
if err := value.Decode(&temp); err != nil {
return err
}
cet.Resources = temp.Resources
cet.Edges = temp.Edges
cet.Outputs = temp.Outputs
cet.Rules = temp.Rules
resourceOrder, _ := captureYAMLKeyOrder(value, "resources")
cet.resourceOrder = resourceOrder
return nil
}
func (cet *ConditionalExpressionTemplate) ResourcesIterator() Iterator[string, ResourceTemplate] {
return Iterator[string, ResourceTemplate]{
source: cet.Resources,
order: cet.resourceOrder,
}
}
func (irt *InputRuleTemplate) UnmarshalYAML(value *yaml.Node) error {
type inputRuleTemplate InputRuleTemplate
var temp inputRuleTemplate
if err := value.Decode(&temp); err != nil {
return err
}
if (temp.If == "" && temp.ForEach == "") || (temp.If != "" && temp.ForEach != "") {
return fmt.Errorf("invalid InputRuleTemplate: must have either If-Then-Else or ForEach-Do")
}
// Check if it's an If-Then-Else structure
if temp.If != "" {
if temp.ForEach != "" || temp.Do != nil {
return fmt.Errorf("invalid InputRuleTemplate: cannot mix If-Then-Else with ForEach-Do")
}
irt.If = temp.If
irt.Then = temp.Then
irt.Else = temp.Else
} else if temp.ForEach != "" {
// Check if it's a ForEach-Do structure
if temp.If != "" || temp.Then != nil || temp.Else != nil {
return fmt.Errorf("invalid InputRuleTemplate: cannot mix ForEach-Do with If-Then-Else")
}
irt.ForEach = temp.ForEach
irt.Do = temp.Do
} else {
return fmt.Errorf("invalid InputRuleTemplate: must have either If-Then-Else or ForEach-Do")
}
irt.Prefix = temp.Prefix
return nil
}
func (ct *ConstructTemplate) GetInput(path string) property.Property {
return property.GetProperty(ct.Inputs.propertyMap, path)
}
// ForEachInput walks the input properties of a construct template,
// including nested properties, and calls the given function for each input.
// If the function returns an error, the walk will stop and return that error.
// If the function returns [ErrStopWalk], the walk will stop and return nil.
func (ct *ConstructTemplate) ForEachInput(c construct.Properties, f func(property.Property) error) error {
return ct.Inputs.ForEach(c, f)
}
package inputs
import (
"fmt"
"reflect"
"regexp"
"strings"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/properties"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/property"
"gopkg.in/yaml.v3"
)
type (
// InputTemplateMap defines the structure of properties defined in YAML as a part of a template.
InputTemplateMap map[string]*InputTemplate
// InputTemplate defines the structure of a property defined in YAML as a part of a template.
// these fields must be exactly the union of all the fields in the different property types.
InputTemplate struct {
Name string `json:"name" yaml:"name"`
// Type defines the type of the property
Type string `json:"type" yaml:"type"`
// Description defines the description of the property
Description string `json:"description" yaml:"description"`
// DefaultValue defines the default value of the property
DefaultValue any `json:"default_value" yaml:"default_value"`
// Required defines whether the property is required
Required bool `json:"required" yaml:"required"`
// Properties defines the sub properties of a key_value_list, map, list, or set
Properties InputTemplateMap `json:"properties" yaml:"properties"`
// MinLength defines the minimum length of a string, list, set, or map (number of entries)
MinLength *int `yaml:"min_length"`
// MaxLength defines the maximum length of a string, list, set, or map (number of entries)
MaxLength *int `yaml:"max_length"`
// MinValue defines the minimum value of an int or float
MinValue *float64 `yaml:"min_value"`
// MaxValue defines the maximum value of an int or float
MaxValue *float64 `yaml:"max_value"`
// UniqueItems defines whether the items in a list or set must be unique
UniqueItems *bool `yaml:"unique_items"`
// UniqueKeys defines whether the keys in a map must be unique (default true)
UniqueKeys *bool `yaml:"unique_keys"`
// SanitizeTmpl is a go template to sanitize user input when setting the property
SanitizeTmpl string `yaml:"sanitize"`
// AllowedValues defines an enumeration of allowed values for a string, int, float, or bool
AllowedValues []string `yaml:"allowed_values"`
// KeyProperty is the property of the keys in a key_value_list or map
KeyProperty *InputTemplate `yaml:"key_property"`
// ValueProperty is the property of the values in a key_value_list or map
ValueProperty *InputTemplate `yaml:"value_property"`
// ItemProperty is the property of the items in a list or set
ItemProperty *InputTemplate `yaml:"item_property"`
// Path is the path to the property in the template
// this field is derived and is not part of the yaml
Path string `json:"-" yaml:"-"`
}
PropertyType string
FieldConverterFunc func(val reflect.Value, p *InputTemplate, kp property.Property) error
)
var (
StringPropertyType PropertyType = "string"
IntPropertyType PropertyType = "int"
FloatPropertyType PropertyType = "float"
BoolPropertyType PropertyType = "bool"
MapPropertyType PropertyType = "map"
ListPropertyType PropertyType = "list"
SetPropertyType PropertyType = "set"
AnyPropertyType PropertyType = "any"
PathPropertyType PropertyType = "path"
KeyValueListPropertyType PropertyType = "key_value_list"
ConstructPropertyType PropertyType = "construct"
)
func (p *InputTemplateMap) UnmarshalYAML(n *yaml.Node) error {
type h InputTemplateMap
var p2 h
err := n.Decode(&p2)
if err != nil {
return err
}
for name, property := range p2 {
property.Name = name
property.Path = name
setChildPaths(property, name)
p2[name] = property
}
*p = InputTemplateMap(p2)
return nil
}
func (p *InputTemplateMap) Convert() (property.PropertyMap, error) {
var errs error
props := property.PropertyMap{}
for name, prop := range *p {
propertyType, err := prop.Convert()
if err != nil {
errs = fmt.Errorf("%w\n%s", errs, err.Error())
continue
}
props[name] = propertyType
}
return props, errs
}
func (p *InputTemplate) Convert() (property.Property, error) {
propertyType, err := InitializeProperty(p.Type)
if err != nil {
return nil, err
}
propertyType.Details().Path = p.Path
srcVal := reflect.ValueOf(p).Elem()
dstVal := reflect.ValueOf(propertyType).Elem()
for i := 0; i < srcVal.NumField(); i++ {
srcField := srcVal.Field(i)
fieldName := srcVal.Type().Field(i).Name
dstField := dstVal.FieldByName(fieldName)
if !dstField.IsValid() || !dstField.CanSet() {
continue
}
// Skip nil pointers
if (srcField.Kind() == reflect.Ptr || srcField.Kind() == reflect.Interface) && srcField.IsNil() {
continue
// skip empty arrays and slices
} else if (srcField.Kind() == reflect.Array || srcField.Kind() == reflect.Slice) && srcField.Len() == 0 {
continue
}
// Handle sub properties so we can recurse down the tree
switch fieldName {
case "Properties":
propMap := srcField.Interface().(InputTemplateMap)
var errs error
props := property.PropertyMap{}
for name, prop := range propMap {
propertyType, err := prop.Convert()
if err != nil {
errs = fmt.Errorf("%w\n%s", errs, err.Error())
continue
}
props[name] = propertyType
}
if errs != nil {
return nil, fmt.Errorf("could not convert sub properties: %w", errs)
}
dstField.Set(reflect.ValueOf(props))
continue
case "KeyProperty", "ValueProperty":
switch {
case strings.HasPrefix(p.Type, "map"):
keyType, valueType, hasElementTypes := strings.Cut(
strings.TrimSuffix(strings.TrimPrefix(p.Type, "map("), ")"),
",",
)
elemProp := srcField.Interface().(*InputTemplate)
// Add the element's type if it is not specified but is on the parent.
// For example, 'map(string,string)' on the parent means the key_property doesn't need 'type: string'
if hasElementTypes {
if fieldName == "KeyProperty" {
if elemProp.Type != "" && elemProp.Type != keyType {
return nil, fmt.Errorf("key property type must be %s (was %s)", keyType, elemProp.Type)
} else if elemProp.Type == "" {
elemProp.Type = keyType
}
} else {
if elemProp.Type != "" && elemProp.Type != valueType {
return nil, fmt.Errorf("value property type must be %s (was %s)", valueType, elemProp.Type)
} else if elemProp.Type == "" {
elemProp.Type = valueType
}
}
}
converted, err := elemProp.Convert()
if err != nil {
return nil, fmt.Errorf("could not convert %s: %w", fieldName, err)
}
srcField = reflect.ValueOf(converted)
case strings.HasPrefix(p.Type, "key_value_list"):
keyType, valueType, hasElementTypes := strings.Cut(
strings.TrimSuffix(strings.TrimPrefix(p.Type, "key_value_list("), ")"),
",",
)
keyType = strings.TrimSpace(keyType)
valueType = strings.TrimSpace(valueType)
elemProp := srcField.Interface().(*InputTemplate)
if hasElementTypes {
if fieldName == "KeyProperty" {
if elemProp.Type != "" && elemProp.Type != keyType {
return nil, fmt.Errorf("key property type must be %s (was %s)", keyType, elemProp.Type)
} else if elemProp.Type == "" {
elemProp.Type = keyType
}
} else {
if elemProp.Type != "" && elemProp.Type != valueType {
return nil, fmt.Errorf("value property type must be %s (was %s)", valueType, elemProp.Type)
} else if elemProp.Type == "" {
elemProp.Type = valueType
}
}
}
converted, err := elemProp.Convert()
if err != nil {
return nil, fmt.Errorf("could not convert %s: %w", fieldName, err)
}
srcField = reflect.ValueOf(converted)
default:
return nil, fmt.Errorf("property must be 'map' or 'key_value_list' (was %s) for %s", p.Type, fieldName)
}
case "ItemProperty":
hasItemType := strings.Contains(p.Type, "(")
elemProp := srcField.Interface().(*InputTemplate)
if hasItemType {
itemType := strings.TrimSuffix(
strings.TrimPrefix(strings.TrimPrefix(p.Type, "list("), "set("),
")",
)
if elemProp.Type != "" && elemProp.Type != itemType {
return nil, fmt.Errorf("item property type must be %s (was %s)", itemType, elemProp.Type)
} else if elemProp.Type == "" {
elemProp.Type = itemType
}
}
converted, err := elemProp.Convert()
if err != nil {
return nil, fmt.Errorf("could not convert %s: %w", fieldName, err)
}
srcField = reflect.ValueOf(converted)
}
if srcField.Type().AssignableTo(dstField.Type()) {
dstField.Set(srcField)
continue
}
if dstField.Kind() == reflect.Ptr && srcField.Kind() == reflect.Ptr {
if srcField.Type().Elem().AssignableTo(dstField.Type().Elem()) {
dstField.Set(srcField)
continue
} else if srcField.Type().Elem().ConvertibleTo(dstField.Type().Elem()) {
val := srcField.Elem().Convert(dstField.Type().Elem())
// set dest field to a pointer of val
dstField.Set(reflect.New(dstField.Type().Elem()))
dstField.Elem().Set(val)
continue
}
}
if conversion, found := fieldConversion[fieldName]; found {
err := conversion(srcField, p, propertyType)
if err != nil {
return nil, err
}
continue
}
return nil, fmt.Errorf(
"could not assign %s#%s (%s) to field in %T (%s)",
p.Path, fieldName, srcField.Type(), propertyType, dstField.Type(),
)
}
return propertyType, nil
}
func setChildPaths(property *InputTemplate, currPath string) {
for name, child := range property.Properties {
child.Name = name
path := currPath + "." + name
child.Path = path
setChildPaths(child, path)
}
}
func (p InputTemplateMap) Clone() InputTemplateMap {
newProps := make(InputTemplateMap, len(p))
for k, v := range p {
newProps[k] = v.Clone()
}
return newProps
}
func (p *InputTemplate) Clone() *InputTemplate {
cloned := *p
cloned.Properties = make(InputTemplateMap, len(p.Properties))
for k, v := range p.Properties {
cloned.Properties[k] = v.Clone()
}
return &cloned
}
// fieldConversion is a map providing functionality on how to convert inputs into our internal types if they are not inherently the same structure
var fieldConversion = map[string]FieldConverterFunc{
"SanitizeTmpl": func(val reflect.Value, p *InputTemplate, kp property.Property) error {
sanitizeTmpl, ok := val.Interface().(string)
if !ok {
return fmt.Errorf("invalid sanitize template")
}
if sanitizeTmpl == "" {
return nil
}
tmpl, err := property.NewSanitizationTmpl(kp.Details().Name, sanitizeTmpl)
if err != nil {
return err
}
dstField := reflect.ValueOf(kp).Elem().FieldByName("SanitizeTmpl")
dstField.Set(reflect.ValueOf(tmpl))
return nil
},
}
func InitializeProperty(ptype string) (property.Property, error) {
if ptype == "" {
return nil, fmt.Errorf("property does not have a type")
}
baseType, typeArgs, err := GetTypeInfo(ptype)
if err != nil {
return nil, err
}
switch baseType {
case MapPropertyType:
if len(typeArgs) == 0 {
return &properties.MapProperty{}, nil
}
if len(typeArgs) != 2 {
return nil, fmt.Errorf("invalid number of arguments for map property type: %s", ptype)
}
keyVal, err := InitializeProperty(typeArgs[0])
if err != nil {
return nil, err
}
valProp, err := InitializeProperty(typeArgs[1])
if err != nil {
return nil, err
}
return &properties.MapProperty{KeyProperty: keyVal, ValueProperty: valProp}, nil
case ListPropertyType:
if len(typeArgs) == 0 {
return &properties.ListProperty{}, nil
}
if len(typeArgs) != 1 {
return nil, fmt.Errorf("invalid number of arguments for list property type: %s", ptype)
}
itemProp, err := InitializeProperty(typeArgs[0])
if err != nil {
return nil, err
}
return &properties.ListProperty{ItemProperty: itemProp}, nil
case SetPropertyType:
if len(typeArgs) == 0 {
return &properties.SetProperty{}, nil
}
if len(typeArgs) != 1 {
return nil, fmt.Errorf("invalid number of arguments for set property type: %s", ptype)
}
itemProp, err := InitializeProperty(typeArgs[0])
if err != nil {
return nil, err
}
return &properties.SetProperty{ItemProperty: itemProp}, nil
case KeyValueListPropertyType:
if len(typeArgs) == 0 {
return &properties.KeyValueListProperty{}, nil
}
if len(typeArgs) != 2 {
return nil, fmt.Errorf("invalid number of arguments for %s property type: %s", KeyValueListPropertyType, ptype)
}
keyPropType := typeArgs[0]
valPropType := typeArgs[1]
keyProp, err := InitializeProperty(keyPropType)
keyProp.Details().Name = "Key"
if err != nil {
return nil, err
}
valProp, err := InitializeProperty(valPropType)
valProp.Details().Name = "Value"
if err != nil {
return nil, err
}
return &properties.KeyValueListProperty{KeyProperty: keyProp, ValueProperty: valProp}, nil
case ConstructPropertyType:
var allowedTypes []property.ConstructType
if len(typeArgs) > 0 {
for _, t := range typeArgs {
var id property.ConstructType
err := id.FromString(t)
if err != nil {
return nil, fmt.Errorf("invalid construct type %s: %w", t, err)
}
allowedTypes = append(allowedTypes, id)
}
}
return &properties.ConstructProperty{AllowedTypes: allowedTypes}, nil
case AnyPropertyType:
return &properties.AnyProperty{}, nil
case StringPropertyType:
return &properties.StringProperty{}, nil
case IntPropertyType:
return &properties.IntProperty{}, nil
case FloatPropertyType:
return &properties.FloatProperty{}, nil
case BoolPropertyType:
return &properties.BoolProperty{}, nil
case PathPropertyType:
return &properties.PathProperty{}, nil
default:
return nil, fmt.Errorf("unknown property type '%s'", baseType)
}
}
var funcRegex = regexp.MustCompile(`^(\w+)(?:\(([^)]*)\))?$`)
var argRegex = regexp.MustCompile(`[^,]+`)
func GetTypeInfo(t string) (propType PropertyType, args []string, err error) {
matches := funcRegex.FindStringSubmatch(t)
if matches == nil {
return "", nil, fmt.Errorf("invalid property type %s", t)
}
propType = PropertyType(matches[1])
args = argRegex.FindAllString(matches[2], -1)
for i, arg := range args {
args[i] = strings.TrimSpace(arg)
}
return propType, args, nil
}
package template
import (
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/inputs"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/property"
"gopkg.in/yaml.v3"
)
type Properties struct {
propertyMap property.PropertyMap
}
func NewProperties(properties property.PropertyMap) *Properties {
if properties == nil {
properties = make(property.PropertyMap)
}
return &Properties{
propertyMap: properties,
}
}
func (p *Properties) Clone() property.Properties {
newProps := Properties{
propertyMap: p.propertyMap.Clone(),
}
return &newProps
}
func (p *Properties) ForEach(c construct.Properties, f func(p property.Property) error) error {
return p.propertyMap.ForEach(c, f)
}
func (p *Properties) Get(key string) (property.Property, bool) {
return p.propertyMap.Get(key)
}
func (p *Properties) Set(key string, value property.Property) {
p.propertyMap.Set(key, value)
}
func (p *Properties) Remove(key string) {
p.propertyMap.Remove(key)
}
func (p *Properties) AsMap() map[string]property.Property {
return p.propertyMap
}
func (p *Properties) UnmarshalYAML(node *yaml.Node) error {
if p.propertyMap == nil {
p.propertyMap = make(property.PropertyMap)
}
ip := make(inputs.InputTemplateMap)
if err := node.Decode(&ip); err != nil {
return err
}
converted, err := ip.Convert()
if err != nil {
return err
}
for k, v := range converted {
p.propertyMap[k] = v
}
return nil
}
package properties
import (
"fmt"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/property"
)
type (
AnyProperty struct {
SharedPropertyFields
property.PropertyDetails
}
)
func (a *AnyProperty) SetProperty(properties construct.Properties, value any) error {
return properties.SetProperty(a.Path, value)
}
func (a *AnyProperty) AppendProperty(properties construct.Properties, value any) error {
propVal, err := properties.GetProperty(a.Path)
if err != nil {
return err
}
if propVal == nil {
return a.SetProperty(properties, value)
}
return properties.AppendProperty(a.Path, value)
}
func (a *AnyProperty) RemoveProperty(properties construct.Properties, value any) error {
return properties.RemoveProperty(a.Path, value)
}
func (a *AnyProperty) Details() *property.PropertyDetails {
return &a.PropertyDetails
}
func (a *AnyProperty) Clone() property.Property {
clone := *a
return &clone
}
func (a *AnyProperty) GetDefaultValue(ctx property.ExecutionContext, data any) (any, error) {
if a.DefaultValue == nil {
return nil, nil
}
return a.Parse(a.DefaultValue, ctx, data)
}
func (a *AnyProperty) Parse(value any, ctx property.ExecutionContext, data any) (any, error) {
if val, ok := value.(string); ok {
// check if its any other template string
var result any
err := ctx.ExecuteUnmarshal(val, data, &result)
if err == nil {
return result, nil
}
}
if mapVal, ok := value.(map[string]any); ok {
m := MapProperty{KeyProperty: &StringProperty{}, ValueProperty: &AnyProperty{}}
return m.Parse(mapVal, ctx, data)
}
if listVal, ok := value.([]any); ok {
l := ListProperty{ItemProperty: &AnyProperty{}}
return l.Parse(listVal, ctx, data)
}
return value, nil
}
func (a *AnyProperty) ZeroValue() any {
return nil
}
func (a *AnyProperty) Contains(value any, contains any) bool {
if val, ok := value.(string); ok {
s := StringProperty{}
return s.Contains(val, contains)
}
if mapVal, ok := value.(map[string]any); ok {
m := MapProperty{KeyProperty: &StringProperty{}, ValueProperty: &AnyProperty{}}
return m.Contains(mapVal, contains)
}
if listVal, ok := value.([]any); ok {
l := ListProperty{ItemProperty: &AnyProperty{}}
return l.Contains(listVal, contains)
}
return false
}
func (a *AnyProperty) Type() string {
return "any"
}
func (a *AnyProperty) Validate(properties construct.Properties, value any) error {
if a.Required && value == nil {
return fmt.Errorf(property.ErrRequiredProperty, a.Path)
}
return nil
}
func (a *AnyProperty) SubProperties() property.PropertyMap {
return nil
}
package properties
import (
"errors"
"fmt"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/property"
)
type (
BoolProperty struct {
SharedPropertyFields
property.PropertyDetails
}
)
func (b *BoolProperty) SetProperty(properties construct.Properties, value any) error {
if val, ok := value.(bool); ok {
return properties.SetProperty(b.Path, val)
} else if val, ok := value.(construct.PropertyRef); ok {
return properties.SetProperty(b.Path, val)
}
return fmt.Errorf("invalid bool value %v", value)
}
func (b *BoolProperty) AppendProperty(properties construct.Properties, value any) error {
return b.SetProperty(properties, value)
}
func (b *BoolProperty) RemoveProperty(properties construct.Properties, value any) error {
propVal, err := properties.GetProperty(b.Path)
if errors.Is(err, construct.ErrPropertyDoesNotExist) {
return nil
}
if err != nil {
return err
}
if propVal == nil {
return nil
}
return properties.RemoveProperty(b.Path, value)
}
func (b *BoolProperty) Clone() property.Property {
clone := *b
return &clone
}
func (b *BoolProperty) Details() *property.PropertyDetails {
return &b.PropertyDetails
}
func (b *BoolProperty) GetDefaultValue(ctx property.ExecutionContext, data any) (any, error) {
if b.DefaultValue == nil {
return nil, nil
}
return b.Parse(b.DefaultValue, ctx, data)
}
func (b *BoolProperty) Parse(value any, ctx property.ExecutionContext, data any) (any, error) {
if val, ok := value.(string); ok {
var result bool
err := ctx.ExecuteUnmarshal(val, data, &result)
return result, err
}
if val, ok := value.(bool); ok {
return val, nil
}
return nil, fmt.Errorf("invalid bool value %v", value)
}
func (b *BoolProperty) ZeroValue() any {
return false
}
func (b *BoolProperty) Contains(value any, contains any) bool {
return false
}
func (b *BoolProperty) Type() string {
return "bool"
}
func (b *BoolProperty) Validate(properties construct.Properties, value any) error {
if value == nil {
if b.Required {
return fmt.Errorf(property.ErrRequiredProperty, b.Path)
}
return nil
}
if _, ok := value.(bool); !ok {
return fmt.Errorf("invalid bool value %v", value)
}
return nil
}
func (b *BoolProperty) SubProperties() property.PropertyMap {
return nil
}
package properties
import (
"errors"
"fmt"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/property"
"github.com/klothoplatform/klotho/pkg/k2/model"
)
type ConstructTemplateIdList []property.ConstructType
func (l ConstructTemplateIdList) MatchesAny(urn model.URN) bool {
var id property.ConstructType
err := id.FromURN(urn)
if err != nil {
return false
}
for _, t := range l {
if t == id {
return true
}
}
return false
}
type (
ConstructProperty struct {
AllowedTypes ConstructTemplateIdList
SharedPropertyFields
property.PropertyDetails
}
)
func (r *ConstructProperty) SetProperty(properties construct.Properties, value any) error {
if val, ok := value.(model.URN); ok {
return properties.SetProperty(r.Path, val)
}
return fmt.Errorf("invalid construct URN %v", value)
}
func (r *ConstructProperty) AppendProperty(properties construct.Properties, value any) error {
return r.SetProperty(properties, value)
}
func (r *ConstructProperty) RemoveProperty(properties construct.Properties, value any) error {
propVal, err := properties.GetProperty(r.Path)
if errors.Is(err, construct.ErrPropertyDoesNotExist) {
return nil
}
if err != nil {
return err
}
if propVal == nil {
return nil
}
propId, ok := propVal.(model.URN)
if !ok {
return fmt.Errorf("error attempting to remove construct property: invalid property value %v", propVal)
}
valId, ok := value.(model.URN)
if !ok {
return fmt.Errorf("error attempting to remove construct property: invalid construct value %v", value)
}
if !propId.Equals(valId) {
return fmt.Errorf("error attempting to remove construct property: construct value %v does not match property value %v", value, propVal)
}
return properties.RemoveProperty(r.Path, value)
}
func (r *ConstructProperty) Details() *property.PropertyDetails {
return &r.PropertyDetails
}
func (r *ConstructProperty) Clone() property.Property {
clone := *r
return &clone
}
func (r *ConstructProperty) GetDefaultValue(ctx property.ExecutionContext, data any) (any, error) {
if r.DefaultValue == nil {
return nil, nil
}
return r.Parse(r.DefaultValue, ctx, data)
}
func (r *ConstructProperty) Parse(value any, ctx property.ExecutionContext, data any) (any, error) {
if val, ok := value.(string); ok {
urn, err := ExecuteUnmarshalAsURN(ctx, val, data)
if err != nil {
return nil, fmt.Errorf("invalid construct URN %v", val)
}
if !urn.IsResource() || urn.Type != "construct" {
return nil, fmt.Errorf("invalid construct URN %v", urn)
}
if len(r.AllowedTypes) > 0 && !r.AllowedTypes.MatchesAny(urn) {
return nil, fmt.Errorf("construct value %v does not match allowed types %s", value, r.AllowedTypes)
}
return urn, err
}
if val, ok := value.(map[string]interface{}); ok {
id := model.URN{
AccountID: val["account"].(string),
Project: val["project"].(string),
Environment: val["environment"].(string),
Application: val["application"].(string),
Type: val["type"].(string),
Subtype: val["subtype"].(string),
ParentResourceID: val["parentResourceId"].(string),
ResourceID: val["resourceId"].(string),
}
if len(r.AllowedTypes) > 0 && !r.AllowedTypes.MatchesAny(id) {
return nil, fmt.Errorf("construct value %v does not match type %s", value, r.AllowedTypes)
}
return id, nil
}
if val, ok := value.(model.URN); ok {
if len(r.AllowedTypes) > 0 && !r.AllowedTypes.MatchesAny(val) {
return nil, fmt.Errorf("construct value %v does not match type %s", value, r.AllowedTypes)
}
return val, nil
}
return nil, fmt.Errorf("invalid construct value %v", value)
}
func (r *ConstructProperty) ZeroValue() any {
return model.URN{}
}
func (r *ConstructProperty) Contains(value any, contains any) bool {
if val, ok := value.(model.URN); ok {
if cont, ok := contains.(model.URN); ok {
return val.Equals(cont)
}
}
return false
}
func (r *ConstructProperty) Type() string {
if len(r.AllowedTypes) > 0 {
typeString := ""
for i, t := range r.AllowedTypes {
typeString += t.String()
if i < len(r.AllowedTypes)-1 {
typeString += ", "
}
}
return fmt.Sprintf("construct(%s)", typeString)
}
return "construct"
}
func (r *ConstructProperty) Validate(properties construct.Properties, value any) error {
if value == nil {
if r.Required {
return fmt.Errorf(property.ErrRequiredProperty, r.Path)
}
return nil
}
id, ok := value.(model.URN)
if !ok {
return fmt.Errorf("invalid construct URN %v", value)
}
if len(r.AllowedTypes) > 0 && !r.AllowedTypes.MatchesAny(id) {
return fmt.Errorf("value %v does not match allowed types %s", value, r.AllowedTypes)
}
return nil
}
func (r *ConstructProperty) SubProperties() property.PropertyMap {
return nil
}
package properties
import (
"errors"
"fmt"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/property"
)
type (
FloatProperty struct {
MinValue *float64
MaxValue *float64
SharedPropertyFields
property.PropertyDetails
}
)
func (f *FloatProperty) SetProperty(properties construct.Properties, value any) error {
switch val := value.(type) {
case float64:
return properties.SetProperty(f.Path, val)
case construct.PropertyRef:
return properties.SetProperty(f.Path, val)
case float32:
return properties.SetProperty(f.Path, float64(val))
case int:
return properties.SetProperty(f.Path, float64(val))
default:
return fmt.Errorf("invalid float value %v", value)
}
}
func (f *FloatProperty) AppendProperty(properties construct.Properties, value any) error {
return f.SetProperty(properties, value)
}
func (f *FloatProperty) RemoveProperty(properties construct.Properties, value any) error {
propVal, err := properties.GetProperty(f.Path)
if errors.Is(err, construct.ErrPropertyDoesNotExist) {
return nil
}
if err != nil {
return err
}
if propVal == nil {
return nil
}
return properties.RemoveProperty(f.Path, value)
}
func (f *FloatProperty) Details() *property.PropertyDetails {
return &f.PropertyDetails
}
func (f *FloatProperty) Clone() property.Property {
clone := *f
return &clone
}
func (f *FloatProperty) GetDefaultValue(ctx property.ExecutionContext, data any) (any, error) {
if f.DefaultValue == nil {
return nil, nil
}
return f.Parse(f.DefaultValue, ctx, data)
}
func (f *FloatProperty) Parse(value any, ctx property.ExecutionContext, data any) (any, error) {
if val, ok := value.(string); ok {
var result float32
err := ctx.ExecuteUnmarshal(val, data, &result)
return result, err
}
if val, ok := value.(float32); ok {
return val, nil
}
if val, ok := value.(float64); ok {
return val, nil
}
if val, ok := value.(int); ok {
return float64(val), nil
}
return nil, fmt.Errorf("invalid float value %v", value)
}
func (f *FloatProperty) ZeroValue() any {
return 0.0
}
func (f *FloatProperty) Contains(value any, contains any) bool {
return false
}
func (f *FloatProperty) Type() string {
return "float"
}
func (f *FloatProperty) Validate(properties construct.Properties, value any) error {
if value == nil {
if f.Required {
return fmt.Errorf(property.ErrRequiredProperty, f.Path)
}
return nil
}
floatVal, ok := value.(float64)
if !ok {
return fmt.Errorf("invalid float value %v", value)
}
if f.MinValue != nil && floatVal < *f.MinValue {
return fmt.Errorf("float value %f is less than lower bound %f", value, *f.MinValue)
}
if f.MaxValue != nil && floatVal > *f.MaxValue {
return fmt.Errorf("float value %f is greater than upper bound %f", value, *f.MaxValue)
}
return nil
}
func (f *FloatProperty) SubProperties() property.PropertyMap {
return nil
}
package properties
import (
"errors"
"fmt"
"math"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/property"
)
type (
IntProperty struct {
MinValue *int
MaxValue *int
SharedPropertyFields
property.PropertyDetails
}
)
func (i *IntProperty) SetProperty(properties construct.Properties, value any) error {
if val, ok := value.(int); ok {
return properties.SetProperty(i.Path, val)
} else if val, ok := value.(construct.PropertyRef); ok {
return properties.SetProperty(i.Path, val)
}
return fmt.Errorf("invalid int value %v", value)
}
func (i *IntProperty) AppendProperty(properties construct.Properties, value any) error {
return i.SetProperty(properties, value)
}
func (i *IntProperty) RemoveProperty(properties construct.Properties, value any) error {
propVal, err := properties.GetProperty(i.Path)
if errors.Is(err, construct.ErrPropertyDoesNotExist) {
return nil
}
if err != nil {
return err
}
if propVal == nil {
return nil
}
return properties.RemoveProperty(i.Path, value)
}
func (i *IntProperty) Details() *property.PropertyDetails {
return &i.PropertyDetails
}
func (i *IntProperty) Clone() property.Property {
clone := *i
return &clone
}
func (i *IntProperty) GetDefaultValue(ctx property.ExecutionContext, data any) (any, error) {
if i.DefaultValue == nil {
return nil, nil
}
return i.Parse(i.DefaultValue, ctx, data)
}
func (i *IntProperty) Parse(value any, ctx property.ExecutionContext, data any) (any, error) {
if val, ok := value.(string); ok {
var result int
err := ctx.ExecuteUnmarshal(val, data, &result)
return result, err
}
if val, ok := value.(int); ok {
return val, nil
}
EPSILON := 0.0000001
if val, ok := value.(float32); ok {
ival := int(val)
if math.Abs(float64(val)-float64(ival)) > EPSILON {
return 0, fmt.Errorf("cannot convert non-integral float to int: %f", val)
}
return int(val), nil
} else if val, ok := value.(float64); ok {
ival := int(val)
if math.Abs(val-float64(ival)) > EPSILON {
return 0, fmt.Errorf("cannot convert non-integral float to int: %f", val)
}
return int(val), nil
}
return nil, fmt.Errorf("invalid int value %v", value)
}
func (i *IntProperty) ZeroValue() any {
return 0
}
func (i *IntProperty) Contains(value any, contains any) bool {
return false
}
func (i *IntProperty) Type() string {
return "int"
}
func (i *IntProperty) Validate(properties construct.Properties, value any) error {
if value == nil {
if i.Required {
return fmt.Errorf(property.ErrRequiredProperty, i.Path)
}
return nil
}
intVal, ok := value.(int)
if !ok {
return fmt.Errorf("invalid int value %v", value)
}
if i.MinValue != nil && intVal < *i.MinValue {
return fmt.Errorf("int value %v is less than lower bound %d", value, *i.MinValue)
}
if i.MaxValue != nil && intVal > *i.MaxValue {
return fmt.Errorf("int value %v is greater than upper bound %d", value, *i.MaxValue)
}
return nil
}
func (i *IntProperty) SubProperties() property.PropertyMap {
return nil
}
package properties
import (
"errors"
"fmt"
"reflect"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/property"
)
type (
KeyValueListProperty struct {
MinLength *int
MaxLength *int
KeyProperty property.Property
ValueProperty property.Property
SharedPropertyFields
property.PropertyDetails
}
KeyValuePair struct {
Key any `json:"key"`
Value any `json:"value"`
}
)
func (kvl *KeyValueListProperty) SetProperty(properties construct.Properties, value any) error {
list, err := kvl.mapToList(value)
if err != nil {
return err
}
return properties.SetProperty(kvl.Path, list)
}
func (kvl *KeyValueListProperty) AppendProperty(properties construct.Properties, value any) error {
list, err := kvl.mapToList(value)
if err != nil {
return err
}
propVal, err := properties.GetProperty(kvl.Path)
if err != nil && !errors.Is(err, construct.ErrPropertyDoesNotExist) {
return err
}
if propVal == nil {
return properties.SetProperty(kvl.Path, list)
}
existingList, ok := propVal.([]any)
if !ok {
return fmt.Errorf("invalid existing property value %v", propVal)
}
return properties.SetProperty(kvl.Path, append(existingList, list...))
}
func (kvl *KeyValueListProperty) RemoveProperty(properties construct.Properties, value any) error {
propVal, err := properties.GetProperty(kvl.Path)
if errors.Is(err, construct.ErrPropertyDoesNotExist) {
return nil
}
if err != nil {
return err
}
if propVal == nil {
return nil
}
existingList, ok := propVal.([]any)
if !ok {
return fmt.Errorf("invalid existing property value %v", propVal)
}
removeList, err := kvl.mapToList(value)
if err != nil {
return err
}
filteredList := make([]any, 0, len(existingList))
for _, item := range existingList {
if !kvl.containsKeyValuePair(removeList, item) {
filteredList = append(filteredList, item)
}
}
return properties.SetProperty(kvl.Path, filteredList)
}
func (kvl *KeyValueListProperty) Details() *property.PropertyDetails {
return &kvl.PropertyDetails
}
func (kvl *KeyValueListProperty) Clone() property.Property {
clone := *kvl
if kvl.KeyProperty != nil {
clone.KeyProperty = kvl.KeyProperty.Clone()
}
if kvl.ValueProperty != nil {
clone.ValueProperty = kvl.ValueProperty.Clone()
}
return &clone
}
func (kvl *KeyValueListProperty) GetDefaultValue(ctx property.ExecutionContext, data any) (any, error) {
if kvl.DefaultValue == nil {
return nil, nil
}
return kvl.Parse(kvl.DefaultValue, ctx, data)
}
func (kvl *KeyValueListProperty) Parse(value any, ctx property.ExecutionContext, data any) (any, error) {
list, err := kvl.mapToList(value)
if err != nil {
return nil, err
}
result := make([]any, 0, len(list))
for _, item := range list {
pair, ok := item.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid key-value pair %v", item)
}
key, err := kvl.KeyProperty.Parse(pair[kvl.KeyPropertyName()], ctx, data)
if err != nil {
return nil, fmt.Errorf("error parsing key: %w", err)
}
value, err := kvl.ValueProperty.Parse(pair[kvl.ValuePropertyName()], ctx, data)
if err != nil {
return nil, fmt.Errorf("error parsing value: %w", err)
}
result = append(result, map[string]any{
kvl.KeyPropertyName(): key,
kvl.ValuePropertyName(): value,
})
}
return result, nil
}
func (kvl *KeyValueListProperty) KeyPropertyName() string {
return kvl.KeyProperty.Details().Name
}
func (kvl *KeyValueListProperty) ValuePropertyName() string {
return kvl.ValueProperty.Details().Name
}
func (kvl *KeyValueListProperty) ZeroValue() any {
return nil
}
func (kvl *KeyValueListProperty) Contains(value any, contains any) bool {
list, err := kvl.mapToList(value)
if err != nil {
return false
}
containsList, err := kvl.mapToList(contains)
if err != nil {
return false
}
for _, item := range containsList {
if kvl.containsKeyValuePair(list, item) {
return true
}
}
return false
}
func (kvl *KeyValueListProperty) Type() string {
return fmt.Sprintf("keyvaluelist(%s,%s)", kvl.KeyProperty.Type(), kvl.ValueProperty.Type())
}
func (kvl *KeyValueListProperty) Validate(properties construct.Properties, value any) error {
if value == nil {
if kvl.Required {
return fmt.Errorf(property.ErrRequiredProperty, kvl.Path)
}
return nil
}
list, err := kvl.mapToList(value)
if err != nil {
return err
}
if kvl.MinLength != nil && len(list) < *kvl.MinLength {
return fmt.Errorf("list value %v is too short. min length is %d", value, *kvl.MinLength)
}
if kvl.MaxLength != nil && len(list) > *kvl.MaxLength {
return fmt.Errorf("list value %v is too long. max length is %d", value, *kvl.MaxLength)
}
var errs error
for _, item := range list {
pair, ok := item.(map[string]any)
if !ok {
errs = errors.Join(errs, fmt.Errorf("invalid key-value pair %v", item))
continue
}
if err := kvl.KeyProperty.Validate(properties, pair[kvl.KeyPropertyName()]); err != nil {
errs = errors.Join(errs, fmt.Errorf("invalid key %v: %w", pair[kvl.KeyPropertyName()], err))
}
if err := kvl.ValueProperty.Validate(properties, pair[kvl.ValuePropertyName()]); err != nil {
errs = errors.Join(errs, fmt.Errorf("invalid value %v: %w", pair[kvl.ValuePropertyName()], err))
}
}
return errs
}
func (kvl *KeyValueListProperty) SubProperties() property.PropertyMap {
return nil
}
func (kvl *KeyValueListProperty) mapToList(value any) ([]any, error) {
switch v := value.(type) {
case []any:
return v, nil
case map[string]any:
result := make([]any, 0, len(v))
for key, val := range v {
result = append(result, map[string]any{
kvl.KeyPropertyName(): key,
kvl.ValuePropertyName(): val,
})
}
return result, nil
default:
return nil, fmt.Errorf("invalid input type for KeyValueListProperty: %T", value)
}
}
func (kvl *KeyValueListProperty) containsKeyValuePair(list []any, item any) bool {
for _, listItem := range list {
if reflect.DeepEqual(listItem, item) {
return true
}
}
return false
}
package properties
import (
"errors"
"fmt"
"reflect"
"strings"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/property"
"github.com/klothoplatform/klotho/pkg/collectionutil"
)
type (
ListProperty struct {
MinLength *int
MaxLength *int
ItemProperty property.Property
Properties property.PropertyMap
SharedPropertyFields
property.PropertyDetails
}
)
func (l *ListProperty) SetProperty(properties construct.Properties, value any) error {
if val, ok := value.([]any); ok {
return properties.SetProperty(l.Path, val)
}
return fmt.Errorf("invalid list value %v", value)
}
func (l *ListProperty) AppendProperty(properties construct.Properties, value any) error {
propVal, err := properties.GetProperty(l.Path)
if err != nil && !errors.Is(err, construct.ErrPropertyDoesNotExist) {
return err
}
if propVal == nil {
err := l.SetProperty(properties, []any{})
if err != nil {
return err
}
}
if l.ItemProperty != nil && !strings.HasPrefix(l.ItemProperty.Type(), "list") {
if reflect.ValueOf(value).Kind() == reflect.Slice || reflect.ValueOf(value).Kind() == reflect.Array {
var errs error
for i := 0; i < reflect.ValueOf(value).Len(); i++ {
err := properties.AppendProperty(l.Path, reflect.ValueOf(value).Index(i).Interface())
if err != nil {
errs = errors.Join(errs, err)
}
}
return errs
}
}
return properties.AppendProperty(l.Path, value)
}
func (l *ListProperty) RemoveProperty(properties construct.Properties, value any) error {
propVal, err := properties.GetProperty(l.Path)
if errors.Is(err, construct.ErrPropertyDoesNotExist) {
return nil
}
if err != nil {
return err
}
if propVal == nil {
return nil
}
if l.ItemProperty != nil && !strings.HasPrefix(l.ItemProperty.Type(), "list") {
if reflect.ValueOf(value).Kind() == reflect.Slice || reflect.ValueOf(value).Kind() == reflect.Array {
var errs error
for i := 0; i < reflect.ValueOf(value).Len(); i++ {
err := properties.RemoveProperty(l.Path, reflect.ValueOf(value).Index(i).Interface())
if err != nil {
errs = errors.Join(errs, err)
}
}
return errs
}
}
return properties.RemoveProperty(l.Path, value)
}
func (l *ListProperty) Details() *property.PropertyDetails {
return &l.PropertyDetails
}
func (l *ListProperty) Clone() property.Property {
var itemProp property.Property
if l.ItemProperty != nil {
itemProp = l.ItemProperty.Clone()
}
var props property.PropertyMap
if l.Properties != nil {
props = l.Properties.Clone()
}
clone := *l
clone.ItemProperty = itemProp
clone.Properties = props
return &clone
}
func (list *ListProperty) GetDefaultValue(ctx property.ExecutionContext, data any) (any, error) {
if list.DefaultValue == nil {
return nil, nil
}
return list.Parse(list.DefaultValue, ctx, data)
}
func (list *ListProperty) Parse(value any, ctx property.ExecutionContext, data any) (any, error) {
var result []any
val, ok := value.([]any)
if !ok {
// before we fail, check to see if the entire value is a template
if strVal, ok := value.(string); ok {
var result []any
err := ctx.ExecuteUnmarshal(strVal, data, &result)
if err != nil {
return nil, fmt.Errorf("invalid list value %v: %w", value, err)
}
val = result
} else {
return nil, fmt.Errorf("invalid list value %v", value)
}
}
for _, v := range val {
if len(list.Properties) != 0 {
m := MapProperty{Properties: list.Properties}
val, err := m.Parse(v, ctx, data)
if err != nil {
return nil, err
}
result = append(result, val)
} else {
val, err := list.ItemProperty.Parse(v, ctx, data)
if err != nil {
return nil, err
}
result = append(result, val)
}
}
return result, nil
}
func (l *ListProperty) ZeroValue() any {
return nil
}
func (l *ListProperty) Contains(value any, contains any) bool {
list, ok := value.([]any)
if !ok {
return false
}
containsList, ok := contains.([]any)
if !ok {
return collectionutil.Contains(list, contains)
}
for _, v := range list {
for _, cv := range containsList {
if reflect.DeepEqual(v, cv) {
return true
}
}
}
return false
}
func (l *ListProperty) Type() string {
if l.ItemProperty != nil {
return fmt.Sprintf("list(%s)", l.ItemProperty.Type())
}
return "list"
}
func (l *ListProperty) Validate(properties construct.Properties, value any) error {
if value == nil {
if l.Required {
return fmt.Errorf(property.ErrRequiredProperty, l.Path)
}
return nil
}
listVal, ok := value.([]any)
if !ok {
return fmt.Errorf("invalid list value %v", value)
}
if l.MinLength != nil {
if len(listVal) < *l.MinLength {
return fmt.Errorf("list value %v is too short. min length is %d", value, *l.MinLength)
}
}
if l.MaxLength != nil {
if len(listVal) > *l.MaxLength {
return fmt.Errorf("list value %v is too long. max length is %d", value, *l.MaxLength)
}
}
validList := make([]any, len(listVal))
var errs error
hasSanitized := false
for i, v := range listVal {
if l.ItemProperty != nil {
err := l.ItemProperty.Validate(properties, v)
if err != nil {
var sanitizeErr *property.SanitizeError
if errors.As(err, &sanitizeErr) {
validList[i] = sanitizeErr.Sanitized
hasSanitized = true
} else {
errs = errors.Join(errs, err)
}
} else {
validList[i] = v
}
} else {
vmap, ok := v.(map[string]any)
if !ok {
return fmt.Errorf("invalid value for list index %d in sub properties validation: expected map[string]any got %T", i, v)
}
validIndex := make(map[string]any)
for _, prop := range l.SubProperties() {
val, ok := vmap[prop.Details().Name]
if !ok {
continue
}
err := prop.Validate(properties, val)
if err != nil {
var sanitizeErr *property.SanitizeError
if errors.As(err, &sanitizeErr) {
validIndex[prop.Details().Name] = sanitizeErr.Sanitized
hasSanitized = true
} else {
errs = errors.Join(errs, err)
}
} else {
validIndex[prop.Details().Name] = val
}
}
validList[i] = validIndex
}
}
if errs != nil {
return errs
}
if hasSanitized {
return &property.SanitizeError{
Input: listVal,
Sanitized: validList,
}
}
return nil
}
func (l *ListProperty) SubProperties() property.PropertyMap {
return l.Properties
}
func (l *ListProperty) Item() property.Property {
return l.ItemProperty
}
package properties
import (
"errors"
"fmt"
"reflect"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/property"
)
type (
MapProperty struct {
MinLength *int
MaxLength *int
KeyProperty property.Property
ValueProperty property.Property
Properties property.PropertyMap
SharedPropertyFields
property.PropertyDetails
}
)
func (m *MapProperty) SetProperty(properties construct.Properties, value any) error {
if val, ok := value.(map[string]any); ok {
return properties.SetProperty(m.Path, val)
}
return fmt.Errorf("invalid properties value %v", value)
}
func (m *MapProperty) AppendProperty(properties construct.Properties, value any) error {
return properties.AppendProperty(m.Path, value)
}
func (m *MapProperty) RemoveProperty(properties construct.Properties, value any) error {
propVal, err := properties.GetProperty(m.Path)
if errors.Is(err, construct.ErrPropertyDoesNotExist) {
return nil
}
if err != nil {
return err
}
if propVal == nil {
return nil
}
propMap, ok := propVal.(map[string]any)
if !ok {
return fmt.Errorf("error attempting to remove map property: invalid property value %v", propVal)
}
if val, ok := value.(map[string]any); ok {
for k, v := range val {
if val, found := propMap[k]; found && reflect.DeepEqual(val, v) {
delete(propMap, k)
}
}
return properties.SetProperty(m.Path, propMap)
}
return properties.RemoveProperty(m.Path, value)
}
func (m *MapProperty) Details() *property.PropertyDetails {
return &m.PropertyDetails
}
func (m *MapProperty) Clone() property.Property {
var keyProp property.Property
if m.KeyProperty != nil {
keyProp = m.KeyProperty.Clone()
}
var valProp property.Property
if m.ValueProperty != nil {
valProp = m.ValueProperty.Clone()
}
var props property.PropertyMap
if m.Properties != nil {
props = m.Properties.Clone()
}
clone := *m
clone.KeyProperty = keyProp
clone.ValueProperty = valProp
clone.Properties = props
return &clone
}
func (m *MapProperty) GetDefaultValue(ctx property.ExecutionContext, data any) (any, error) {
if m.DefaultValue == nil {
return nil, nil
}
return m.Parse(m.DefaultValue, ctx, data)
}
func (m *MapProperty) Parse(value any, ctx property.ExecutionContext, data any) (any, error) {
result := map[string]any{}
mapVal, ok := value.(map[string]any)
if !ok {
// before we fail, check to see if the entire value is a template
if strVal, ok := value.(string); ok {
err := ctx.ExecuteUnmarshal(strVal, data, &result)
return result, err
}
mapVal, ok = value.(construct.Properties)
if !ok {
return nil, fmt.Errorf("invalid map value %v", value)
}
}
// If we are an object with sub properties then we know that we need to get the type of our sub properties to determine how we are parsed into a value
if len(m.Properties) != 0 {
var errs error
for key, prop := range m.Properties {
if _, found := mapVal[key]; found {
val, err := prop.Parse(mapVal[key], ctx, data)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("unable to parse value for sub property %s: %w", key, err))
continue
}
result[key] = val
} else {
val, err := prop.GetDefaultValue(ctx, data)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("unable to get default value for sub property %s: %w", key, err))
continue
}
if val == nil {
continue
}
result[key] = val
}
}
}
if m.KeyProperty == nil || m.ValueProperty == nil {
return result, nil
}
// Else we are a set type of map and can just loop over the values
for key, v := range mapVal {
keyVal, err := m.KeyProperty.Parse(key, ctx, data)
if err != nil {
return nil, err
}
val, err := m.ValueProperty.Parse(v, ctx, data)
if err != nil {
return nil, err
}
switch keyVal := keyVal.(type) {
case string:
result[keyVal] = val
//case constructs.ConstructId:
// result[keyVal.String()] = val
//case construct.PropertyRef:
// result[keyVal.String()] = val
default:
return nil, fmt.Errorf("invalid key type for map property type %s", keyVal)
}
}
return result, nil
}
func (m *MapProperty) ZeroValue() any {
return nil
}
func (m *MapProperty) Contains(value any, contains any) bool {
mapVal, ok := value.(map[string]any)
if !ok {
return false
}
containsMap, ok := contains.(map[string]any)
if !ok {
return false
}
for k, v := range containsMap {
if val, found := mapVal[k]; found || reflect.DeepEqual(val, v) {
return true
}
}
for _, v := range mapVal {
for _, cv := range containsMap {
if reflect.DeepEqual(v, cv) {
return true
}
}
}
return false
}
func (m *MapProperty) Type() string {
if m.KeyProperty != nil && m.ValueProperty != nil {
return fmt.Sprintf("map(%s,%s)", m.KeyProperty.Type(), m.ValueProperty.Type())
}
return "map"
}
func (m *MapProperty) Validate(properties construct.Properties, value any) error {
if value == nil {
if m.Required {
return fmt.Errorf(property.ErrRequiredProperty, m.Path)
}
return nil
}
mapVal, ok := value.(map[string]any)
if !ok {
return fmt.Errorf("invalid map value %v", value)
}
if m.MinLength != nil {
if len(mapVal) < *m.MinLength {
return fmt.Errorf("map value %v is too short. min length is %d", value, *m.MinLength)
}
}
if m.MaxLength != nil {
if len(mapVal) > *m.MaxLength {
return fmt.Errorf("map value %v is too long. max length is %d", value, *m.MaxLength)
}
}
var errs error
hasSanitized := false
validMap := make(map[string]any)
// Only validate values if it's a primitive map, otherwise let the sub properties handle their own validation
for k, v := range mapVal {
if m.KeyProperty != nil {
var sanitizeErr *property.SanitizeError
if err := m.KeyProperty.Validate(properties, k); errors.As(err, &sanitizeErr) {
k = sanitizeErr.Sanitized.(string)
hasSanitized = true
} else if err != nil {
errs = errors.Join(errs, fmt.Errorf("invalid key %v for map property type %s: %w", k, m.KeyProperty.Type(), err))
}
}
if m.ValueProperty != nil {
var sanitizeErr *property.SanitizeError
if err := m.ValueProperty.Validate(properties, v); errors.As(err, &sanitizeErr) {
v = sanitizeErr.Sanitized
hasSanitized = true
} else if err != nil {
errs = errors.Join(errs, fmt.Errorf("invalid value %v for map property type %s: %w", v, m.ValueProperty.Type(), err))
}
}
validMap[k] = v
}
if errs != nil {
return errs
}
if hasSanitized {
return &property.SanitizeError{
Input: mapVal,
Sanitized: validMap,
}
}
return nil
}
func (m *MapProperty) SubProperties() property.PropertyMap {
return m.Properties
}
func (m *MapProperty) Key() property.Property {
return m.KeyProperty
}
func (m *MapProperty) Value() property.Property {
return m.ValueProperty
}
package properties
import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/property"
"github.com/klothoplatform/klotho/pkg/collectionutil"
)
type (
PathProperty struct {
SanitizeTmpl *property.SanitizeTmpl
AllowedValues []string
SharedPropertyFields
property.PropertyDetails
RelativeTo string
}
)
func (p *PathProperty) SetProperty(properties construct.Properties, value any) error {
strVal, ok := value.(string)
if !ok {
return fmt.Errorf("value %v is not a string", value)
}
if strVal == "" {
return properties.SetProperty(p.Path, "")
}
path, err := resolvePath(strVal, p.RelativeTo)
if err != nil {
return err
}
return properties.SetProperty(p.Path, path)
}
func (p *PathProperty) AppendProperty(properties construct.Properties, value any) error {
return p.SetProperty(properties, value)
}
func (p *PathProperty) RemoveProperty(properties construct.Properties, value any) error {
propVal, err := properties.GetProperty(p.Path)
if errors.Is(err, construct.ErrPropertyDoesNotExist) {
return nil
}
if err != nil {
return err
}
if propVal == nil {
return nil
}
return properties.RemoveProperty(p.Path, nil)
}
func (p *PathProperty) Details() *property.PropertyDetails {
return &p.PropertyDetails
}
func (p *PathProperty) Clone() property.Property {
clone := *p
return &clone
}
func (p *PathProperty) GetDefaultValue(ctx property.ExecutionContext, data any) (any, error) {
if p.DefaultValue == nil {
return p.ZeroValue(), nil
}
return p.Parse(p.DefaultValue, ctx, data)
}
func (p *PathProperty) Parse(value any, ctx property.ExecutionContext, data any) (any, error) {
strVal := ""
switch val := value.(type) {
case string:
err := ctx.ExecuteUnmarshal(val, data, &strVal)
if err != nil {
return nil, err
}
strVal = val
case int, int32, int64, float32, float64, bool:
strVal = fmt.Sprintf("%v", val)
default:
return nil, fmt.Errorf("could not parse string property: invalid string value %v (%[1]T)", value)
}
if strVal == "" {
return "", nil
}
return resolvePath(strVal, p.RelativeTo)
}
func (p *PathProperty) ZeroValue() any {
return ""
}
func (p *PathProperty) Contains(value any, contains any) bool {
vString, ok := value.(string)
if !ok {
return false
}
cString, ok := contains.(string)
if !ok {
return false
}
return strings.Contains(vString, cString)
}
func (p *PathProperty) Type() string {
return "string"
}
func (p *PathProperty) Validate(properties construct.Properties, value any) error {
if value == nil {
if p.Required {
return fmt.Errorf(property.ErrRequiredProperty, p.Path)
}
return nil
}
stringVal, ok := value.(string)
if !ok {
return fmt.Errorf("value %v is not a string", value)
}
if len(p.AllowedValues) > 0 && !collectionutil.Contains(p.AllowedValues, stringVal) {
return fmt.Errorf("value %s is not allowed. allowed values are %s", stringVal, p.AllowedValues)
}
if p.SanitizeTmpl != nil {
return p.SanitizeTmpl.Check(stringVal)
}
return nil
}
func (p *PathProperty) SubProperties() property.PropertyMap {
return nil
}
func resolvePath(path string, basePath string) (string, error) {
// If the path is absolute, return it as is
if filepath.IsAbs(path) {
return path, nil
}
// Otherwise, make it relative to the base path or the current working directory
if basePath == "" {
var err error
basePath, err = os.Getwd()
if err != nil {
return "", fmt.Errorf("could not get working directory")
}
}
abs, err := filepath.Abs(filepath.Join(basePath, path))
if err != nil {
return "", fmt.Errorf("could not resolve path %s: %w", path, err)
}
return abs, nil
}
package properties
import (
"bytes"
"encoding"
"encoding/json"
"fmt"
"reflect"
"strconv"
"strings"
"text/template"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/property"
"github.com/klothoplatform/klotho/pkg/k2/model"
"github.com/klothoplatform/klotho/pkg/templateutils"
)
type (
DefaultExecutionContext struct{}
)
// UnmarshalFunc decodes data into the supplied pointer, v
type UnmarshalFunc func(data *bytes.Buffer, v any) error
func (d DefaultExecutionContext) ExecuteUnmarshal(tmpl string, data any, v any) error {
parsedTemplate, err := template.New("tmpl").Funcs(templateutils.WithCommonFuncs(template.FuncMap{})).Parse(tmpl)
if err != nil {
return err
}
return ExecuteTemplateUnmarshal(parsedTemplate, data, v, d.Unmarshal)
}
func (d DefaultExecutionContext) Unmarshal(data *bytes.Buffer, v any) error {
return UnmarshalAny(data, v)
}
// ExecuteTemplateUnmarshal executes the [template.Template], t, using data and unmarshals the value into v
func ExecuteTemplateUnmarshal(
t *template.Template,
data any,
v any,
unmarshal UnmarshalFunc,
) error {
buf := new(bytes.Buffer)
if err := t.Execute(buf, data); err != nil {
return err
}
if err := unmarshal(buf, v); err != nil {
return fmt.Errorf("cannot decode template result '%s' into %T", buf, v)
}
return nil
}
func UnmarshalJSON(data *bytes.Buffer, outputRefValue any) error {
dec := json.NewDecoder(data)
return dec.Decode(outputRefValue)
}
// UnmarshalAny decodes the template result into a primitive or a struct that implements [encoding.TextUnmarshaler].
// As a fallback, it tries to unmarshal the result using [json.Unmarshal].
// If v is a pointer, it will be set to the decoded value.
func UnmarshalAny(data *bytes.Buffer, v any) error {
// trim the spaces, so you don't have to sprinkle the templates with `{{-` and `-}}` (the `-` trims spaces)
bstr := strings.TrimSpace(data.String())
switch value := v.(type) {
case *string:
*value = bstr
return nil
case *[]byte:
*value = []byte(bstr)
return nil
case *bool:
result := strings.ToLower(bstr)
// If the input (eg 'field') is nil and the 'if' statement just uses '{{ inputs "field" }}',
// then the string result will be '<no value>'.
// Make sure we don't interpret that as a true condition.
*value = result != "" && result != "<no value>" && strings.ToLower(result) != "false"
return nil
case *int:
i, err := strconv.Atoi(bstr)
if err != nil {
return err
}
*value = i
return nil
case *float64:
f, err := strconv.ParseFloat(bstr, 64)
if err != nil {
return err
}
*value = f
return nil
case *float32:
f, err := strconv.ParseFloat(bstr, 32)
if err != nil {
return err
}
*value = float32(f)
return nil
case encoding.TextUnmarshaler:
// notably, this handles `construct.ResourceId` and `construct.IaCValue`
return value.UnmarshalText([]byte(bstr))
}
resultStr := reflect.ValueOf(data.String())
valueRefl := reflect.ValueOf(v).Elem()
if resultStr.Type().AssignableTo(valueRefl.Type()) {
// this covers alias types like `type MyString string`
valueRefl.Set(resultStr)
return nil
}
err := json.Unmarshal([]byte(bstr), v)
if err == nil {
return nil
}
return err
}
func ExecuteUnmarshalAsURN(ctx property.ExecutionContext, tmpl string, data any) (model.URN, error) {
var selector model.URN
err := ctx.ExecuteUnmarshal(tmpl, data, &selector)
if err != nil {
return selector, err
}
if selector.IsZero() {
return selector, fmt.Errorf("selector '%s' is zero", tmpl)
}
return selector, nil
}
package properties
import (
"errors"
"fmt"
"reflect"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/property"
"github.com/klothoplatform/klotho/pkg/set"
)
type (
SetProperty struct {
MinLength *int
MaxLength *int
ItemProperty property.Property
Properties property.PropertyMap
SharedPropertyFields
property.PropertyDetails
}
)
func (s *SetProperty) SetProperty(properties construct.Properties, value any) error {
switch val := value.(type) {
case set.HashedSet[string, any]:
return properties.SetProperty(s.Path, val)
}
if val, ok := value.(set.HashedSet[string, any]); ok {
return properties.SetProperty(s.Path, val)
}
return fmt.Errorf("invalid set value %v", value)
}
func (s *SetProperty) AppendProperty(properties construct.Properties, value any) error {
propVal, err := properties.GetProperty(s.Path)
if err != nil && !errors.Is(err, construct.ErrPropertyDoesNotExist) {
return err
}
if propVal == nil {
if val, ok := value.(set.HashedSet[string, any]); ok {
return s.SetProperty(properties, val)
}
}
return properties.AppendProperty(s.Path, value)
}
func (s *SetProperty) RemoveProperty(properties construct.Properties, value any) error {
propVal, err := properties.GetProperty(s.Path)
if errors.Is(err, construct.ErrPropertyDoesNotExist) {
return nil
}
if err != nil {
return err
}
if propVal == nil {
return nil
}
propSet, ok := propVal.(set.HashedSet[string, any])
if !ok {
return errors.New("invalid set value")
}
if val, ok := value.(set.HashedSet[string, any]); ok {
for _, v := range val.ToSlice() {
propSet.Remove(v)
}
} else {
return fmt.Errorf("invalid set value %v", value)
}
return s.SetProperty(properties, propSet)
}
func (s *SetProperty) Details() *property.PropertyDetails {
return &s.PropertyDetails
}
func (s *SetProperty) Clone() property.Property {
var itemProp property.Property
if s.ItemProperty != nil {
itemProp = s.ItemProperty.Clone()
}
var props property.PropertyMap
if s.Properties != nil {
props = s.Properties.Clone()
}
clone := *s
clone.ItemProperty = itemProp
clone.Properties = props
return &clone
}
func (s *SetProperty) GetDefaultValue(ctx property.ExecutionContext, data any) (any, error) {
if s.DefaultValue == nil {
return nil, nil
}
return s.Parse(s.DefaultValue, ctx, data)
}
func (s *SetProperty) Parse(value any, ctx property.ExecutionContext, data any) (any, error) {
var result = set.HashedSet[string, any]{
Hasher: func(s any) string {
return fmt.Sprintf("%v", s)
},
Less: func(s1, s2 string) bool {
return s1 < s2
},
}
var vals []any
if valSet, ok := value.(set.HashedSet[string, any]); ok {
vals = valSet.ToSlice()
} else if val, ok := value.([]any); ok {
vals = val
} else {
// before we fail, check to see if the entire value is a template
if strVal, ok := value.(string); ok {
err := ctx.ExecuteUnmarshal(strVal, data, &vals)
if err != nil {
return nil, err
}
}
}
for _, v := range vals {
if len(s.Properties) != 0 {
m := MapProperty{Properties: s.Properties}
val, err := m.Parse(v, ctx, data)
if err != nil {
return nil, err
}
result.Add(val)
} else {
val, err := s.ItemProperty.Parse(v, ctx, data)
if err != nil {
return nil, err
}
result.Add(val)
}
}
return result, nil
}
func (s *SetProperty) ZeroValue() any {
return nil
}
func (s *SetProperty) Contains(value any, contains any) bool {
valSet, ok := value.(set.HashedSet[string, any])
if !ok {
return false
}
for _, val := range valSet.M {
if reflect.DeepEqual(contains, val) {
return true
}
}
return false
}
func (s *SetProperty) Type() string {
if s.ItemProperty != nil {
return fmt.Sprintf("set(%s)", s.ItemProperty.Type())
}
return "set"
}
func (s *SetProperty) Validate(properties construct.Properties, value any) error {
if value == nil {
if s.Required {
return fmt.Errorf(property.ErrRequiredProperty, s.Path)
}
return nil
}
setVal, ok := value.(set.HashedSet[string, any])
if !ok {
return fmt.Errorf("could not validate set property: invalid set value %v", value)
}
if s.MinLength != nil {
if setVal.Len() < *s.MinLength {
return fmt.Errorf("value %s is too short. minimum length is %d", setVal.M, *s.MinLength)
}
}
if s.MaxLength != nil {
if setVal.Len() > *s.MaxLength {
return fmt.Errorf("value %s is too long. maximum length is %d", setVal.M, *s.MaxLength)
}
}
// Only validate values if its a primitive list, otherwise let the sub properties handle their own validation
if s.ItemProperty != nil {
var errs error
hasSanitized := false
validSet := set.HashedSet[string, any]{Hasher: setVal.Hasher}
for _, item := range setVal.ToSlice() {
if err := s.ItemProperty.Validate(properties, item); err != nil {
var sanitizeErr *property.SanitizeError
if errors.As(err, &sanitizeErr) {
validSet.Add(sanitizeErr.Sanitized)
hasSanitized = true
} else {
errs = errors.Join(errs, fmt.Errorf("invalid item %v: %v", item, err))
}
} else {
validSet.Add(item)
}
}
if errs != nil {
return errs
}
if hasSanitized {
return &property.SanitizeError{
Input: setVal,
Sanitized: validSet,
}
}
}
return nil
}
func (s *SetProperty) SubProperties() property.PropertyMap {
return s.Properties
}
func (s *SetProperty) Item() property.Property {
return s.ItemProperty
}
package properties
import (
"bytes"
"fmt"
"text/template"
"github.com/klothoplatform/klotho/pkg/construct"
)
type (
SharedPropertyFields struct {
DefaultValue any `json:"default_value" yaml:"default_value"`
ValidityChecks []PropertyValidityCheck
}
PropertyValidityCheck struct {
template *template.Template
}
ValidityCheckData struct {
Properties construct.Properties `json:"properties" yaml:"properties"`
Value any `json:"value" yaml:"value"`
}
)
func (p *PropertyValidityCheck) Validate(value any, properties construct.Properties) error {
var buff bytes.Buffer
data := ValidityCheckData{
Properties: properties,
Value: value,
}
err := p.template.Execute(&buff, data)
if err != nil {
return err
}
result := buff.String()
if result != "" {
return fmt.Errorf("invalid value %v: %s", value, result)
}
return nil
}
package properties
import (
"errors"
"fmt"
"strings"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/property"
"github.com/klothoplatform/klotho/pkg/collectionutil"
)
type (
StringProperty struct {
SanitizeTmpl *property.SanitizeTmpl
AllowedValues []string
SharedPropertyFields
property.PropertyDetails
}
)
func (str *StringProperty) SetProperty(properties construct.Properties, value any) error {
if val, ok := value.(string); ok {
return properties.SetProperty(str.Path, val)
} else if val, ok := value.(construct.PropertyRef); ok {
return properties.SetProperty(str.Path, val)
}
return fmt.Errorf("could not set string property: invalid string value %v", value)
}
func (str *StringProperty) AppendProperty(properties construct.Properties, value any) error {
return str.SetProperty(properties, value)
}
func (str *StringProperty) RemoveProperty(properties construct.Properties, value any) error {
propVal, err := properties.GetProperty(str.Path)
if errors.Is(err, construct.ErrPropertyDoesNotExist) {
return nil
}
if err != nil {
return err
}
if propVal == nil {
return nil
}
return properties.RemoveProperty(str.Path, nil)
}
func (str *StringProperty) Details() *property.PropertyDetails {
return &str.PropertyDetails
}
func (str *StringProperty) Clone() property.Property {
clone := *str
return &clone
}
func (str *StringProperty) GetDefaultValue(ctx property.ExecutionContext, data any) (any, error) {
if str.DefaultValue == nil {
return nil, nil
}
return str.Parse(str.DefaultValue, ctx, data)
}
func (str *StringProperty) Parse(value any, ctx property.ExecutionContext, data any) (any, error) {
switch val := value.(type) {
case string:
err := ctx.ExecuteUnmarshal(val, data, &val)
return val, err
case int, int32, int64, float32, float64, bool:
return fmt.Sprintf("%v", val), nil
}
return nil, fmt.Errorf("could not parse string property: invalid string value %v (%[1]T)", value)
}
func (str *StringProperty) ZeroValue() any {
return ""
}
func (str *StringProperty) Contains(value any, contains any) bool {
vString, ok := value.(string)
if !ok {
return false
}
cString, ok := contains.(string)
if !ok {
return false
}
return strings.Contains(vString, cString)
}
func (str *StringProperty) Type() string {
return "string"
}
func (str *StringProperty) Validate(properties construct.Properties, value any) error {
if value == nil {
if str.Required {
return fmt.Errorf(property.ErrRequiredProperty, str.Path)
}
return nil
}
stringVal, ok := value.(string)
if !ok {
return fmt.Errorf("value %v is not a string", value)
}
if len(str.AllowedValues) > 0 && !collectionutil.Contains(str.AllowedValues, stringVal) {
return fmt.Errorf("value %s is not allowed. allowed values are %s", stringVal, str.AllowedValues)
}
if str.SanitizeTmpl != nil {
return str.SanitizeTmpl.Check(stringVal)
}
return nil
}
func (str *StringProperty) SubProperties() property.PropertyMap {
return nil
}
package property
import (
"fmt"
"regexp"
"strings"
"github.com/klothoplatform/klotho/pkg/k2/model"
"gopkg.in/yaml.v3"
)
type ConstructReference struct {
URN model.URN `yaml:"urn" json:"urn"`
Path string `yaml:"path" json:"path"`
}
type ConstructType struct {
Package string `yaml:"package"`
Name string `yaml:"name"`
}
var constructTypeRegexp = regexp.MustCompile(`^(?:([\w-]+)\.)+([\w-]+)$`)
func (c *ConstructType) UnmarshalYAML(value *yaml.Node) error {
var typeString string
err := value.Decode(&typeString)
if err != nil {
return fmt.Errorf("failed to decode construct type: %w", err)
}
if !constructTypeRegexp.MatchString(typeString) {
return fmt.Errorf("invalid construct type: %s", typeString)
}
lastDot := strings.LastIndex(typeString, ".")
c.Name = typeString[lastDot+1:]
c.Package = typeString[:lastDot]
return nil
}
func (c *ConstructType) String() string {
return fmt.Sprintf("%s.%s", c.Package, c.Name)
}
func (c *ConstructType) FromString(id string) error {
parts := strings.Split(id, ".")
if len(parts) < 2 {
return fmt.Errorf("invalid construct template id: %s", id)
}
c.Package = strings.Join(parts[:len(parts)-1], ".")
c.Name = parts[len(parts)-1]
return nil
}
func ParseConstructType(id string) (ConstructType, error) {
var c ConstructType
err := c.FromString(id)
return c, err
}
func (c *ConstructType) FromURN(urn model.URN) error {
if urn.Type != "construct" {
return fmt.Errorf("invalid urn type: %s", urn.Type)
}
return c.FromString(urn.Subtype)
}
package property
import (
"errors"
"fmt"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/set"
"reflect"
"sort"
"strings"
)
// PropertyMap is a map of properties that can be used to represent complex data structures in a template
// Wrap this in a struct that implements the [Properties] interface when using it in a template
type PropertyMap map[string]Property
func (p PropertyMap) Clone() PropertyMap {
newProps := make(PropertyMap, len(p))
for k, v := range p {
newProps[k] = v.Clone()
}
return newProps
}
func (p PropertyMap) Get(key string) (Property, bool) {
value, exists := p[key]
return value, exists
}
func (p PropertyMap) Set(key string, value Property) {
p[key] = value
}
func (p PropertyMap) Remove(key string) {
delete(p, key)
}
func (p PropertyMap) ForEach(c construct.Properties, f func(p Property) error) error {
queue := []PropertyMap{p}
var props PropertyMap
var errs error
for len(queue) > 0 {
props, queue = queue[0], queue[1:]
propKeys := make([]string, 0, len(props))
for k := range props {
propKeys = append(propKeys, k)
}
sort.Strings(propKeys)
for _, key := range propKeys {
prop := props[key]
err := f(prop)
if err != nil {
if errors.Is(err, ErrStopWalk) {
return nil
}
errs = errors.Join(errs, err)
continue
}
if strings.HasPrefix(prop.Type(), "list") || strings.HasPrefix(prop.Type(), "set") {
p, err := c.GetProperty(prop.Details().Path)
if err != nil || p == nil {
continue
}
// Because lists/sets will start as empty, do not recurse into their sub-properties if it's not set.
// To allow for defaults within list objects and operational rules to be run,
// we will look inside the property to see if there are values.
if strings.HasPrefix(prop.Type(), "list") {
length := reflect.ValueOf(p).Len()
for i := 0; i < length; i++ {
subProperties := make(PropertyMap)
for subK, subProp := range prop.SubProperties() {
propTemplate := subProp.Clone()
ReplacePath(propTemplate, prop.Details().Path, fmt.Sprintf("%s[%d]", prop.Details().Path, i))
subProperties[subK] = propTemplate
}
if len(subProperties) > 0 {
queue = append(queue, subProperties)
}
}
} else if strings.HasPrefix(prop.Type(), "set") {
hs, ok := p.(set.HashedSet[string, any])
if !ok {
errs = errors.Join(errs, fmt.Errorf("could not cast property to set"))
continue
}
for i := range hs.ToSlice() {
subProperties := make(PropertyMap)
for subK, subProp := range prop.SubProperties() {
propTemplate := subProp.Clone()
ReplacePath(propTemplate, prop.Details().Path, fmt.Sprintf("%s[%d]", prop.Details().Path, i))
subProperties[subK] = propTemplate
}
if len(subProperties) > 0 {
queue = append(queue, subProperties)
}
}
}
} else if prop.SubProperties() != nil {
queue = append(queue, prop.SubProperties())
}
}
}
return errs
}
func GetProperty(properties PropertyMap, path string) Property {
fields := strings.Split(path, ".")
FIELDS:
for i, field := range fields {
currFieldName := strings.Split(field, "[")[0]
found := false
for name, property := range properties {
if name != currFieldName {
continue
}
found = true
if len(fields) == i+1 {
// use a clone resource so we can modify the name in case anywhere in the path
// has index strings or map keys
clone := property.Clone()
details := clone.Details()
details.Path = path
return clone
} else {
properties = property.SubProperties()
if len(properties) == 0 {
if mp, ok := property.(MapProperty); ok {
clone := mp.Value().Clone()
details := clone.Details()
details.Path = path
return clone
} else if cp, ok := property.(CollectionProperty); ok {
clone := cp.Item().Clone()
details := clone.Details()
details.Path = path
return clone
}
}
}
continue FIELDS
}
if !found {
return nil
}
}
return nil
}
package property
import (
"bytes"
"crypto/sha256"
"fmt"
"regexp"
"strings"
"sync"
"text/template"
)
type (
SanitizeTmpl struct {
template *template.Template
}
// SanitizeError is returned when a value is sanitized if the input is not valid. The Sanitized field
// is always the same type as the Input field.
SanitizeError struct {
Input any
Sanitized any
}
)
func NewSanitizationTmpl(name string, tmpl string) (*SanitizeTmpl, error) {
t, err := template.New(name + "/sanitize").
Funcs(template.FuncMap{
"replace": func(pattern, replace, name string) (string, error) {
re, err := regexp.Compile(pattern)
if err != nil {
return name, err
}
return re.ReplaceAllString(name, replace), nil
},
"length": func(min, max int, name string) string {
if len(name) < min {
return name + strings.Repeat("0", min-len(name))
}
if len(name) > max {
base := name[:max-8]
h := sha256.New()
fmt.Fprint(h, name)
x := fmt.Sprintf("%x", h.Sum(nil))
return base + x[:8]
}
return name
},
"lower": strings.ToLower,
"upper": strings.ToUpper,
}).
Parse(tmpl)
return &SanitizeTmpl{
template: t,
}, err
}
var sanitizeBufs = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}
func (t SanitizeTmpl) Execute(value string) (string, error) {
buf := sanitizeBufs.Get().(*bytes.Buffer)
defer sanitizeBufs.Put(buf)
buf.Reset()
err := t.template.Execute(buf, value)
if err != nil {
return value, fmt.Errorf("could not execute sanitize name template on %q: %w", value, err)
}
return strings.TrimSpace(buf.String()), nil
}
func (t SanitizeTmpl) Check(value string) error {
sanitized, err := t.Execute(value)
if err != nil {
return err
}
if sanitized != value {
return &SanitizeError{
Input: value,
Sanitized: sanitized,
}
}
return nil
}
func (err SanitizeError) Error() string {
return fmt.Sprintf("invalid value %q, suggested value: %q", err.Input, err.Sanitized)
}
package property
import (
"errors"
"strings"
)
const ErrRequiredProperty = "required property %s is not set"
var ErrStopWalk = errors.New("stop walk")
// ReplacePath runs a simple [strings.ReplaceAll] on the path of the property and all of its sub properties.
// NOTE: this mutates the property, so make sure to [Property.Clone] it first if you don't want that.
func ReplacePath(p Property, original, replacement string) {
p.Details().Path = strings.ReplaceAll(p.Details().Path, original, replacement)
for _, prop := range p.SubProperties() {
ReplacePath(prop, original, replacement)
}
}
package template
import (
"fmt"
"github.com/klothoplatform/klotho/pkg/k2/model"
"github.com/klothoplatform/klotho/pkg/reflectutil"
"reflect"
"text/template"
)
type (
ResourceRef struct {
ConstructURN model.URN
ResourceKey string
Property string
Type ResourceRefType
}
ResourceRefType string
InterpolationSourceKey string
InterpolationSource interface {
GetPropertySource() *PropertySource
}
PropertySource struct {
source reflect.Value
}
TemplateFuncSupplier interface {
GetTemplateFuncs() template.FuncMap
}
)
const (
// ResourceRefTypeTemplate is a reference to a resource template and will be fully resolved prior to constraint generation
// e.g., ${resources:resourceName.property} or ${resources:resourceName}
ResourceRefTypeTemplate ResourceRefType = "template"
// ResourceRefTypeIaC is a reference to an infrastructure as code resource that will be resolved by the engine
// e.g., ${resources:resourceName#property}
ResourceRefTypeIaC ResourceRefType = "iac"
// ResourceRefTypeInterpolated is an initial interpolation reference to a resource.
// An interpolated value will be evaluated during initial processing and will be converted to one of the other types.
ResourceRefTypeInterpolated ResourceRefType = "interpolated"
)
func (r *ResourceRef) String() string {
if r.Type == ResourceRefTypeIaC {
return fmt.Sprintf("%s#%s", r.ResourceKey, r.Property)
}
return r.ResourceKey
}
func NewPropertySource(source any) *PropertySource {
var v reflect.Value
if sv, ok := source.(reflect.Value); ok {
v = sv
} else {
v = reflect.ValueOf(source)
}
return &PropertySource{
source: v,
}
}
func (p *PropertySource) GetProperty(key string) (value any, ok bool) {
v, err := reflectutil.GetField(p.source, key)
if err != nil || !v.IsValid() {
return nil, false
}
return v.Interface(), true
}
func GetTypedProperty[T any](source *PropertySource, key string) (T, bool) {
var typedField T
v, ok := source.GetProperty(key)
if !ok {
return typedField, false
}
return reflectutil.GetTypedValue[T](v)
}
package template
import (
"embed"
"fmt"
"github.com/klothoplatform/klotho/pkg/k2/constructs/template/property"
"path/filepath"
"strings"
"sync"
"gopkg.in/yaml.v3"
)
//go:embed templates/*
var templates embed.FS
var (
cachedConstructs = make(map[property.ConstructType]ConstructTemplate)
cachedBindings = make(map[string]BindingTemplate)
mu sync.Mutex
)
func LoadConstructTemplate(id property.ConstructType) (ConstructTemplate, error) {
mu.Lock()
defer mu.Unlock()
if ct, ok := cachedConstructs[id]; ok {
return ct, nil
}
if !strings.HasPrefix(id.Package, "klotho.") {
return ConstructTemplate{}, fmt.Errorf("invalid package: %s", id.Package)
}
constructDir, err := getConstructTemplateDir(id)
if err != nil {
return ConstructTemplate{}, err
}
constructKey := strings.ToLower(id.Name)
fileContent, err := templates.ReadFile(filepath.Join(constructDir, constructKey+".yaml"))
if err != nil {
return ConstructTemplate{}, fmt.Errorf("failed to read file: %w", err)
}
var ct ConstructTemplate
if err := yaml.Unmarshal(fileContent, &ct); err != nil {
return ConstructTemplate{}, fmt.Errorf("failed to unmarshal yaml: %w", err)
}
cachedConstructs[ct.Id] = ct
return ct, nil
}
func LoadBindingTemplate(owner property.ConstructType, from property.ConstructType, to property.ConstructType) (BindingTemplate, error) {
mu.Lock()
defer mu.Unlock()
if owner != from && owner != to {
return BindingTemplate{}, fmt.Errorf("owner must be either from or to")
}
// binding key name depends on whether the owner is from or to
// if the owner is from, the key is to_<to_name>
// if the owner is to, the key is from_<from_name>
// this is because the binding template is stored in the directory of the owner
// and each binding may have a separate template file for both the from and to constructs
var bindingKey string
if owner == from {
bindingKey = "to_" + to.String()
} else {
bindingKey = "from_" + from.String()
}
cacheKey := fmt.Sprintf("%s/%s", owner.String(), bindingKey)
if ct, ok := cachedBindings[cacheKey]; ok {
return ct, nil
}
constructDir, err := getConstructTemplateDir(owner)
if err != nil {
return BindingTemplate{}, err
}
bindingsDir := filepath.Join(constructDir, "bindings")
// Read the YAML fileContent
fileContent, err := templates.ReadFile(filepath.Join(bindingsDir, bindingKey+".yaml"))
if err != nil {
return BindingTemplate{}, fmt.Errorf("binding template %s (%s -> %s) not found: %w", owner.String(), from.String(), to.String(), err)
}
// Unmarshal the YAML fileContent into a map
var ct BindingTemplate
if err := yaml.Unmarshal(fileContent, &ct); err != nil {
return BindingTemplate{}, fmt.Errorf("failed to unmarshal yaml: %w", err)
}
// Cache the binding template for future use
cachedBindings[cacheKey] = ct
return ct, nil
}
func getConstructTemplateDir(id property.ConstructType) (string, error) {
// trim the klotho package prefix
parts := strings.SplitN(id.Package, ".", 2)
if len(parts) < 2 {
return "", fmt.Errorf("invalid package: %s", id.Package)
}
parts = strings.Split(parts[1], ".")
return strings.ToLower(filepath.Join(append(append([]string{"templates"}, parts...), id.Name)...)), nil
}
package initialize
import (
"context"
"embed"
"fmt"
"github.com/klothoplatform/klotho/pkg/command"
"github.com/klothoplatform/klotho/pkg/k2/cleanup"
"github.com/klothoplatform/klotho/pkg/logging"
"os"
"os/exec"
"path/filepath"
"strings"
"syscall"
"text/template"
)
//go:embed templates/python/infra.py.tmpl
var files embed.FS
type ApplicationRequest struct {
Context context.Context
ProjectName string
AppName string
Environment string
OutputDirectory string
DefaultRegion string
Runtime string
ProgramFileName string
SkipInstall bool
}
func Application(request ApplicationRequest) error {
outDir := request.OutputDirectory
if outDir == "" {
outDir = "."
}
if request.Runtime != "python" {
return fmt.Errorf("unsupported runtime: %s", request.Runtime)
}
if !strings.HasSuffix(request.ProgramFileName, ".py") {
request.ProgramFileName = fmt.Sprintf("%s.py", request.ProgramFileName)
}
if err := createOutputDirectory(outDir); err != nil {
return err
}
fmt.Println("Creating Python program...")
if err := createProgramFile(outDir, request.ProgramFileName, request); err != nil {
return err
}
fmt.Printf("Created %s\n", filepath.Join(outDir, request.ProgramFileName))
if !request.SkipInstall {
if err := updatePipfile(request.Context, outDir); err != nil {
return err
}
}
return nil
}
func getPipCommand() (string, error) {
if _, err := exec.LookPath("pip3"); err == nil {
return "pip3", nil
}
if _, err := exec.LookPath("pip"); err == nil {
return "pip", nil
}
return "", fmt.Errorf("pip not found")
}
// updatePipfile updates the Pipfile in the output directory with the necessary dependencies by invoking pipenv install
func updatePipfile(ctx context.Context, outDir string) error {
// check if pipenv is installed and if not, install it
if _, err := exec.LookPath("pipenv"); err != nil {
fmt.Println("pipenv not found, installing pipenv")
pip, err := getPipCommand()
if err != nil {
return err
}
if err = runCommand(ctx, outDir, pip, []string{"install", "pipenv"}); err != nil {
return fmt.Errorf("failed to install pipenv: %w", err)
}
fmt.Println("pipenv installed successfully")
}
fmt.Println("Installing klotho python SDK...")
// Install the necessary dependencies
if err := runCommand(ctx, outDir, "pipenv", []string{"install", "-d", "klotho"}); err != nil {
return fmt.Errorf("failed to install klotho python SDK: %w", err)
}
fmt.Println("klotho python SDK installed successfully")
return nil
}
func runCommand(ctx context.Context, dir string, name string, args []string) error {
log := logging.GetLogger(ctx).Sugar()
cmd := exec.CommandContext(ctx, name, args...)
cmd.Dir = dir
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
command.SetProcAttr(cmd)
cleanup.OnKill(func(signal syscall.Signal) error {
cleanup.SignalProcessGroup(cmd.Process.Pid, syscall.SIGTERM)
return nil
})
log.Debugf("Executing: %s for %v", cmd.Path, cmd.Args)
err := cmd.Run()
if err != nil {
log.Errorf("%s process exited with error: %v", name, err)
}
log.Debugf("%s process exited successfully", name)
return err
}
func createProgramFile(outDir, programFileName string, request ApplicationRequest) error {
programTemplateContent, err := files.ReadFile("templates/python/infra.py.tmpl")
if err != nil {
return err
}
programFile, err := os.Create(filepath.Join(outDir, programFileName))
if err != nil {
return err
}
defer programFile.Close()
program, err := template.New("program").Parse(string(programTemplateContent))
if err != nil {
return err
}
err = program.Execute(programFile, request)
if err != nil {
return err
}
return nil
}
func createOutputDirectory(outDir string) error {
// Create the output directory if it doesn't exist
if _, err := os.Stat(outDir); os.IsNotExist(err) {
if err := os.Mkdir(outDir, 0755); err != nil {
return err
}
}
return nil
}
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.34.1
// protoc v5.27.0
// source: service.proto
package _go
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type IRRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Filename string `protobuf:"bytes,1,opt,name=filename,proto3" json:"filename,omitempty"`
}
func (x *IRRequest) Reset() {
*x = IRRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_service_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *IRRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*IRRequest) ProtoMessage() {}
func (x *IRRequest) ProtoReflect() protoreflect.Message {
mi := &file_service_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use IRRequest.ProtoReflect.Descriptor instead.
func (*IRRequest) Descriptor() ([]byte, []int) {
return file_service_proto_rawDescGZIP(), []int{0}
}
func (x *IRRequest) GetFilename() string {
if x != nil {
return x.Filename
}
return ""
}
type IRReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"`
YamlPayload string `protobuf:"bytes,2,opt,name=yaml_payload,json=yamlPayload,proto3" json:"yaml_payload,omitempty"`
}
func (x *IRReply) Reset() {
*x = IRReply{}
if protoimpl.UnsafeEnabled {
mi := &file_service_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *IRReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*IRReply) ProtoMessage() {}
func (x *IRReply) ProtoReflect() protoreflect.Message {
mi := &file_service_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use IRReply.ProtoReflect.Descriptor instead.
func (*IRReply) Descriptor() ([]byte, []int) {
return file_service_proto_rawDescGZIP(), []int{1}
}
func (x *IRReply) GetMessage() string {
if x != nil {
return x.Message
}
return ""
}
func (x *IRReply) GetYamlPayload() string {
if x != nil {
return x.YamlPayload
}
return ""
}
type HealthCheckRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *HealthCheckRequest) Reset() {
*x = HealthCheckRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_service_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *HealthCheckRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*HealthCheckRequest) ProtoMessage() {}
func (x *HealthCheckRequest) ProtoReflect() protoreflect.Message {
mi := &file_service_proto_msgTypes[2]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use HealthCheckRequest.ProtoReflect.Descriptor instead.
func (*HealthCheckRequest) Descriptor() ([]byte, []int) {
return file_service_proto_rawDescGZIP(), []int{2}
}
type HealthCheckReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Status string `protobuf:"bytes,1,opt,name=status,proto3" json:"status,omitempty"`
}
func (x *HealthCheckReply) Reset() {
*x = HealthCheckReply{}
if protoimpl.UnsafeEnabled {
mi := &file_service_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *HealthCheckReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*HealthCheckReply) ProtoMessage() {}
func (x *HealthCheckReply) ProtoReflect() protoreflect.Message {
mi := &file_service_proto_msgTypes[3]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use HealthCheckReply.ProtoReflect.Descriptor instead.
func (*HealthCheckReply) Descriptor() ([]byte, []int) {
return file_service_proto_rawDescGZIP(), []int{3}
}
func (x *HealthCheckReply) GetStatus() string {
if x != nil {
return x.Status
}
return ""
}
type RegisterConstructRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
YamlPayload string `protobuf:"bytes,1,opt,name=yaml_payload,json=yamlPayload,proto3" json:"yaml_payload,omitempty"`
}
func (x *RegisterConstructRequest) Reset() {
*x = RegisterConstructRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_service_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *RegisterConstructRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*RegisterConstructRequest) ProtoMessage() {}
func (x *RegisterConstructRequest) ProtoReflect() protoreflect.Message {
mi := &file_service_proto_msgTypes[4]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use RegisterConstructRequest.ProtoReflect.Descriptor instead.
func (*RegisterConstructRequest) Descriptor() ([]byte, []int) {
return file_service_proto_rawDescGZIP(), []int{4}
}
func (x *RegisterConstructRequest) GetYamlPayload() string {
if x != nil {
return x.YamlPayload
}
return ""
}
type RegisterConstructReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"`
YamlPayload string `protobuf:"bytes,2,opt,name=yaml_payload,json=yamlPayload,proto3" json:"yaml_payload,omitempty"`
}
func (x *RegisterConstructReply) Reset() {
*x = RegisterConstructReply{}
if protoimpl.UnsafeEnabled {
mi := &file_service_proto_msgTypes[5]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *RegisterConstructReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*RegisterConstructReply) ProtoMessage() {}
func (x *RegisterConstructReply) ProtoReflect() protoreflect.Message {
mi := &file_service_proto_msgTypes[5]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use RegisterConstructReply.ProtoReflect.Descriptor instead.
func (*RegisterConstructReply) Descriptor() ([]byte, []int) {
return file_service_proto_rawDescGZIP(), []int{5}
}
func (x *RegisterConstructReply) GetMessage() string {
if x != nil {
return x.Message
}
return ""
}
func (x *RegisterConstructReply) GetYamlPayload() string {
if x != nil {
return x.YamlPayload
}
return ""
}
var File_service_proto protoreflect.FileDescriptor
var file_service_proto_rawDesc = []byte{
0x0a, 0x0d, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12,
0x06, 0x6b, 0x6c, 0x6f, 0x74, 0x68, 0x6f, 0x22, 0x27, 0x0a, 0x09, 0x49, 0x52, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x66, 0x69, 0x6c, 0x65, 0x6e, 0x61, 0x6d, 0x65,
0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x66, 0x69, 0x6c, 0x65, 0x6e, 0x61, 0x6d, 0x65,
0x22, 0x46, 0x0a, 0x07, 0x49, 0x52, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x6d,
0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65,
0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x21, 0x0a, 0x0c, 0x79, 0x61, 0x6d, 0x6c, 0x5f, 0x70, 0x61,
0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x79, 0x61, 0x6d,
0x6c, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x22, 0x14, 0x0a, 0x12, 0x48, 0x65, 0x61, 0x6c,
0x74, 0x68, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x2a,
0x0a, 0x10, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x52, 0x65, 0x70,
0x6c, 0x79, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, 0x01,
0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x3d, 0x0a, 0x18, 0x52, 0x65,
0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x73, 0x74, 0x72, 0x75, 0x63, 0x74, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x21, 0x0a, 0x0c, 0x79, 0x61, 0x6d, 0x6c, 0x5f, 0x70,
0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x79, 0x61,
0x6d, 0x6c, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x22, 0x55, 0x0a, 0x16, 0x52, 0x65, 0x67,
0x69, 0x73, 0x74, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x73, 0x74, 0x72, 0x75, 0x63, 0x74, 0x52, 0x65,
0x70, 0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01,
0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x21, 0x0a,
0x0c, 0x79, 0x61, 0x6d, 0x6c, 0x5f, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20,
0x01, 0x28, 0x09, 0x52, 0x0b, 0x79, 0x61, 0x6d, 0x6c, 0x50, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64,
0x32, 0xdf, 0x01, 0x0a, 0x0d, 0x4b, 0x6c, 0x6f, 0x74, 0x68, 0x6f, 0x53, 0x65, 0x72, 0x76, 0x69,
0x63, 0x65, 0x12, 0x2e, 0x0a, 0x06, 0x53, 0x65, 0x6e, 0x64, 0x49, 0x52, 0x12, 0x11, 0x2e, 0x6b,
0x6c, 0x6f, 0x74, 0x68, 0x6f, 0x2e, 0x49, 0x52, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a,
0x0f, 0x2e, 0x6b, 0x6c, 0x6f, 0x74, 0x68, 0x6f, 0x2e, 0x49, 0x52, 0x52, 0x65, 0x70, 0x6c, 0x79,
0x22, 0x00, 0x12, 0x45, 0x0a, 0x0b, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x43, 0x68, 0x65, 0x63,
0x6b, 0x12, 0x1a, 0x2e, 0x6b, 0x6c, 0x6f, 0x74, 0x68, 0x6f, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74,
0x68, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x18, 0x2e,
0x6b, 0x6c, 0x6f, 0x74, 0x68, 0x6f, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x43, 0x68, 0x65,
0x63, 0x6b, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x57, 0x0a, 0x11, 0x52, 0x65, 0x67,
0x69, 0x73, 0x74, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x73, 0x74, 0x72, 0x75, 0x63, 0x74, 0x12, 0x20,
0x2e, 0x6b, 0x6c, 0x6f, 0x74, 0x68, 0x6f, 0x2e, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72,
0x43, 0x6f, 0x6e, 0x73, 0x74, 0x72, 0x75, 0x63, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x1a, 0x1e, 0x2e, 0x6b, 0x6c, 0x6f, 0x74, 0x68, 0x6f, 0x2e, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74,
0x65, 0x72, 0x43, 0x6f, 0x6e, 0x73, 0x74, 0x72, 0x75, 0x63, 0x74, 0x52, 0x65, 0x70, 0x6c, 0x79,
0x22, 0x00, 0x42, 0x3a, 0x5a, 0x38, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d,
0x2f, 0x6b, 0x6c, 0x6f, 0x74, 0x68, 0x6f, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x2f,
0x6b, 0x6c, 0x6f, 0x74, 0x68, 0x6f, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x6b, 0x32, 0x2f, 0x6c, 0x61,
0x6e, 0x67, 0x75, 0x61, 0x67, 0x65, 0x5f, 0x68, 0x6f, 0x73, 0x74, 0x2f, 0x67, 0x6f, 0x62, 0x06,
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_service_proto_rawDescOnce sync.Once
file_service_proto_rawDescData = file_service_proto_rawDesc
)
func file_service_proto_rawDescGZIP() []byte {
file_service_proto_rawDescOnce.Do(func() {
file_service_proto_rawDescData = protoimpl.X.CompressGZIP(file_service_proto_rawDescData)
})
return file_service_proto_rawDescData
}
var file_service_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
var file_service_proto_goTypes = []interface{}{
(*IRRequest)(nil), // 0: klotho.IRRequest
(*IRReply)(nil), // 1: klotho.IRReply
(*HealthCheckRequest)(nil), // 2: klotho.HealthCheckRequest
(*HealthCheckReply)(nil), // 3: klotho.HealthCheckReply
(*RegisterConstructRequest)(nil), // 4: klotho.RegisterConstructRequest
(*RegisterConstructReply)(nil), // 5: klotho.RegisterConstructReply
}
var file_service_proto_depIdxs = []int32{
0, // 0: klotho.KlothoService.SendIR:input_type -> klotho.IRRequest
2, // 1: klotho.KlothoService.HealthCheck:input_type -> klotho.HealthCheckRequest
4, // 2: klotho.KlothoService.RegisterConstruct:input_type -> klotho.RegisterConstructRequest
1, // 3: klotho.KlothoService.SendIR:output_type -> klotho.IRReply
3, // 4: klotho.KlothoService.HealthCheck:output_type -> klotho.HealthCheckReply
5, // 5: klotho.KlothoService.RegisterConstruct:output_type -> klotho.RegisterConstructReply
3, // [3:6] is the sub-list for method output_type
0, // [0:3] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_service_proto_init() }
func file_service_proto_init() {
if File_service_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_service_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*IRRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_service_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*IRReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_service_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*HealthCheckRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_service_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*HealthCheckReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_service_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*RegisterConstructRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_service_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*RegisterConstructReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_service_proto_rawDesc,
NumEnums: 0,
NumMessages: 6,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_service_proto_goTypes,
DependencyIndexes: file_service_proto_depIdxs,
MessageInfos: file_service_proto_msgTypes,
}.Build()
File_service_proto = out.File
file_service_proto_rawDesc = nil
file_service_proto_goTypes = nil
file_service_proto_depIdxs = nil
}
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.3.0
// - protoc v5.27.0
// source: service.proto
package _go
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
const (
KlothoService_SendIR_FullMethodName = "/klotho.KlothoService/SendIR"
KlothoService_HealthCheck_FullMethodName = "/klotho.KlothoService/HealthCheck"
KlothoService_RegisterConstruct_FullMethodName = "/klotho.KlothoService/RegisterConstruct"
)
// KlothoServiceClient is the client API for KlothoService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type KlothoServiceClient interface {
SendIR(ctx context.Context, in *IRRequest, opts ...grpc.CallOption) (*IRReply, error)
HealthCheck(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (*HealthCheckReply, error)
RegisterConstruct(ctx context.Context, in *RegisterConstructRequest, opts ...grpc.CallOption) (*RegisterConstructReply, error)
}
type klothoServiceClient struct {
cc grpc.ClientConnInterface
}
func NewKlothoServiceClient(cc grpc.ClientConnInterface) KlothoServiceClient {
return &klothoServiceClient{cc}
}
func (c *klothoServiceClient) SendIR(ctx context.Context, in *IRRequest, opts ...grpc.CallOption) (*IRReply, error) {
out := new(IRReply)
err := c.cc.Invoke(ctx, KlothoService_SendIR_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *klothoServiceClient) HealthCheck(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (*HealthCheckReply, error) {
out := new(HealthCheckReply)
err := c.cc.Invoke(ctx, KlothoService_HealthCheck_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *klothoServiceClient) RegisterConstruct(ctx context.Context, in *RegisterConstructRequest, opts ...grpc.CallOption) (*RegisterConstructReply, error) {
out := new(RegisterConstructReply)
err := c.cc.Invoke(ctx, KlothoService_RegisterConstruct_FullMethodName, in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// KlothoServiceServer is the server API for KlothoService service.
// All implementations must embed UnimplementedKlothoServiceServer
// for forward compatibility
type KlothoServiceServer interface {
SendIR(context.Context, *IRRequest) (*IRReply, error)
HealthCheck(context.Context, *HealthCheckRequest) (*HealthCheckReply, error)
RegisterConstruct(context.Context, *RegisterConstructRequest) (*RegisterConstructReply, error)
mustEmbedUnimplementedKlothoServiceServer()
}
// UnimplementedKlothoServiceServer must be embedded to have forward compatible implementations.
type UnimplementedKlothoServiceServer struct {
}
func (UnimplementedKlothoServiceServer) SendIR(context.Context, *IRRequest) (*IRReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method SendIR not implemented")
}
func (UnimplementedKlothoServiceServer) HealthCheck(context.Context, *HealthCheckRequest) (*HealthCheckReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method HealthCheck not implemented")
}
func (UnimplementedKlothoServiceServer) RegisterConstruct(context.Context, *RegisterConstructRequest) (*RegisterConstructReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method RegisterConstruct not implemented")
}
func (UnimplementedKlothoServiceServer) mustEmbedUnimplementedKlothoServiceServer() {}
// UnsafeKlothoServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to KlothoServiceServer will
// result in compilation errors.
type UnsafeKlothoServiceServer interface {
mustEmbedUnimplementedKlothoServiceServer()
}
func RegisterKlothoServiceServer(s grpc.ServiceRegistrar, srv KlothoServiceServer) {
s.RegisterService(&KlothoService_ServiceDesc, srv)
}
func _KlothoService_SendIR_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(IRRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(KlothoServiceServer).SendIR(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: KlothoService_SendIR_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(KlothoServiceServer).SendIR(ctx, req.(*IRRequest))
}
return interceptor(ctx, in, info, handler)
}
func _KlothoService_HealthCheck_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(HealthCheckRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(KlothoServiceServer).HealthCheck(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: KlothoService_HealthCheck_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(KlothoServiceServer).HealthCheck(ctx, req.(*HealthCheckRequest))
}
return interceptor(ctx, in, info, handler)
}
func _KlothoService_RegisterConstruct_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RegisterConstructRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(KlothoServiceServer).RegisterConstruct(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: KlothoService_RegisterConstruct_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(KlothoServiceServer).RegisterConstruct(ctx, req.(*RegisterConstructRequest))
}
return interceptor(ctx, in, info, handler)
}
// KlothoService_ServiceDesc is the grpc.ServiceDesc for KlothoService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var KlothoService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "klotho.KlothoService",
HandlerType: (*KlothoServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "SendIR",
Handler: _KlothoService_SendIR_Handler,
},
{
MethodName: "HealthCheck",
Handler: _KlothoService_HealthCheck_Handler,
},
{
MethodName: "RegisterConstruct",
Handler: _KlothoService_RegisterConstruct_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "service.proto",
}
package language_host
import (
"context"
"errors"
"fmt"
"os/exec"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
pb "github.com/klothoplatform/klotho/pkg/k2/language_host/go"
"github.com/klothoplatform/klotho/pkg/k2/model"
"github.com/klothoplatform/klotho/pkg/logging"
)
type LanguageHost struct {
debugCfg DebugConfig
langHost *exec.Cmd
conn *grpc.ClientConn
}
func (irs *LanguageHost) Start(ctx context.Context, debug DebugConfig, pythonPath string) (err error) {
log := logging.GetLogger(ctx).Sugar()
irs.debugCfg = debug
var srvState *ServerState
irs.langHost, srvState, err = StartPythonClient(ctx, debug, pythonPath)
if err != nil {
return
}
log.Debug("Waiting for Python server to start")
if debug.Enabled() {
// Don't add a timeout in case there are breakpoints in the language host before an address is printed
<-srvState.Done
} else {
select {
case <-srvState.Done:
case <-time.After(30 * time.Second):
return errors.New("timeout waiting for Python server to start")
}
}
if srvState.Error != nil {
return srvState.Error
}
irs.conn, err = grpc.NewClient(srvState.Address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return fmt.Errorf("failed to connect to Python server: %w", err)
}
return nil
}
func (irs *LanguageHost) NewClient() pb.KlothoServiceClient {
return pb.NewKlothoServiceClient(irs.conn)
}
func (irs *LanguageHost) GetIR(ctx context.Context, req *pb.IRRequest) (*model.ApplicationEnvironment, error) {
// Don't set the timeout if debugging, otherwise it may timeout while at a breakpoint or waiting to connect
// to the debug server
if !irs.debugCfg.Enabled() {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Second*10)
defer cancel()
}
client := irs.NewClient()
res, err := client.SendIR(ctx, req)
if err != nil {
return nil, fmt.Errorf("error sending IR request: %w", err)
}
ir, err := model.ParseIRFile([]byte(res.GetYamlPayload()))
if err != nil {
return nil, fmt.Errorf("error parsing IR file: %w", err)
}
return ir, nil
}
func (irs *LanguageHost) Close() error {
var errs []error
if conn := irs.conn; conn != nil {
errs = append(errs, conn.Close())
}
if p := irs.langHost.Process; p != nil {
errs = append(errs, p.Kill())
}
return errors.Join(errs...)
}
package language_host
import (
"context"
_ "embed"
"errors"
"fmt"
"io"
"os"
"os/exec"
"regexp"
"strings"
"syscall"
"github.com/klothoplatform/klotho/pkg/command"
"github.com/klothoplatform/klotho/pkg/k2/cleanup"
"github.com/klothoplatform/klotho/pkg/logging"
"go.uber.org/zap"
)
//go:embed python/python_language_host.py
var pythonLanguageHost string
type ServerState struct {
Log *zap.SugaredLogger
Address string
Error error
Done chan struct{}
}
func NewServerState(log *zap.SugaredLogger) *ServerState {
return &ServerState{
Log: log,
Done: make(chan struct{}),
}
}
var listenOnPattern = regexp.MustCompile(`(?m)^\s*Listening on (\S+)$`)
var exceptionPattern = regexp.MustCompile(`(?s)(?:^|\n)\s*Exception occurred: (.+)$`)
func (f *ServerState) Write(b []byte) (int, error) {
if f.Address != "" || f.Error != nil {
return len(b), nil
}
s := string(b)
// captures a fatal error in the language host that occurs before the address is printed to stdout
if matches := exceptionPattern.FindStringSubmatch(s); len(matches) >= 2 {
f.Error = errors.New(strings.TrimSpace(matches[1]))
f.Log.Debug(s)
close(f.Done)
// captures the gRPC server address
} else if matches := listenOnPattern.FindStringSubmatch(s); len(matches) >= 2 {
f.Address = matches[1]
f.Log.Debugf("Found language host listening on %s", f.Address)
close(f.Done)
}
return len(b), nil
}
type DebugConfig struct {
Port int
Mode string
}
func (cfg DebugConfig) Enabled() bool {
return cfg.Mode != ""
}
func copyToTempDir(name, content string) (string, error) {
f, err := os.CreateTemp("", fmt.Sprintf("k2_%s*.py", name))
if err != nil {
return "", fmt.Errorf("failed to create temp file: %w", err)
}
defer f.Close()
if _, err := f.WriteString(content); err != nil {
return "", fmt.Errorf("failed to write to temp file: %w", err)
}
return f.Name(), nil
}
func StartPythonClient(ctx context.Context, debugConfig DebugConfig, pythonPath string) (*exec.Cmd, *ServerState, error) {
log := logging.GetLogger(ctx).Sugar()
hostPath, err := copyToTempDir("python_language_host", pythonLanguageHost)
if err != nil {
return nil, nil, fmt.Errorf("could not copy python language host to temp dir: %w", err)
}
args := []string{"run", "python", hostPath}
if debugConfig.Enabled() {
if debugConfig.Port > 0 {
args = append(args, "--debug-port", fmt.Sprintf("%d", debugConfig.Port))
}
if debugConfig.Mode != "" {
args = append(args, "--debug", debugConfig.Mode)
}
}
cmd := logging.Command(
ctx,
logging.CommandLogger{
RootLogger: log.Desugar().Named("python"),
StdoutLevel: zap.DebugLevel,
StderrLevel: zap.DebugLevel,
},
"pipenv", args...,
)
if cmd.Env == nil {
cmd.Env = os.Environ()
}
cmd.Env = append(cmd.Env, "PYTHONPATH="+pythonPath)
lf := NewServerState(log)
cmd.Stdout = io.MultiWriter(cmd.Stdout, lf)
command.SetProcAttr(cmd)
cleanup.OnKill(func(signal syscall.Signal) error {
cleanup.SignalProcessGroup(cmd.Process.Pid, syscall.SIGTERM)
return nil
})
log.Debugf("Executing: %s for %v", cmd.Path, cmd.Args)
if err := cmd.Start(); err != nil {
return nil, nil, fmt.Errorf("failed to start Python client: %w", err)
}
log.Debug("Python client started")
go func() {
err := cmd.Wait()
if err != nil {
log.Debugf("Python process exited with error: %v", err)
} else {
log.Debug("Python process exited successfully")
}
}()
return cmd, lf, nil
}
package model
type ConstructState struct {
Status ConstructStatus `yaml:"status,omitempty"`
LastUpdated string `yaml:"last_updated,omitempty"`
Inputs map[string]Input `yaml:"inputs,omitempty"`
Outputs map[string]any `yaml:"outputs,omitempty"`
Bindings []Binding `yaml:"bindings,omitempty"`
Options map[string]any `yaml:"options,omitempty"`
DependsOn []*URN `yaml:"dependsOn,omitempty"`
PulumiStack UUID `yaml:"pulumi_stack,omitempty"`
URN *URN `yaml:"urn,omitempty"`
}
type ConstructStatus string
const (
// Create-related statuses
ConstructCreating ConstructStatus = "creating"
ConstructCreateComplete ConstructStatus = "create_complete"
ConstructCreateFailed ConstructStatus = "create_failed"
// Update-related statuses
ConstructUpdating ConstructStatus = "updating"
ConstructUpdateComplete ConstructStatus = "update_complete"
ConstructUpdateFailed ConstructStatus = "update_failed"
// Delete-related statuses
ConstructDeleting ConstructStatus = "deleting"
ConstructDeleteComplete ConstructStatus = "delete_complete"
ConstructDeleteFailed ConstructStatus = "delete_failed"
// Unknown status
ConstructUnknown ConstructStatus = "unknown"
)
var validTransitions = map[ConstructStatus][]ConstructStatus{
ConstructCreating: {ConstructCreating, ConstructCreateComplete, ConstructCreateFailed},
ConstructCreateComplete: {ConstructUpdating, ConstructDeleting},
ConstructCreateFailed: {ConstructCreating, ConstructDeleting},
ConstructUpdating: {ConstructUpdating, ConstructUpdateComplete, ConstructUpdateFailed},
ConstructUpdateComplete: {ConstructUpdating, ConstructDeleting},
ConstructUpdateFailed: {ConstructUpdating, ConstructDeleting},
ConstructDeleting: {ConstructDeleting, ConstructDeleteComplete, ConstructDeleteFailed},
ConstructDeleteComplete: {ConstructCreating},
ConstructDeleteFailed: {ConstructUpdating, ConstructDeleting},
ConstructUnknown: {},
}
func IsUpdatable(status ConstructStatus) bool {
for _, nextStatus := range validTransitions[status] {
if nextStatus == ConstructUpdating {
return true
}
}
return false
}
func IsCreatable(status ConstructStatus) bool {
for _, nextStatus := range validTransitions[status] {
if nextStatus == ConstructCreating {
return true
}
}
return false
}
func IsDeployable(status ConstructStatus) bool {
return IsCreatable(status) || IsUpdatable(status)
}
func IsDeletable(status ConstructStatus) bool {
for _, nextStatus := range validTransitions[status] {
if nextStatus == ConstructDeleting {
return true
}
}
return false
}
func isValidTransition(currentStatus, nextStatus ConstructStatus) bool {
validTransitions, exists := validTransitions[currentStatus]
if !exists {
return false
}
for _, validStatus := range validTransitions {
if validStatus == nextStatus {
return true
}
}
return false
}
type (
ConstructAction string
)
const (
ConstructActionCreate ConstructAction = "create"
ConstructActionUpdate ConstructAction = "update"
ConstructActionDelete ConstructAction = "delete"
)
package model
import (
"os"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
)
type ApplicationEnvironment struct {
SchemaVersion int `yaml:"schemaVersion,omitempty"`
Version int `yaml:"version,omitempty"`
ProjectURN URN `yaml:"project_urn,omitempty"`
AppURN URN `yaml:"app_urn,omitempty"`
Environment string `yaml:"environment,omitempty"`
Constructs map[string]Construct `yaml:"constructs,omitempty"`
DefaultRegion string `yaml:"default_region,omitempty"`
}
type Construct struct {
URN *URN `yaml:"urn,omitempty"`
Version int `yaml:"version,omitempty"`
Inputs map[string]Input `yaml:"inputs,omitempty"`
Outputs map[string]any `yaml:"outputs,omitempty"`
Bindings []Binding `yaml:"bindings,omitempty"`
Options map[string]interface{} `yaml:"options,omitempty"`
DependsOn []*URN `yaml:"dependsOn,omitempty"`
}
type Input struct {
Value interface{} `yaml:"value,omitempty"`
Encrypted bool `yaml:"encrypted,omitempty"`
Status InputStatus `yaml:"status,omitempty"`
DependsOn string `yaml:"dependsOn,omitempty"`
}
type InputStatus string
const (
InputStatusPending InputStatus = "pending"
InputStatusResolved InputStatus = "resolved"
InputStatusError InputStatus = "error"
)
type Binding struct {
URN *URN `yaml:"urn,omitempty"`
Inputs map[string]Input `yaml:"inputs,omitempty"`
}
func ReadIRFile(filename string) (*ApplicationEnvironment, error) {
data, err := os.ReadFile(filename)
if err != nil {
return &ApplicationEnvironment{}, err
}
return ParseIRFile(data)
}
func ParseIRFile(content []byte) (*ApplicationEnvironment, error) {
var appEnv *ApplicationEnvironment
err := yaml.Unmarshal(content, &appEnv)
if err != nil {
zap.S().Errorf("Error unmarshalling IR file: %s", err)
return &ApplicationEnvironment{}, err
}
return appEnv, nil
}
package model
import (
"context"
"fmt"
"os"
"sync"
"time"
"github.com/klothoplatform/klotho/pkg/tui"
"github.com/spf13/afero"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
)
type StateManager struct {
fs afero.Fs
stateFile string
state *State
mutex sync.Mutex
}
type State struct {
SchemaVersion int `yaml:"schemaVersion,omitempty"`
Version int `yaml:"version,omitempty"`
ProjectURN URN `yaml:"project_urn,omitempty"`
AppURN URN `yaml:"app_urn,omitempty"`
Environment string `yaml:"environment,omitempty"`
DefaultRegion string `yaml:"default_region,omitempty"`
Constructs map[string]ConstructState `yaml:"constructs,omitempty"`
}
func NewStateManager(fsys afero.Fs, stateFile string) *StateManager {
return &StateManager{
fs: fsys,
stateFile: stateFile,
state: &State{
SchemaVersion: 1,
Version: 1,
Constructs: make(map[string]ConstructState),
},
}
}
func (sm *StateManager) CheckStateFileExists() bool {
exists, err := afero.Exists(sm.fs, sm.stateFile)
return err == nil && exists
}
func (sm *StateManager) InitState(ir *ApplicationEnvironment) {
sm.mutex.Lock()
defer sm.mutex.Unlock()
for urn, construct := range ir.Constructs {
sm.state.Constructs[urn] = ConstructState{
Status: ConstructCreating,
LastUpdated: time.Now().Format(time.RFC3339),
Inputs: construct.Inputs,
Outputs: construct.Outputs,
Bindings: construct.Bindings,
Options: construct.Options,
DependsOn: construct.DependsOn,
URN: construct.URN,
}
}
sm.state.ProjectURN = ir.ProjectURN
sm.state.AppURN = ir.AppURN
sm.state.Environment = ir.Environment
sm.state.DefaultRegion = ir.DefaultRegion
}
func (sm *StateManager) LoadState() error {
sm.mutex.Lock()
defer sm.mutex.Unlock()
data, err := afero.ReadFile(sm.fs, sm.stateFile)
if err != nil {
if os.IsNotExist(err) {
sm.state = nil
return nil
}
return err
}
return yaml.Unmarshal(data, sm.state)
}
func (sm *StateManager) SaveState() error {
sm.mutex.Lock()
defer sm.mutex.Unlock()
data, err := yaml.Marshal(sm.state)
if err != nil {
return fmt.Errorf("error marshalling state: %w", err)
}
err = afero.WriteFile(sm.fs, sm.stateFile, data, 0644)
if err != nil {
return fmt.Errorf("error writing state: %w", err)
}
return nil
}
func (sm *StateManager) GetState() *State {
sm.mutex.Lock()
defer sm.mutex.Unlock()
return sm.state
}
func (sm *StateManager) UpdateResourceState(name string, status ConstructStatus, lastUpdated string) error {
sm.mutex.Lock()
defer sm.mutex.Unlock()
if sm.state.Constructs == nil {
sm.state.Constructs = make(map[string]ConstructState)
}
construct, exists := sm.state.Constructs[name]
if !exists {
return fmt.Errorf("construct %s not found", name)
}
if !isValidTransition(construct.Status, status) {
return fmt.Errorf("invalid transition from %s to %s", construct.Status, status)
}
construct.Status = status
construct.LastUpdated = lastUpdated
sm.state.Constructs[name] = construct
return nil
}
func (sm *StateManager) GetConstructState(name string) (ConstructState, bool) {
sm.mutex.Lock()
defer sm.mutex.Unlock()
construct, exists := sm.state.Constructs[name]
return construct, exists
}
func (sm *StateManager) SetConstructState(construct ConstructState) {
sm.mutex.Lock()
defer sm.mutex.Unlock()
sm.state.Constructs[construct.URN.ResourceID] = construct
}
func (sm *StateManager) GetAllConstructs() map[string]ConstructState {
sm.mutex.Lock()
defer sm.mutex.Unlock()
return sm.state.Constructs
}
func (sm *StateManager) TransitionConstructState(construct *ConstructState, nextStatus ConstructStatus) error {
sm.mutex.Lock()
defer sm.mutex.Unlock()
if !isValidTransition(construct.Status, nextStatus) {
return fmt.Errorf("invalid transition from %s to %s", construct.Status, nextStatus)
}
zap.L().Debug("Transitioning construct", zap.String("urn", construct.URN.String()), zap.String("from", string(construct.Status)), zap.String("to", string(nextStatus)))
construct.Status = nextStatus
construct.LastUpdated = time.Now().Format(time.RFC3339)
sm.state.Constructs[construct.URN.ResourceID] = *construct
return nil
}
func (sm *StateManager) IsOperating(construct *ConstructState) bool {
return construct.Status == ConstructCreating || construct.Status == ConstructUpdating || construct.Status == ConstructDeleting
}
func (sm *StateManager) TransitionConstructFailed(construct *ConstructState) error {
switch construct.Status {
case ConstructCreating:
return sm.TransitionConstructState(construct, ConstructCreateFailed)
case ConstructUpdating:
return sm.TransitionConstructState(construct, ConstructUpdateFailed)
case ConstructDeleting:
return sm.TransitionConstructState(construct, ConstructDeleteFailed)
default:
return fmt.Errorf("Initial state %s must be one of Creating, Updating, or Deleting", construct.Status)
}
}
func (sm *StateManager) TransitionConstructComplete(construct *ConstructState) error {
switch construct.Status {
case ConstructCreating:
return sm.TransitionConstructState(construct, ConstructCreateComplete)
case ConstructUpdating:
return sm.TransitionConstructState(construct, ConstructUpdateComplete)
case ConstructDeleting:
return sm.TransitionConstructState(construct, ConstructDeleteComplete)
default:
return fmt.Errorf("Initial state %s must be one of Creating, Updating, or Deleting", construct.Status)
}
}
// RegisterOutputValues registers the resolved output values of a construct in the state manager and resolves any inputs that depend on the provided outputs
func (sm *StateManager) RegisterOutputValues(ctx context.Context, urn URN, outputs map[string]any) error {
sm.mutex.Lock()
defer sm.mutex.Unlock()
prog := tui.GetProgram(ctx)
if sm.state.Constructs == nil {
return fmt.Errorf("%s not found in state", urn.String())
}
construct, exists := sm.state.Constructs[urn.ResourceID]
if !exists {
return fmt.Errorf("%s not found in state", urn.String())
}
if construct.Outputs == nil {
construct.Outputs = make(map[string]any)
}
for key, value := range outputs {
construct.Outputs[key] = value
if prog != nil {
prog.Send(tui.OutputMessage{
Construct: urn.ResourceID,
Name: key,
Value: value,
})
}
}
sm.state.Constructs[urn.ResourceID] = construct
for _, c := range sm.state.Constructs {
if urn.Equals(c.URN) {
continue
}
updated := false
for k, input := range c.Inputs {
if input.DependsOn == urn.String() {
input.Status = InputStatusResolved
input.Value = urn
c.Inputs[k] = input
updated = true
}
if o, ok := outputs[input.DependsOn]; ok {
input.Value = o
input.Status = InputStatusResolved
c.Inputs[k] = input
updated = true
}
}
if updated {
sm.state.Constructs[c.URN.ResourceID] = c
}
}
return nil
}
package model
import (
"fmt"
"path/filepath"
"strings"
)
// URN represents a Unique Resource Name in the Klotho ecosystem
type (
URN struct {
AccountID string `yaml:"accountId"`
Project string `yaml:"project"`
Environment string `yaml:"environment,omitempty"`
Application string `yaml:"application,omitempty"`
Type string `yaml:"type,omitempty"`
Subtype string `yaml:"subtype,omitempty"`
ParentResourceID string `yaml:"parentResourceId,omitempty"`
ResourceID string `yaml:"resourceId,omitempty"`
Output string `yaml:"output,omitempty"`
}
UrnType string
)
const (
AccountUrnType UrnType = "account"
ProjectUrnType UrnType = "project"
EnvironmentUrnType UrnType = "environment"
ApplicationEnvironmentUrnType UrnType = "application_environment"
ResourceUrnType UrnType = "resource"
OutputUrnType UrnType = "output"
TypeUrnType UrnType = "type"
)
// ParseURN parses a URN string into a URN struct
func ParseURN(urnString string) (*URN, error) {
var urn URN
if err := urn.UnmarshalText([]byte(urnString)); err != nil {
return nil, err
}
return &urn, nil
}
// String returns the URN as a string
func (u URN) String() string {
var sb strings.Builder
sb.WriteString("urn:")
sb.WriteString(u.AccountID)
sb.WriteString(":")
sb.WriteString(u.Project)
sb.WriteString(":")
sb.WriteString(u.Environment)
sb.WriteString(":")
sb.WriteString(u.Application)
sb.WriteString(":")
if u.Type != "" && u.Subtype != "" {
sb.WriteString(u.Type)
sb.WriteString("/")
sb.WriteString(u.Subtype)
}
sb.WriteString(":")
if u.ParentResourceID != "" && u.ResourceID != "" {
sb.WriteString(u.ParentResourceID)
sb.WriteString("/")
sb.WriteString(u.ResourceID)
} else {
sb.WriteString(u.ResourceID)
}
sb.WriteString(":")
sb.WriteString(u.Output)
sb.WriteString(":")
// Remove trailing colons
urn := sb.String()
return strings.TrimRight(urn, ":")
}
func (u URN) MarshalText() ([]byte, error) {
return []byte(u.String()), nil
}
func (u *URN) UnmarshalText(text []byte) error {
parts := strings.Split(string(text), ":")
if parts[0] == "urn" {
parts = parts[1:]
}
if len(parts) < 2 {
return fmt.Errorf("invalid URN format: missing account ID and/or project")
} else if len(parts) > 7 {
return fmt.Errorf("invalid URN format: too many parts")
}
u.AccountID = parts[0]
u.Project = parts[1]
if len(parts) > 2 {
u.Environment = parts[2]
}
if len(parts) > 3 {
u.Application = parts[3]
}
if len(parts) > 4 && parts[4] != "" {
typeParts := strings.Split(parts[4], "/")
if len(typeParts) != 2 {
return fmt.Errorf("invalid URN type format: %s", parts[4])
}
u.Type = typeParts[0]
u.Subtype = typeParts[1]
}
if len(parts) > 5 && parts[5] != "" {
resourceParts := strings.Split(parts[5], "/")
if len(resourceParts) == 2 {
u.ParentResourceID = resourceParts[0]
u.ResourceID = resourceParts[1]
} else {
u.ResourceID = parts[5]
}
}
if len(parts) > 6 && parts[6] != "" {
u.Output = parts[6]
}
return nil
}
func (u *URN) Equals(other any) bool {
switch other := other.(type) {
case URN:
if u == nil {
return false
}
return *u == other
case *URN:
if u == nil || other == nil {
return u == other
}
return *u == *other
}
return false
}
func (u *URN) IsOutput() bool {
// all fields are filled except application
return u.AccountID != "" && u.Project != "" && u.Environment != "" && u.Type != "" &&
u.Subtype != "" && u.ParentResourceID != "" && u.ResourceID != "" && u.Output != ""
}
func (u *URN) IsResource() bool {
// all fields are filled except application and output
return u.AccountID != "" && u.Project != "" && u.Environment != "" && u.Type != "" &&
u.Subtype != "" && u.ResourceID != "" && u.Output == ""
}
func (u *URN) IsApplicationEnvironment() bool {
return u.AccountID != "" && u.Project != "" && u.Environment != "" && u.Application != "" &&
u.Type == "" && u.Subtype == "" && u.ParentResourceID == "" && u.ResourceID == "" && u.Output == ""
}
func (u *URN) IsType() bool {
return u.Type != "" && u.Subtype == "" && u.ParentResourceID == "" && u.ResourceID == "" && u.Output == ""
}
func (u *URN) IsEnvironment() bool {
return u.AccountID != "" && u.Project != "" && u.Environment != "" && u.Application == "" &&
u.Type == "" && u.Subtype == "" && u.ParentResourceID == "" && u.ResourceID == "" && u.Output == ""
}
func (u *URN) IsProject() bool {
return u.AccountID != "" && u.Project != "" && u.Environment == "" && u.Application == "" &&
u.Type == "" && u.Subtype == "" && u.ParentResourceID == "" && u.ResourceID == "" && u.Output == ""
}
func (u *URN) IsAccount() bool {
return u.AccountID != "" && u.Project == "" && u.Environment == "" && u.Application == "" &&
u.Type == "" && u.Subtype == "" && u.ParentResourceID == "" && u.ResourceID == "" && u.Output == ""
}
func (u *URN) UrnType() UrnType {
if u.IsAccount() {
return AccountUrnType
}
if u.IsProject() {
return ProjectUrnType
}
if u.IsEnvironment() {
return EnvironmentUrnType
}
if u.IsApplicationEnvironment() {
return ApplicationEnvironmentUrnType
}
if u.IsResource() {
return ResourceUrnType
}
if u.IsOutput() {
return OutputUrnType
}
if u.IsType() {
return TypeUrnType
}
return ""
}
// UrnPath returns the relative filesystem path of the output for a given URN
// (e.g., project/application/environment/construct)
func UrnPath(urn URN) (string, error) {
parts := []string{
urn.Project,
urn.Application,
urn.Environment,
urn.ResourceID,
}
for i, p := range parts {
if p == "" {
return filepath.Join(parts[:i]...), nil
}
}
return filepath.Join(parts...), nil
}
func (u *URN) Compare(other URN) int {
if u.AccountID != other.AccountID {
return strings.Compare(u.AccountID, other.AccountID)
}
if u.Project != other.Project {
return strings.Compare(u.Project, other.Project)
}
if u.Environment != other.Environment {
return strings.Compare(u.Environment, other.Environment)
}
if u.Application != other.Application {
return strings.Compare(u.Application, other.Application)
}
if u.Type != other.Type {
return strings.Compare(u.Type, other.Type)
}
if u.Subtype != other.Subtype {
return strings.Compare(u.Subtype, other.Subtype)
}
if u.ParentResourceID != other.ParentResourceID {
return strings.Compare(u.ParentResourceID, other.ParentResourceID)
}
if u.ResourceID != other.ResourceID {
return strings.Compare(u.ResourceID, other.ResourceID)
}
if u.Output != other.Output {
return strings.Compare(u.Output, other.Output)
}
return 0
}
func (u *URN) IsZero() bool {
return u == nil || *u == URN{}
}
package model
import (
"fmt"
"github.com/google/uuid"
)
type UUID struct {
uuid.UUID
}
func (u *UUID) UnmarshalYAML(unmarshal func(interface{}) error) error {
var s string
if err := unmarshal(&s); err != nil {
return fmt.Errorf("error unmarshalling YAML string: %w", err)
}
parsedUUID, err := uuid.Parse(s)
if err != nil {
return fmt.Errorf("error parsing UUID: %w", err)
}
*u = UUID{parsedUUID}
return nil
}
package orchestration
import (
"context"
"github.com/klothoplatform/klotho/pkg/k2/model"
"github.com/klothoplatform/klotho/pkg/logging"
"github.com/klothoplatform/klotho/pkg/tui"
"go.uber.org/zap"
)
func ConstructContext(ctx context.Context, construct model.URN) context.Context {
ctx = logging.WithLogger(ctx, logging.GetLogger(ctx).With(zap.String("construct", construct.ResourceID)))
if prog := tui.GetProgram(ctx); prog != nil {
ctx = tui.WithProgress(ctx, &tui.TuiProgress{
Prog: prog,
Construct: construct.ResourceID,
})
}
return ctx
}
package orchestration
import (
"context"
"errors"
"fmt"
"github.com/klothoplatform/klotho/pkg/k2/model"
"github.com/klothoplatform/klotho/pkg/k2/stack"
"github.com/klothoplatform/klotho/pkg/tui"
"github.com/spf13/afero"
)
type (
DownOrchestrator struct {
*Orchestrator
FS afero.Fs
}
DownRequest struct {
StackReferences []stack.Reference
DryRun model.DryRun
}
)
func NewDownOrchestrator(sm *model.StateManager, fs afero.Fs, outputPath string) *DownOrchestrator {
return &DownOrchestrator{
Orchestrator: NewOrchestrator(sm, fs, outputPath),
FS: afero.NewOsFs(),
}
}
func (do *DownOrchestrator) RunDownCommand(ctx context.Context, request DownRequest, maxConcurrency int) error {
if request.DryRun > 0 {
// TODO Stack.Destroy hard-codes the flag to "--skip-preview"
// and doesn't have any options for "--preview-only"
// which was added in https://github.com/pulumi/pulumi/pull/15336
return errors.New("Dryrun not supported in Down Command yet")
}
defer do.FinalizeState(ctx)
sm := do.StateManager
stackRefCache := make(map[string]stack.Reference)
actions := make(map[model.URN]model.ConstructAction)
var constructsToDelete []model.ConstructState
for _, ref := range request.StackReferences {
c, exists := sm.GetConstructState(ref.ConstructURN.ResourceID)
if !exists {
// This means there's a construct in our StackReferences that doesn't exist in the state
// This should never happen as we just build StackReferences from the state
return fmt.Errorf("construct %s not found in state", ref.ConstructURN.ResourceID)
}
if c.Status == model.ConstructDeleteComplete {
continue
}
constructsToDelete = append(constructsToDelete, c)
// Cache the stack reference for later use outside this loop
stackRefCache[ref.ConstructURN.ResourceID] = ref
actions[*c.URN] = model.ConstructActionDelete
}
deleteOrder, err := sortConstructsByDependency(constructsToDelete, actions)
if err != nil {
return fmt.Errorf("failed to determine deployment order: %w", err)
}
for _, group := range deleteOrder {
for _, cURN := range group {
action := actions[cURN]
ctx := ConstructContext(ctx, cURN)
prog := tui.GetProgress(ctx)
prog.UpdateIndeterminate(fmt.Sprintf("Starting %s", action))
}
}
sem := make(chan struct{}, maxConcurrency)
for _, group := range deleteOrder {
errChan := make(chan error, len(group))
for _, cURN := range group {
sem <- struct{}{}
go func(cURN model.URN) {
defer func() { <-sem }()
construct, exists := sm.GetConstructState(cURN.ResourceID)
if !exists {
errChan <- fmt.Errorf("construct %s not found in state", cURN.ResourceID)
return
}
ctx := ConstructContext(ctx, *construct.URN)
prog := tui.GetProgress(ctx)
if construct.Status == model.ConstructDeleteComplete || construct.Status == model.ConstructCreating {
prog.Complete("Skipped")
errChan <- sm.TransitionConstructState(&construct, model.ConstructDeleteComplete)
return
}
if err := sm.TransitionConstructState(&construct, model.ConstructDeleting); err != nil {
prog.Complete("Failed")
errChan <- err
return
}
stackRef := stackRefCache[construct.URN.ResourceID]
err := stack.RunDown(ctx, do.FS, stackRef)
if err != nil {
prog.Complete("Failed")
if err2 := sm.TransitionConstructFailed(&construct); err2 != nil {
err = fmt.Errorf("%v: error transitioning construct state to delete failed: %v", err, err2)
}
errChan <- err
return
} else if err := sm.TransitionConstructComplete(&construct); err != nil {
prog.Complete("Failed")
errChan <- err
return
}
prog.Complete("Success")
errChan <- nil
}(cURN)
}
var errs []error
for i := 0; i < len(group); i++ {
if err := <-errChan; err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return errors.Join(errs...)
}
}
return nil
}
package orchestration
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine"
"github.com/klothoplatform/klotho/pkg/engine/solution"
"github.com/klothoplatform/klotho/pkg/infra/iac"
kio "github.com/klothoplatform/klotho/pkg/io"
"github.com/klothoplatform/klotho/pkg/knowledgebase/reader"
"github.com/klothoplatform/klotho/pkg/logging"
"github.com/klothoplatform/klotho/pkg/provider/aws"
"github.com/klothoplatform/klotho/pkg/templates"
"github.com/spf13/afero"
"gopkg.in/yaml.v3"
)
type (
InfraGenerator struct {
Engine *engine.Engine
FS afero.Fs
}
InfraRequest struct {
engine.SolveRequest
OutputDir string
}
)
func NewInfraGenerator(fs afero.Fs) (*InfraGenerator, error) {
kb, err := reader.NewKBFromFs(templates.ResourceTemplates, templates.EdgeTemplates, templates.Models)
if err != nil {
return nil, err
}
return &InfraGenerator{
Engine: engine.NewEngine(kb),
FS: fs,
}, nil
}
func (g *InfraGenerator) writeYamlFile(outDir string, path string, v any) error {
if !strings.HasPrefix(path, outDir) {
path = filepath.Join(outDir, path)
}
err := g.FS.MkdirAll(filepath.Dir(path), 0755)
if err != nil {
return err
}
f, err := g.FS.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
return err
}
defer f.Close()
return yaml.NewEncoder(f).Encode(v)
}
func (g *InfraGenerator) writeKFile(outDir string, kf kio.File) error {
path := kf.Path()
if !strings.HasPrefix(path, outDir) {
path = filepath.Join(outDir, path)
}
err := g.FS.MkdirAll(filepath.Dir(path), 0755)
if err != nil {
return err
}
f, err := g.FS.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
return err
}
defer f.Close()
_, err = kf.WriteTo(f)
return err
}
func (g *InfraGenerator) Run(ctx context.Context, req engine.SolveRequest, outDir string) (solution.Solution, error) {
if err := g.writeYamlFile(outDir, "engine_input.yaml", req); err != nil {
return nil, fmt.Errorf("failed to write engine input: %w", err)
}
sol, errs := g.resolveResources(ctx, InfraRequest{
SolveRequest: req,
OutputDir: outDir,
})
if errs != nil {
return nil, fmt.Errorf("failed to resolve resources: %v", errs)
}
err := g.generateIac(iacRequest{
PulumiAppName: "k2",
Solution: sol,
OutputDir: outDir,
})
if err != nil {
return nil, fmt.Errorf("failed to generate iac: %w", err)
}
return sol, nil
}
func (g *InfraGenerator) resolveResources(ctx context.Context, request InfraRequest) (solution.Solution, error) {
log := logging.GetLogger(ctx)
log.Info("Running engine")
sol, engineErr := g.Engine.Run(ctx, &request.SolveRequest)
if engineErr != nil {
return nil, fmt.Errorf("Engine failed: %w", engineErr)
}
log.Info("Generating views")
var fileWriteErrs []error
log.Info("Serializing constraints")
vizFiles, err := g.Engine.VisualizeViews(sol)
if err != nil {
return nil, fmt.Errorf("failed to generate views %w", err)
}
for _, f := range vizFiles {
fileWriteErrs = append(fileWriteErrs, g.writeKFile(request.OutputDir, f))
}
log.Info("Generating resources.yaml")
fileWriteErrs = append(fileWriteErrs, g.writeYamlFile(
request.OutputDir,
"resources.yaml",
construct.YamlGraph{Graph: sol.DataflowGraph(), Outputs: sol.Outputs()},
))
policyBytes, err := aws.DeploymentPermissionsPolicy(sol)
if err != nil {
return nil, fmt.Errorf("failed to generate deployment permissions policy: %w", err)
}
if policyBytes != nil {
fileWriteErrs = append(fileWriteErrs,
g.writeKFile(request.OutputDir, &kio.RawFile{
FPath: "aws_deployment_policy.json",
Content: policyBytes,
}),
)
}
return sol, errors.Join(fileWriteErrs...)
}
type iacRequest struct {
PulumiAppName string
Solution solution.Solution
OutputDir string
}
func (g *InfraGenerator) generateIac(request iacRequest) error {
pulumiPlugin := iac.Plugin{
Config: &iac.PulumiConfig{AppName: request.PulumiAppName},
KB: g.Engine.Kb,
}
iacFiles, err := pulumiPlugin.Translate(request.Solution)
if err != nil {
return err
}
var fileWriteErrs []error
for _, f := range iacFiles {
fileWriteErrs = append(fileWriteErrs, g.writeKFile(request.OutputDir, f))
}
return errors.Join(fileWriteErrs...)
}
package orchestration
import (
"context"
"fmt"
"path/filepath"
"sync"
"time"
"github.com/klothoplatform/klotho/pkg/k2/constructs/graph"
"github.com/klothoplatform/klotho/pkg/k2/model"
"github.com/klothoplatform/klotho/pkg/k2/stack"
"github.com/klothoplatform/klotho/pkg/logging"
"github.com/spf13/afero"
)
// Orchestrator is the base orchestrator for the K2 platform
type Orchestrator struct {
StateManager *model.StateManager
FS afero.Fs
OutputDirectory string
mu sync.Mutex // guards the following fields
infraGenerator *InfraGenerator
}
func NewOrchestrator(sm *model.StateManager, fs afero.Fs, outputPath string) *Orchestrator {
return &Orchestrator{
StateManager: sm,
FS: fs,
OutputDirectory: outputPath,
}
}
func (o *Orchestrator) InfraGenerator() (*InfraGenerator, error) {
o.mu.Lock()
defer o.mu.Unlock()
if o.infraGenerator == nil {
var err error
o.infraGenerator, err = NewInfraGenerator(o.FS)
if err != nil {
return nil, err
}
}
return o.infraGenerator, nil
}
func (uo *UpOrchestrator) EvaluateConstruct(ctx context.Context, state model.State, constructUrn model.URN) (stack.Reference, error) {
constructOutDir := filepath.Join(uo.OutputDirectory, constructUrn.ResourceID)
err := uo.FS.MkdirAll(constructOutDir, 0755)
if err != nil {
return stack.Reference{}, fmt.Errorf("error creating construct output directory: %w", err)
}
req, err := uo.ConstructEvaluator.Evaluate(constructUrn, state, ctx)
if err != nil {
return stack.Reference{}, err
}
req.GlobalTag = "k2" // TODO make this meaningful?
ig, err := uo.InfraGenerator()
if err != nil {
return stack.Reference{}, fmt.Errorf("error getting infra generator: %w", err)
}
sol, err := ig.Run(ctx, req, constructOutDir)
if err != nil {
return stack.Reference{}, fmt.Errorf("error running infra generator: %w", err)
}
uo.ConstructEvaluator.AddSolution(constructUrn, sol)
return stack.Reference{
ConstructURN: constructUrn,
Name: constructUrn.ResourceID,
IacDirectory: constructOutDir,
AwsRegion: uo.StateManager.GetState().DefaultRegion,
}, nil
}
func (o *Orchestrator) resolveInitialState(ir *model.ApplicationEnvironment) (map[model.URN]model.ConstructAction, error) {
actions := make(map[model.URN]model.ConstructAction)
state := o.StateManager.GetState()
//TODO: implement some kind of versioning check
state.Version += 1
// Check for default region mismatch
if state.DefaultRegion != ir.DefaultRegion {
deployed := make(map[string]model.ConstructStatus)
for k, v := range state.Constructs {
if model.IsDeletable(v.Status) {
deployed[k] = v.Status
}
}
if len(deployed) > 0 {
return nil, fmt.Errorf("cannot change region (%s -> %s) with deployed resources: %v", state.DefaultRegion, ir.DefaultRegion, deployed)
}
}
// Check for schema version mismatch
if state.SchemaVersion != ir.SchemaVersion {
return nil, fmt.Errorf("state schema version mismatch")
}
for _, c := range ir.Constructs {
var status model.ConstructStatus
var action model.ConstructAction
construct, exists := o.StateManager.GetConstructState(c.URN.ResourceID)
if !exists {
// If the construct doesn't exist in the current state, it's a create action
action = model.ConstructActionCreate
status = model.ConstructCreating
construct = model.ConstructState{
Status: model.ConstructCreating,
LastUpdated: time.Now().Format(time.RFC3339),
Inputs: c.Inputs,
Outputs: c.Outputs,
Bindings: c.Bindings,
Options: c.Options,
DependsOn: c.DependsOn,
URN: c.URN,
}
} else {
if model.IsCreatable(construct.Status) {
action = model.ConstructActionCreate
status = model.ConstructCreating
} else if model.IsUpdatable(construct.Status) {
action = model.ConstructActionUpdate
status = model.ConstructUpdating
}
construct.Inputs = c.Inputs
construct.Outputs = c.Outputs
construct.Bindings = c.Bindings
construct.Options = c.Options
construct.DependsOn = c.DependsOn
}
actions[*c.URN] = action
err := o.StateManager.TransitionConstructState(&construct, status)
if err != nil {
return nil, err
}
}
// Find deleted constructs
for k, v := range o.StateManager.GetState().Constructs {
if _, ok := ir.Constructs[k]; !ok {
if v.Status == model.ConstructDeleteComplete {
continue
}
actions[*v.URN] = model.ConstructActionDelete
if !model.IsDeletable(v.Status) {
return nil, fmt.Errorf("construct %s is not deletable", v.URN.ResourceID)
}
err := o.StateManager.TransitionConstructState(&v, model.ConstructDeleting)
if err != nil {
return nil, err
}
}
}
return actions, nil
}
// sortConstructsByDependency sorts the constructs based on their dependencies and returns the deployment order
// in the form of sequential construct groups that can be deployed in parallel
func sortConstructsByDependency(constructs []model.ConstructState, actions map[model.URN]model.ConstructAction) ([][]model.URN, error) {
constructGraph := graph.NewAcyclicGraph()
// Add vertices and edges to the graph based on the construct dependencies.
// Edges are reversed for delete actions
// (i.e., if 'a' depends on 'b', and 'a' is to be deleted, the edge is from 'b' to 'a' otherwise from 'a' to 'b')
for _, c := range constructs {
_ = constructGraph.AddVertex(*c.URN)
}
for _, c := range constructs {
for _, dep := range c.DependsOn {
var source, target model.URN
if actions[*c.URN] == model.ConstructActionDelete {
source = *dep
target = *c.URN
} else {
source = *c.URN
target = *dep
}
err := constructGraph.AddEdge(source, target)
if err != nil {
return nil, err
}
}
for _, b := range c.Bindings {
var source, target model.URN
if actions[*c.URN] == model.ConstructActionDelete {
source = *b.URN
target = *c.URN
} else {
source = *c.URN
target = *b.URN
}
err := constructGraph.AddEdge(source, target)
if err != nil {
return nil, err
}
}
}
return graph.ResolveDeploymentGroups(constructGraph)
}
func (o *Orchestrator) FinalizeState(ctx context.Context) {
log := logging.GetLogger(ctx).Sugar()
sm := o.StateManager
for _, c := range sm.GetState().Constructs {
if sm.IsOperating(&c) {
if err := sm.TransitionConstructFailed(&c); err != nil {
log.Errorf("Error transitioning construct state: %v", err)
}
}
}
if err := sm.SaveState(); err != nil {
log.Errorf("Error saving state: %v", err)
}
}
package orchestration
import (
"context"
"errors"
"fmt"
"path/filepath"
"time"
"github.com/klothoplatform/klotho/pkg/engine/debug"
"github.com/klothoplatform/klotho/pkg/k2/constructs"
pb "github.com/klothoplatform/klotho/pkg/k2/language_host/go"
"github.com/klothoplatform/klotho/pkg/k2/model"
"github.com/klothoplatform/klotho/pkg/k2/stack"
"github.com/klothoplatform/klotho/pkg/logging"
"github.com/klothoplatform/klotho/pkg/tui"
"github.com/spf13/afero"
"go.uber.org/zap"
"golang.org/x/sync/semaphore"
"gopkg.in/yaml.v3"
)
type UpOrchestrator struct {
*Orchestrator
LanguageHostClient pb.KlothoServiceClient
StackStateManager *stack.StateManager
ConstructEvaluator *constructs.ConstructEvaluator
}
func NewUpOrchestrator(
sm *model.StateManager, languageHostClient pb.KlothoServiceClient, fs afero.Fs, outputPath string,
) (*UpOrchestrator, error) {
ssm := stack.NewStateManager()
ce, err := constructs.NewConstructEvaluator(sm, ssm)
if err != nil {
return nil, err
}
return &UpOrchestrator{
Orchestrator: NewOrchestrator(sm, fs, outputPath),
LanguageHostClient: languageHostClient,
StackStateManager: ssm,
ConstructEvaluator: ce,
}, nil
}
func (uo *UpOrchestrator) RunUpCommand(
ctx context.Context, ir *model.ApplicationEnvironment, dryRun model.DryRun, sem *semaphore.Weighted,
) error {
uo.ConstructEvaluator.DryRun = dryRun
if dryRun == model.DryRunNone {
// We don't finalize for dryrun as this updates/creates the state file
defer uo.FinalizeState(ctx)
}
actions, err := uo.resolveInitialState(ir)
if err != nil {
return fmt.Errorf("error resolving initial state: %w", err)
}
var cs []model.ConstructState
constructState := uo.StateManager.GetState().Constructs
for cURN := range actions {
cs = append(cs, constructState[cURN.ResourceID])
}
deployOrder, err := sortConstructsByDependency(cs, actions)
if err != nil {
return fmt.Errorf("failed to determine deployment order: %w", err)
}
for _, group := range deployOrder {
for _, cURN := range group {
action := actions[cURN]
ctx := ConstructContext(ctx, cURN)
prog := tui.GetProgress(ctx)
prog.UpdateIndeterminate(fmt.Sprintf("Pending %s", action))
}
}
for _, group := range deployOrder {
errChan := make(chan error, len(group))
for _, cURN := range group {
if err := sem.Acquire(ctx, 1); err != nil {
errChan <- fmt.Errorf("error acquiring semaphore: %w", err)
continue
}
go func(cURN model.URN) {
defer sem.Release(1)
c, exists := uo.StateManager.GetConstructState(cURN.ResourceID)
if !exists {
errChan <- fmt.Errorf("construct %s not found in state", cURN.ResourceID)
return
}
errChan <- uo.executeAction(ctx, c, actions[cURN], dryRun)
}(cURN)
}
var errs []error
for i := 0; i < len(group); i++ {
if err := <-errChan; err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return errors.Join(errs...)
}
}
return nil
}
// placeholderOutputs sends placeholder values to TUI for cases where they cannot be taken from the state when
// running in dry run modes.
func (uo *UpOrchestrator) placeholderOutputs(ctx context.Context, cURN model.URN) {
c, ok := uo.ConstructEvaluator.Constructs.Get(cURN)
if !ok {
return
}
prog := tui.GetProgram(ctx)
if prog == nil {
return
}
outputs := c.Outputs
if len(outputs) == 0 && c.Solution != nil {
outputs = make(map[string]any)
for name, o := range c.Solution.Outputs() {
if !o.Ref.IsZero() {
outputs[name] = fmt.Sprintf("<%s>", o.Ref)
continue
}
outputs[name] = o.Value
}
}
for key, value := range outputs {
prog.Send(tui.OutputMessage{
Construct: cURN.ResourceID,
Name: key,
Value: value,
})
}
}
func (uo *UpOrchestrator) executeAction(ctx context.Context, c model.ConstructState, action model.ConstructAction, dryRun model.DryRun) (err error) {
sm := uo.StateManager
log := logging.GetLogger(ctx).Sugar()
outDir := filepath.Join(uo.OutputDirectory, c.URN.ResourceID)
ctx = ConstructContext(ctx, *c.URN)
if debugDir := debug.GetDebugDir(ctx); debugDir != "" {
ctx = debug.WithDebugDir(ctx, filepath.Join(debugDir, c.URN.ResourceID))
}
prog := tui.GetProgress(ctx)
prog.UpdateIndeterminate(fmt.Sprintf("Starting %s", action))
skipped := false
defer func() {
r := recover()
msg := "Success"
if err != nil || r != nil {
msg = "Failed"
} else if dryRun > 0 {
msg += " (dry run)"
}
if skipped && err == nil {
msg = "Skipped"
}
prog.Complete(msg)
if r != nil {
panic(r)
}
}()
if action == model.ConstructActionDelete {
if !model.IsDeletable(c.Status) {
skipped = true
log.Debugf("Skipping construct %s, status is %s", c.URN.ResourceID, c.Status)
return nil
}
if dryRun > 0 {
log.Infof("Dry run: Skipping pulumi down for deleted construct %s", c.URN.ResourceID)
return nil
}
// Mark as deleting
if err = sm.TransitionConstructState(&c, model.ConstructDeleting); err != nil {
return err
}
err = stack.RunDown(ctx, uo.FS, stack.Reference{
ConstructURN: *c.URN,
Name: c.URN.ResourceID,
IacDirectory: outDir,
AwsRegion: sm.GetState().DefaultRegion,
})
if err != nil {
if err2 := sm.TransitionConstructFailed(&c); err2 != nil {
log.Errorf("Error transitioning construct state: %v", err2)
}
return fmt.Errorf("error running pulumi down command: %w", err)
}
// Mark as deleted
return sm.TransitionConstructComplete(&c)
}
// Only proceed if the construct is deployable
if !model.IsDeployable(c.Status) {
skipped = true
log.Debugf("Skipping construct %s, status is %s", c.URN.ResourceID, c.Status)
return nil
}
// Evaluate the construct
stackRef, err := uo.EvaluateConstruct(ctx, *uo.StateManager.GetState(), *c.URN)
if err != nil {
return fmt.Errorf("error evaluating construct: %w", err)
}
switch dryRun {
case model.DryRunPreview:
_, err = stack.RunPreview(ctx, uo.FS, stackRef)
uo.placeholderOutputs(ctx, *c.URN)
if err != nil {
return fmt.Errorf("error running pulumi preview command: %w", err)
}
err = sm.RegisterOutputValues(ctx, stackRef.ConstructURN, map[string]any{})
return err
case model.DryRunCompile:
err = stack.InstallDependencies(ctx, stackRef.IacDirectory)
if err != nil {
return err
}
cmd := logging.Command(ctx,
logging.CommandLogger{
RootLogger: log.Desugar().Named("pulumi.tsc"),
StdoutLevel: zap.DebugLevel,
StderrLevel: zap.DebugLevel,
},
"tsc", "--noEmit", "index.ts",
)
cmd.Dir = stackRef.IacDirectory
err := cmd.Run()
uo.placeholderOutputs(ctx, *c.URN)
if err != nil {
return fmt.Errorf("error running tsc: %w", err)
}
return sm.RegisterOutputValues(ctx, stackRef.ConstructURN, map[string]any{})
case model.DryRunFileOnly:
// file already written, nothing left to do
uo.placeholderOutputs(ctx, *c.URN)
return sm.RegisterOutputValues(ctx, stackRef.ConstructURN, map[string]any{})
}
// Run pulumi up command for the construct
upResult, stackState, err := stack.RunUp(ctx, uo.FS, stackRef)
if err != nil {
if err2 := sm.TransitionConstructFailed(&c); err2 != nil {
log.Errorf("Error transitioning construct state: %v", err2)
}
return fmt.Errorf("error running pulumi up command: %w", err)
}
uo.StackStateManager.ConstructStackState[stackRef.ConstructURN] = *stackState
err = sm.RegisterOutputValues(ctx, stackRef.ConstructURN, stackState.Outputs)
if err != nil {
return fmt.Errorf("error registering output values: %w", err)
}
// Update construct state based on the up result
err = stack.UpdateConstructStateFromUpResult(sm, stackRef, upResult)
if err != nil {
return err
}
// Resolve pending output values by calling the language host
resolvedOutputs, err := uo.resolveOutputValues(stackRef, *stackState)
if err != nil {
return fmt.Errorf("error resolving output values: %w", err)
}
uo.ConstructEvaluator.RegisterOutputValues(stackRef.ConstructURN, stackState.Outputs)
return sm.RegisterOutputValues(ctx, stackRef.ConstructURN, resolvedOutputs)
}
func (uo *UpOrchestrator) resolveOutputValues(stackReference stack.Reference, stackState stack.State) (map[string]any, error) {
outputs := map[string]map[string]any{
stackReference.ConstructURN.String(): stackState.Outputs,
}
payload, err := yaml.Marshal(outputs)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*300)
defer cancel()
resp, err := uo.LanguageHostClient.RegisterConstruct(ctx, &pb.RegisterConstructRequest{
YamlPayload: string(payload),
})
if err != nil {
return nil, err
}
var resolvedOutputs map[string]any
err = yaml.Unmarshal([]byte(resp.GetYamlPayload()), &resolvedOutputs)
if err != nil {
return nil, err
}
return resolvedOutputs, nil
}
package stack
import (
"context"
"fmt"
pulumi "github.com/pulumi/pulumi/sdk/v3"
"io"
"os"
"path/filepath"
"strings"
"github.com/klothoplatform/klotho/pkg/k2/model"
"github.com/klothoplatform/klotho/pkg/logging"
"github.com/klothoplatform/klotho/pkg/tui"
"github.com/pulumi/pulumi/sdk/v3/go/auto"
"github.com/pulumi/pulumi/sdk/v3/go/auto/optdestroy"
"github.com/pulumi/pulumi/sdk/v3/go/auto/optpreview"
"github.com/pulumi/pulumi/sdk/v3/go/auto/optup"
"github.com/pulumi/pulumi/sdk/v3/go/common/tokens"
"github.com/pulumi/pulumi/sdk/v3/go/common/workspace"
"github.com/spf13/afero"
"go.uber.org/zap"
)
type Reference struct {
ConstructURN model.URN
Name string
IacDirectory string
AwsRegion string
}
func Initialize(ctx context.Context, fs afero.Fs, projectName string, stackName string, stackDirectory string) (StackInterface, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("Failed to get user home directory: %w", err)
}
pulumiHomeDir := filepath.Join(homeDir, ".k2", "pulumi")
if exists, err := afero.DirExists(fs, pulumiHomeDir); !exists || err != nil {
if err := fs.MkdirAll(pulumiHomeDir, 0755); err != nil {
return nil, fmt.Errorf("Failed to create pulumi home directory: %w", err)
}
}
stateDir := filepath.Join(pulumiHomeDir, "state")
if exists, err := afero.DirExists(fs, stateDir); !exists || err != nil {
if err := fs.MkdirAll(stateDir, 0755); err != nil {
return nil, fmt.Errorf("Failed to create stack state directory: %w", err)
}
}
proj := auto.Project(workspace.Project{
Name: tokens.PackageName("myproject"),
Runtime: workspace.NewProjectRuntimeInfo("nodejs", nil),
Backend: &workspace.ProjectBackend{
URL: "file://" + stateDir,
},
})
secretsProvider := auto.SecretsProvider("passphrase")
envvars := auto.EnvVars(map[string]string{
"PULUMI_CONFIG_PASSPHRASE": "",
})
pulumiCmd, err := auto.NewPulumiCommand(&auto.PulumiCommandOptions{
Root: filepath.Join(pulumiHomeDir, "versions", pulumi.Version.String()),
})
if err != nil {
return nil, err
}
stack, err := auto.UpsertStackLocalSource(ctx, stackName, stackDirectory, proj, envvars, auto.PulumiHome(pulumiHomeDir), secretsProvider, auto.Pulumi(pulumiCmd))
if err != nil {
return nil, fmt.Errorf("Failed to create or select stack: %w", err)
}
return &stack, nil
}
func RunUp(ctx context.Context, fs afero.Fs, stackReference Reference) (*auto.UpResult, *State, error) {
log := logging.GetLogger(ctx).Named("pulumi.up").Sugar()
stackName := stackReference.Name
stackDirectory := stackReference.IacDirectory
s, err := Initialize(ctx, fs, "myproject", stackName, stackDirectory)
if err != nil {
return nil, nil, fmt.Errorf("Failed to create or select stack: %w", err)
}
log.Debugf("Created/Selected stack %q", stackName)
err = InstallDependencies(ctx, stackDirectory)
if err != nil {
return nil, nil, fmt.Errorf("Failed to install dependencies: %w", err)
}
// set stack configuration specifying the AWS region to deploy
err = s.SetConfig(ctx, "aws:region", auto.ConfigValue{Value: stackReference.AwsRegion})
if err != nil {
return nil, nil, fmt.Errorf("Failed to set stack configuration: %w", err)
}
log.Debug("Starting update")
upResult, err := s.Up(
ctx,
optup.ProgressStreams(logging.NewLoggerWriter(log.Desugar(), zap.InfoLevel)),
optup.EventStreams(Events(ctx, "Deploying")),
optup.Refresh(),
)
if err != nil {
return nil, nil, fmt.Errorf("Failed to update stack: %w", err)
}
log.Infof("Successfully deployed stack %s", stackName)
stackState, err := GetState(ctx, s)
return &upResult, &stackState, err
}
func RunPreview(ctx context.Context, fs afero.Fs, stackReference Reference) (*auto.PreviewResult, error) {
log := logging.GetLogger(ctx).Named("pulumi.preview").Sugar()
stackName := stackReference.Name
stackDirectory := stackReference.IacDirectory
s, err := Initialize(ctx, fs, "myproject", stackName, stackDirectory)
if err != nil {
return nil, fmt.Errorf("Failed to create or select stack: %w", err)
}
log.Infof("Created/Selected stack %q", stackName)
err = InstallDependencies(ctx, stackDirectory)
if err != nil {
return nil, fmt.Errorf("Failed to install dependencies: %w", err)
}
// set stack configuration specifying the AWS region to deploy
err = s.SetConfig(ctx, "aws:region", auto.ConfigValue{Value: stackReference.AwsRegion})
if err != nil {
return nil, fmt.Errorf("Failed to set stack configuration: %w", err)
}
log.Debug("Starting preview")
previewResult, err := s.Preview(
ctx,
optpreview.ProgressStreams(logging.NewLoggerWriter(log.Desugar(), zap.InfoLevel)),
optpreview.EventStreams(Events(ctx, "Previewing")),
optpreview.Refresh(),
)
if err != nil {
str := err.Error()
// Use the first line only, the rest of it is redundant with the first line or the live logging already shown
firstLine := strings.Split(str, "\n")[0]
if auto.IsCompilationError(err) || auto.IsRuntimeError(err) || auto.IsCreateStack409Error(err) {
return nil, fmt.Errorf("Failed to preview stack: %s", firstLine)
}
log.Warnf("Failed to preview stack %s: %s", stackName, firstLine)
// Don't return an error for preview failures so that futher previewing can proceed
return nil, nil
}
log.Infof("Successfully previewed stack %s", stackName)
return &previewResult, nil
}
func RunDown(ctx context.Context, fs afero.Fs, stackReference Reference) error {
log := logging.GetLogger(ctx).Named("pulumi.destroy").Sugar()
stackName := stackReference.Name
stackDirectory := stackReference.IacDirectory
s, err := Initialize(ctx, fs, "myproject", stackName, stackDirectory)
if err != nil {
return fmt.Errorf("Failed to create or select stack: %w", err)
}
log.Debugf("Created/Selected stack %q", stackName)
// set stack configuration specifying the AWS region to deploy
err = s.SetConfig(ctx, "aws:region", auto.ConfigValue{Value: stackReference.AwsRegion})
if err != nil {
return fmt.Errorf("Failed to set stack configuration: %w", err)
}
log.Debug("Starting destroy")
// wire up our destroy to stream progress to stdout
stdoutStreamer := optdestroy.ProgressStreams(logging.NewLoggerWriter(log.Desugar(), zap.InfoLevel))
refresh := optdestroy.Refresh()
eventStream := optdestroy.EventStreams(Events(ctx, "Destroying"))
// run the destroy to remove our resources
_, err = s.Destroy(ctx, stdoutStreamer, eventStream, refresh)
if err != nil {
return fmt.Errorf("Failed to destroy stack: %w", err)
}
log.Infof("Successfully destroyed stack %s", stackName)
log.Infof("Removing stack %s", stackName)
err = s.Workspace().RemoveStack(ctx, stackName)
if err != nil {
return fmt.Errorf("Failed to remove stack: %w", err)
}
return nil
}
func InstallDependencies(ctx context.Context, stackDirectory string) error {
prog := tui.GetProgress(ctx)
log := logging.GetLogger(ctx).Named("npm").Sugar()
log.Debugf("Installing pulumi dependencies in %s", stackDirectory)
prog.UpdateIndeterminate("Installing pulumi packages")
npmCmd := logging.Command(
ctx,
logging.CommandLogger{
RootLogger: log.Desugar(),
StdoutLevel: zap.DebugLevel,
},
// loglevel silly is required for the NpmProgress to capture all logs
"npm", "install", "--loglevel", "silly", "--no-fund", "--no-audit",
)
npmProg := &NpmProgress{Progress: prog}
npmCmd.Stdout = io.MultiWriter(npmCmd.Stdout, npmProg)
npmCmd.Stderr = io.MultiWriter(npmCmd.Stderr, npmProg)
npmCmd.Dir = stackDirectory
return npmCmd.Run()
}
package stack
import (
"bufio"
"bytes"
"strings"
"github.com/klothoplatform/klotho/pkg/tui"
)
type NpmProgress struct {
Progress tui.Progress
packageCount int
completed int
}
// Write parses npm's stdout and stderr to drive setting the progress. It uses loglevel silly output
// and string parsing, so it's not very robust.
func (p *NpmProgress) Write(b []byte) (n int, err error) {
scan := bufio.NewScanner(bytes.NewReader(b))
for scan.Scan() {
line := scan.Text()
switch {
case strings.HasPrefix(line, "npm sill tarball no local data"):
// not in the `npm` cache, need to go fetch it
p.packageCount++
case strings.HasPrefix(line, "npm http fetch"):
// downloading the package
p.packageCount++
p.completed++
case strings.HasPrefix(line, "npm sill ADD"):
// added the package to node_modules
p.completed++
// if the package was in the npm cache, it'll skip straight to the ADD
// so just add it to the package count
if p.completed > p.packageCount {
p.packageCount = p.completed
}
}
}
if p.packageCount > 0 {
p.Progress.Update("Installing pulumi packages", p.completed, p.packageCount)
}
return len(b), scan.Err()
}
package stack
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"strings"
"github.com/klothoplatform/klotho/pkg/logging"
"github.com/klothoplatform/klotho/pkg/tui"
"github.com/pulumi/pulumi/sdk/v3/go/auto/events"
"github.com/pulumi/pulumi/sdk/v3/go/common/apitype"
"go.uber.org/zap"
)
type PulumiProgress struct {
Progress tui.Progress
complete int
total int
}
func (p *PulumiProgress) Write(b []byte) (n int, err error) {
scan := bufio.NewScanner(bytes.NewReader(b))
for scan.Scan() {
line := scan.Text()
line = strings.TrimSpace(line)
switch {
case strings.Contains(line, "creating"), strings.Contains(line, "deleting"):
if strings.Contains(line, "failed") {
p.complete++
} else {
p.total++
}
case strings.Contains(line, "created"), strings.Contains(line, "deleted"):
p.complete++
}
}
p.Progress.Update("Deploying stack", p.complete, p.total)
return len(b), scan.Err()
}
func Events(ctx context.Context, action string) chan<- events.EngineEvent {
ech := make(chan events.EngineEvent)
go func() {
log := logging.GetLogger(ctx).Named("pulumi.events").Sugar()
progress := tui.GetProgress(ctx)
status := fmt.Sprintf("%s stack", action)
// resourceStatus tracks each resource's status. The key is the resource's URN and the value is the status.
// The value is an enum that represents the resource's status:
// 0. Pending / resource pre event, this just marks which resources we're aware of
// 1. Refresh complete
// 2. In progress
// 3. Done
resourceStatus := make(map[string]int)
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
for {
select {
case <-ctx.Done():
return
case e, ok := <-ech:
if !ok {
return
}
buf.Reset()
if err := enc.Encode(e); err != nil {
log.Error("Failed to encode pulumi event", zap.Error(err))
continue
}
logLine := strings.TrimSpace(buf.String())
log.Debugf("Pulumi event: %s", logLine)
switch {
case e.PreludeEvent != nil:
progress.UpdateIndeterminate(status)
case e.ResourcePreEvent != nil:
e := e.ResourcePreEvent
if e.Metadata.Op == apitype.OpRefresh {
resourceStatus[e.Metadata.URN] = 0
} else {
resourceStatus[e.Metadata.URN] = 2
}
case e.ResOutputsEvent != nil:
e := e.ResOutputsEvent
if e.Metadata.Op == apitype.OpRefresh {
resourceStatus[e.Metadata.URN] = 1
} else {
resourceStatus[e.Metadata.URN] = 3
}
}
current, total := 0, 0
for _, stateCode := range resourceStatus {
total += 3
current += stateCode
}
if total > 0 {
progress.Update(status, current, total)
}
}
}
}()
return ech
}
package stack
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/k2/model"
"github.com/pulumi/pulumi/sdk/v3/go/auto"
"github.com/pulumi/pulumi/sdk/v3/go/auto/optdestroy"
"github.com/pulumi/pulumi/sdk/v3/go/auto/optpreview"
"github.com/pulumi/pulumi/sdk/v3/go/auto/optup"
"github.com/pulumi/pulumi/sdk/v3/go/common/apitype"
"go.uber.org/zap"
)
type State struct {
Version int
Deployment apitype.DeploymentV3
Outputs map[string]any
Resources map[construct.ResourceId]apitype.ResourceV3
}
type StackInterface interface {
Export(ctx context.Context) (apitype.UntypedDeployment, error)
Up(ctx context.Context, opts ...optup.Option) (auto.UpResult, error)
Preview(ctx context.Context, opts ...optpreview.Option) (auto.PreviewResult, error)
Destroy(ctx context.Context, opts ...optdestroy.Option) (auto.DestroyResult, error)
SetConfig(ctx context.Context, key string, value auto.ConfigValue) error
Workspace() auto.Workspace
Outputs(ctx context.Context) (auto.OutputMap, error)
}
// GetState retrieves the state of a stack
func GetState(ctx context.Context, stack StackInterface) (State, error) {
rawOutputs, err := stack.Outputs(ctx)
if err != nil {
return State{}, err
}
stackOutputs, err := GetStackOutputs(rawOutputs)
if err != nil {
return State{}, err
}
resourceIdByUrn, err := GetResourceIdByURNMap(rawOutputs)
if err != nil {
return State{}, err
}
rawState, err := stack.Export(ctx)
if err != nil {
return State{}, err
}
unmarshalledState := apitype.DeploymentV3{}
err = json.Unmarshal(rawState.Deployment, &unmarshalledState)
if err != nil {
return State{}, err
}
zap.S().Debugf("unmarshalled state: %v", unmarshalledState)
resourcesByResourceId := make(map[construct.ResourceId]apitype.ResourceV3)
for _, res := range unmarshalledState.Resources {
resType := res.URN.QualifiedType()
switch {
case strings.HasPrefix(string(resType), "pulumi:"), strings.HasPrefix(string(resType), "docker:"):
// Skip known non-cloud / Pulumi internal resource (eg: Stack or Provider)
continue
}
id, ok := resourceIdByUrn[string(res.URN)]
if !ok {
zap.S().Warnf("could not find resource id for urn %s", res.URN)
continue
}
var parsedId construct.ResourceId
err := parsedId.Parse(id)
if err != nil {
zap.S().Warnf("could not parse resource id %s: %v", id, err)
continue
}
resourcesByResourceId[parsedId] = res
}
return State{
Version: rawState.Version,
Deployment: unmarshalledState,
Outputs: stackOutputs,
Resources: resourcesByResourceId,
}, nil
}
func GetStackOutputs(rawOutputs auto.OutputMap) (map[string]any, error) {
stackOutputs := make(map[string]any)
outputs, ok := rawOutputs["$outputs"]
if !ok {
return nil, fmt.Errorf("$outputs not found in stack outputs")
}
outputsValue, ok := outputs.Value.(map[string]any)
if !ok {
return nil, fmt.Errorf("failed to parse stack outputs")
}
for key, value := range outputsValue {
stackOutputs[key] = value
}
return stackOutputs, nil
}
func GetResourceIdByURNMap(rawOutputs auto.OutputMap) (map[string]string, error) {
urns, ok := rawOutputs["$urns"]
if !ok {
return nil, fmt.Errorf("$urns not found in stack outputs")
}
urnsValue, ok := urns.Value.(map[string]any)
if !ok {
return nil, fmt.Errorf("failed to parse URNs")
}
resourceIdByUrn := make(map[string]string)
for id, rawUrn := range urnsValue {
if urn, ok := rawUrn.(string); ok {
resourceIdByUrn[urn] = id
} else {
zap.S().Warnf("could not convert urn %v to string", rawUrn)
}
}
return resourceIdByUrn, nil
}
func UpdateConstructStateFromUpResult(sm *model.StateManager, stackReference Reference, summary *auto.UpResult) error {
constructName := stackReference.ConstructURN.ResourceID
c, exists := sm.GetConstructState(constructName)
if !exists {
return fmt.Errorf("construct %s not found in state", constructName)
}
nextStatus := determineNextStatus(c.Status, summary.Summary.Result)
if err := sm.TransitionConstructState(&c, nextStatus); err != nil {
return fmt.Errorf("failed to transition construct state: %v", err)
}
c.LastUpdated = time.Now().Format(time.RFC3339)
sm.SetConstructState(c)
return nil
}
func determineNextStatus(currentStatus model.ConstructStatus, result string) model.ConstructStatus {
switch currentStatus {
case model.ConstructCreating:
if result == "succeeded" {
return model.ConstructCreateComplete
}
return model.ConstructCreateFailed
case model.ConstructUpdating:
if result == "succeeded" {
return model.ConstructUpdateComplete
}
return model.ConstructUpdateFailed
case model.ConstructDeleting:
if result == "succeeded" {
return model.ConstructDeleteComplete
}
return model.ConstructDeleteFailed
default:
return model.ConstructUnknown
}
}
type StateManager struct {
ConstructStackState map[model.URN]State
}
func NewStateManager() *StateManager {
return &StateManager{
ConstructStackState: make(map[model.URN]State),
}
}
func (sm *StateManager) GetResourceState(urn model.URN, id construct.ResourceId) (apitype.ResourceV3, bool) {
stackState, exists := sm.ConstructStackState[urn]
if !exists {
return apitype.ResourceV3{}, false
}
res, exists := stackState.Resources[id]
return res, exists
}
package knowledgebase
import (
"bytes"
"encoding"
"encoding/json"
"fmt"
"github.com/klothoplatform/klotho/pkg/templateutils"
"reflect"
"regexp"
"strconv"
"strings"
"text/template"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/graph_addons"
"go.uber.org/zap"
)
type (
// DynamicValueContext is used to scope the Graph into the template functions
DynamicValueContext struct {
Graph construct.Graph
KnowledgeBase TemplateKB
resultJson bool
}
DynamicContext interface {
DAG() construct.Graph
KB() TemplateKB
ExecuteDecode(tmpl string, data DynamicValueData, value interface{}) error
}
// DynamicValueData provides the resource or edge to the templates as
// `{{ .Self }}` for resources
// `{{ .Source }}` and `{{ .Target }}` for edges
// `{{ .Tag }}` for the global tag
// `{{ .EdgeData }}` for the edge data (see [construct.EdgeData])
DynamicValueData struct {
Resource construct.ResourceId
Edge *construct.Edge
Path construct.PropertyPath
GlobalTag string
}
)
func (ctx DynamicValueContext) DAG() construct.Graph {
return ctx.Graph
}
func (ctx DynamicValueContext) KB() TemplateKB {
return ctx.KnowledgeBase
}
func (ctx DynamicValueContext) TemplateFunctions() template.FuncMap {
return templateutils.WithCommonFuncs(template.FuncMap{
"hasUpstream": ctx.HasUpstream,
"upstream": ctx.Upstream,
"layeredUpstream": ctx.LayeredUpstream,
"allUpstream": ctx.AllUpstream,
"hasDownstream": ctx.HasDownstream,
"layeredDownstream": ctx.LayeredDownstream,
"downstream": ctx.Downstream,
"closestDownstream": ctx.ClosestDownstream,
"allDownstream": ctx.AllDownstream,
"shortestPath": ctx.ShortestPath,
"fieldValue": ctx.FieldValue,
"hasField": ctx.HasField,
"fieldRef": ctx.FieldRef,
"pathAncestor": ctx.PathAncestor,
"pathAncestorExists": ctx.PathAncestorExists,
"toJson": ctx.toJson,
"firstId": firstId,
"filterIds": filterIds,
"sanitizeName": sanitizeName,
})
}
func (ctx DynamicValueContext) Parse(tmpl string) (*template.Template, error) {
t, err := template.New("config").Funcs(ctx.TemplateFunctions()).Parse(tmpl)
return t, err
}
func ExecuteDecodeAsResourceId(ctx DynamicContext, tmpl string, data DynamicValueData) (construct.ResourceId, error) {
var selector construct.ResourceId
err := ctx.ExecuteDecode(tmpl, data, &selector)
if err != nil {
return selector, err
}
if selector.IsZero() {
// ? Should this error instead?
// Make sure we don't just add arbitrary dependencies, since all resources match the zero value
return selector, fmt.Errorf("selector '%s' is zero", tmpl)
}
return selector, nil
}
func (ctx DynamicValueContext) ExecuteDecode(tmpl string, data DynamicValueData, value interface{}) error {
t, err := ctx.Parse(tmpl)
if err != nil {
return err
}
return ctx.ExecuteTemplateDecode(t, data, value)
}
// ExecuteDecode executes the template `tmpl` using `data` and decodes the value into `value`
func (ctx DynamicValueContext) ExecuteTemplateDecode(
t *template.Template,
data DynamicValueData,
value interface{},
) error {
buf := new(bytes.Buffer)
if err := t.Execute(buf, data); err != nil {
return err
}
if ctx.resultJson {
dec := json.NewDecoder(buf)
return dec.Decode(value)
}
// trim the spaces so you don't have to sprinkle the templates with `{{-` and `-}}` (the `-` trims spaces)
bstr := strings.TrimSpace(buf.String())
switch value := value.(type) {
case *string:
*value = bstr
return nil
case *[]byte:
*value = []byte(bstr)
return nil
case *bool:
b, err := strconv.ParseBool(bstr)
if err != nil {
return err
}
*value = b
return nil
case *int:
i, err := strconv.Atoi(bstr)
if err != nil {
return err
}
*value = i
return nil
case *float64:
f, err := strconv.ParseFloat(bstr, 64)
if err != nil {
return err
}
*value = f
return nil
case *float32:
f, err := strconv.ParseFloat(bstr, 32)
if err != nil {
return err
}
*value = float32(f)
return nil
case encoding.TextUnmarshaler:
// notably, this handles `construct.ResourceId` and `construct.IaCValue`
return value.UnmarshalText([]byte(bstr))
}
resultStr := reflect.ValueOf(buf.String())
valueRefl := reflect.ValueOf(value).Elem()
if resultStr.Type().AssignableTo(valueRefl.Type()) {
// this covers alias types like `type MyString string`
valueRefl.Set(resultStr)
return nil
}
err := json.Unmarshal([]byte(bstr), value)
if err == nil {
return nil
}
return fmt.Errorf("cannot decode template result '%s' into %T", buf, value)
}
func (ctx DynamicValueContext) ResolveConfig(config Configuration, data DynamicValueData) (Configuration, error) {
if cfgVal, ok := config.Value.(string); ok {
res, err := ctx.Graph.Vertex(data.Resource)
if err != nil {
return config, err
}
field := reflect.ValueOf(res).Elem().FieldByName(config.Field)
if !field.IsValid() {
return config, fmt.Errorf("field %s not found on resource %s when trying to ResolveConfig", config.Field, data.Resource)
}
valueRefl := reflect.New(field.Type())
value := valueRefl.Interface()
err = ctx.ExecuteDecode(cfgVal, data, value)
if err != nil {
return config, err
}
config.Value = valueRefl.Elem().Interface()
}
return config, nil
}
func (data DynamicValueData) Self() (construct.ResourceId, error) {
if data.Resource.IsZero() {
return construct.ResourceId{}, fmt.Errorf("no .Self is set")
}
return data.Resource, nil
}
func (data DynamicValueData) Source() (construct.ResourceId, error) {
if data.Edge.Source.IsZero() {
return construct.ResourceId{}, fmt.Errorf("no .Source is set")
}
return data.Edge.Source, nil
}
func (data DynamicValueData) Target() (construct.ResourceId, error) {
if data.Edge.Target.IsZero() {
return construct.ResourceId{}, fmt.Errorf("no .Target is set")
}
return data.Edge.Target, nil
}
func (data DynamicValueData) Tag() string {
return data.GlobalTag
}
// Log is primarily used for debugging templates and shouldn't actually appear in any.
// Allows for outputting any intermediate values (such as `$integration := downstream "aws:api_integration" .Self`)
func (data DynamicValueData) Log(level string, message string, args ...interface{}) string {
l := zap.L()
if !data.Resource.IsZero() {
l = l.With(zap.String("resource", data.Resource.String()))
}
if data.Edge != nil {
l = l.With(zap.String("edge", data.Edge.Source.String()+" -> "+data.Edge.Target.String()))
}
switch strings.ToLower(level) {
case "debug":
l.Sugar().Debugf(message, args...)
case "info":
l.Sugar().Infof(message, args...)
case "warn":
l.Sugar().Warnf(message, args...)
case "error":
l.Sugar().Errorf(message, args...)
default:
l.Sugar().Warnf(message, args...)
}
return ""
}
func (data DynamicValueData) EdgeData() *construct.EdgeData {
if d, ok := data.Edge.Properties.Data.(construct.EdgeData); ok {
return &d
} else if !data.Edge.Source.IsZero() {
// default edge data to an empty struct
return &construct.EdgeData{}
}
return nil
}
func TemplateArgToRID(arg any) (construct.ResourceId, error) {
switch arg := arg.(type) {
case construct.ResourceId:
return arg, nil
case construct.Resource:
return arg.ID, nil
case string:
var resId construct.ResourceId
err := resId.UnmarshalText([]byte(arg))
return resId, err
}
return construct.ResourceId{}, fmt.Errorf("invalid argument type %T", arg)
}
func (ctx DynamicValueContext) upstream(selector any, resource construct.ResourceId) (construct.ResourceId, error) {
selId, err := TemplateArgToRID(selector)
if err != nil {
return construct.ResourceId{}, err
}
var match construct.ResourceId
err = graph_addons.WalkUp(ctx.Graph, resource, func(path graph_addons.Path[construct.ResourceId], nerr error) error {
id := path[len(path)-1]
if selId.Matches(id) {
match = id
return graph_addons.StopWalk
}
if GetFunctionality(ctx.KB(), id) != Unknown {
return graph_addons.SkipPath
}
return nil
})
return match, err
}
// Upstream returns the first resource that matches `selector` which is upstream of `resource`
func (ctx DynamicValueContext) HasUpstream(selector any, resource construct.ResourceId) (bool, error) {
up, err := ctx.upstream(selector, resource)
if err != nil {
return false, err
}
return !up.IsZero(), nil
}
// Upstream returns the first resource that matches `selector` which is upstream of `resource`
func (ctx DynamicValueContext) Upstream(selector any, resource construct.ResourceId) (construct.ResourceId, error) {
up, err := ctx.upstream(selector, resource)
if err != nil {
return construct.ResourceId{}, err
}
if up.IsZero() {
return up, fmt.Errorf("no upstream resource of '%s' found matching selector '%s'", resource, selector)
}
return up, nil
}
// LayeredUpstream returns the first resource that matches `selector` which is upstream of `resource` for the specified layer
func (ctx DynamicValueContext) LayeredUpstream(
selector any,
resource construct.ResourceId,
layer string,
) (construct.ResourceId, error) {
selId, err := TemplateArgToRID(selector)
if err != nil {
return construct.ResourceId{}, err
}
dependencyLayer := DependencyLayer(layer)
f, err := layerWalkFunc(ctx.Graph, ctx.KnowledgeBase, resource, dependencyLayer, nil)
if err != nil {
return construct.ResourceId{}, err
}
result := construct.ResourceId{}
wrapper := func(path graph_addons.Path[construct.ResourceId], nerr error) error {
id := path[len(path)-1]
if selId.Matches(id) {
result = id
return graph_addons.StopWalk
}
return f(path, nerr)
}
err = graph_addons.WalkUp(ctx.Graph, resource, wrapper)
if err != nil {
return construct.ResourceId{}, err
}
return result, nil
}
// AllUpstream is like Upstream but returns all transitive upstream resources.
// nolint: lll
func (ctx DynamicValueContext) AllUpstream(selector any, resource construct.ResourceId) (construct.ResourceList, error) {
selId, err := TemplateArgToRID(selector)
if err != nil {
return nil, err
}
upstreams, err := Upstream(ctx.Graph, ctx.KnowledgeBase, resource, AllDepsLayer)
if err != nil {
return []construct.ResourceId{}, err
}
matches := make([]construct.ResourceId, 0, len(upstreams))
for _, up := range upstreams {
if selId.Matches(up) {
matches = append(matches, up)
}
}
return matches, nil
}
func (ctx DynamicValueContext) downstream(selector any, resource construct.ResourceId) (construct.ResourceId, error) {
selId, err := TemplateArgToRID(selector)
if err != nil {
return construct.ResourceId{}, err
}
var match construct.ResourceId
err = graph_addons.WalkDown(ctx.Graph, resource, func(path graph_addons.Path[construct.ResourceId], nerr error) error {
id := path[len(path)-1]
if selId.Matches(id) {
match = id
return graph_addons.StopWalk
}
if GetFunctionality(ctx.KB(), id) != Unknown {
return graph_addons.SkipPath
}
return nil
})
return match, err
}
// LayeredUpstream returns the first resource that matches `selector` which is upstream of `resource` for the specified layer
func (ctx DynamicValueContext) LayeredDownstream(
selector any,
resource construct.ResourceId,
layer string,
) (construct.ResourceId, error) {
selId, err := TemplateArgToRID(selector)
if err != nil {
return construct.ResourceId{}, err
}
dependencyLayer := DependencyLayer(layer)
f, err := layerWalkFunc(ctx.Graph, ctx.KnowledgeBase, resource, dependencyLayer, nil)
if err != nil {
return construct.ResourceId{}, err
}
result := construct.ResourceId{}
wrapper := func(path graph_addons.Path[construct.ResourceId], nerr error) error {
id := path[len(path)-1]
if selId.Matches(id) {
result = id
return graph_addons.StopWalk
}
return f(path, nerr)
}
err = graph_addons.WalkDown(ctx.Graph, resource, wrapper)
if err != nil {
return construct.ResourceId{}, err
}
return result, nil
}
// Downstream returns the first resource that matches `selector` which is downstream of `resource`
func (ctx DynamicValueContext) HasDownstream(selector any, resource construct.ResourceId) (bool, error) {
down, err := ctx.downstream(selector, resource)
if err != nil {
return false, err
}
return !down.IsZero(), nil
}
// Downstream returns the first resource that matches `selector` which is downstream of `resource`
// nolint: lll
func (ctx DynamicValueContext) Downstream(selector any, resource construct.ResourceId) (construct.ResourceId, error) {
down, err := ctx.downstream(selector, resource)
if err != nil {
return construct.ResourceId{}, err
}
if down.IsZero() {
return down, fmt.Errorf("no downstream resource of '%s' found matching selector '%s'", resource, selector)
}
return down, nil
}
func (ctx DynamicValueContext) ClosestDownstream(selector any, resource construct.ResourceId) (construct.ResourceId, error) {
selId, err := TemplateArgToRID(selector)
if err != nil {
return construct.ResourceId{}, err
}
var match construct.ResourceId
err = graph.BFS(ctx.Graph, resource, func(id construct.ResourceId) bool {
if selId.Matches(id) {
match = id
return true
}
return false
})
return match, err
}
// AllDownstream is like Downstream but returns all transitive downstream resources.
// nolint: lll
func (ctx DynamicValueContext) AllDownstream(selector any, resource construct.ResourceId) (construct.ResourceList, error) {
selId, err := TemplateArgToRID(selector)
if err != nil {
return nil, err
}
downstreams, err := Downstream(ctx.Graph, ctx.KnowledgeBase, resource, AllDepsLayer)
if err != nil {
return []construct.ResourceId{}, err
}
matches := make([]construct.ResourceId, 0, len(downstreams))
for _, down := range downstreams {
if selId.Matches(down) {
matches = append(matches, down)
}
}
return matches, nil
}
// ShortestPath returns all the resource IDs on the shortest path from source to destination
func (ctx DynamicValueContext) ShortestPath(source, destination any) (construct.ResourceList, error) {
srcId, err := TemplateArgToRID(source)
if err != nil {
return nil, err
}
dstId, err := TemplateArgToRID(destination)
if err != nil {
return nil, err
}
return graph.ShortestPathStable(ctx.Graph, srcId, dstId, construct.ResourceIdLess)
}
// FieldValue returns the value of `field` on `resource` in json
func (ctx DynamicValueContext) FieldValue(field string, resource any) (any, error) {
resId, err := TemplateArgToRID(resource)
if err != nil {
return "", err
}
r, err := ctx.Graph.Vertex(resId)
if r == nil || err != nil {
return nil, fmt.Errorf("resource '%s' not found", resId)
}
val, err := r.GetProperty(field)
if err != nil || val == nil {
return nil, fmt.Errorf("field '%s' not found on resource '%s'", field, resId)
}
return val, nil
}
func (ctx DynamicValueContext) HasField(field string, resource any) (bool, error) {
resId, err := TemplateArgToRID(resource)
if err != nil {
return false, err
}
r, err := ctx.Graph.Vertex(resId)
if r == nil || err != nil {
return false, fmt.Errorf("resource '%s' not found", resId)
}
property, err := r.GetProperty(field)
if err != nil || property == nil {
return false, nil
}
return true, nil
}
// FieldRef returns a reference to `field` on `resource` (as a PropertyRef)
func (ctx DynamicValueContext) FieldRef(field string, resource any) (construct.PropertyRef, error) {
resId, err := TemplateArgToRID(resource)
if err != nil {
return construct.PropertyRef{}, err
}
return construct.PropertyRef{
Resource: resId,
Property: field,
}, nil
}
// toJson is used to return complex values that do not have TextUnmarshaler implemented
func (ctx DynamicValueContext) toJson(value any) (string, error) {
j, err := json.Marshal(value)
if err != nil {
return "", err
}
return string(j), nil
}
func (ctx DynamicValueContext) PathAncestor(path construct.PropertyPath, depth int) (string, error) {
if depth < 0 {
return "", fmt.Errorf("depth must be >= 0")
}
if depth == 0 {
return path.String(), nil
}
if len(path) <= depth {
return "", fmt.Errorf("depth %d is greater than path length %d", depth, len(path))
}
return path[:len(path)-depth].String(), nil
}
func (ctx DynamicValueContext) PathAncestorExists(path construct.PropertyPath, depth int) bool {
return len(path) > depth
}
func filterIds(selector any, ids []construct.ResourceId) ([]construct.ResourceId, error) {
selId, err := TemplateArgToRID(selector)
if err != nil {
return nil, err
}
matches := make([]construct.ResourceId, 0, len(ids))
for _, r := range ids {
if selId.Matches(r) {
matches = append(matches, r)
}
}
return matches, nil
}
func firstId(selector any, ids []construct.ResourceId) (construct.ResourceId, error) {
selId, err := TemplateArgToRID(selector)
if err != nil {
return construct.ResourceId{}, err
}
if len(ids) == 0 {
return construct.ResourceId{}, fmt.Errorf("no ids")
}
for _, r := range ids {
if selId.Matches(r) {
return r, nil
}
}
return construct.ResourceId{}, fmt.Errorf("no ids match selector")
}
// invalidNameCharacters matches characters that are not allowed in resource names. Basically,
// the same as [construct2.resourceNamePattern] except inverted.
var invalidNameCharacters = regexp.MustCompile(`[^a-zA-Z0-9_./\-:\[\]]`)
func sanitizeName(name string) string {
return invalidNameCharacters.ReplaceAllString(name, "")
}
package knowledgebase
import (
"fmt"
construct "github.com/klothoplatform/klotho/pkg/construct"
"gopkg.in/yaml.v3"
)
type (
EdgeTemplate struct {
Source construct.ResourceId `yaml:"source"`
Target construct.ResourceId `yaml:"target"`
// AlwaysProcess signals that the edge should always be processed even if the source and target exist in the input graph
// currently we dont check edges for operational rules if they previously existed and this flag is set to false
AlwaysProcess bool `yaml:"always_process"`
// DirectEdgeOnly signals that the edge cannot be used within constructing other paths
// and can only be used as a direct edge
DirectEdgeOnly bool `yaml:"direct_edge_only"`
// DeploymentOrderReversed is specified when the edge is in the opposite direction of the deployment order
DeploymentOrderReversed bool `yaml:"deployment_order_reversed"`
// DeletetionDependent is used to specify edges which should not influence the deletion criteria of a resource
// a true value specifies the target being deleted is dependent on the source and do not need to depend on
// satisfication of the deletion criteria to attempt to delete the true source of the edge.
DeletionDependent bool `yaml:"deletion_dependent"`
// Unique see type [Unique]
Unique Unique `yaml:"unique"`
OperationalRules []OperationalRule `yaml:"operational_rules"`
EdgeWeightMultiplier float32 `yaml:"edge_weight_multiplier"`
Classification []string `yaml:"classification"`
NoIac bool `json:"no_iac" yaml:"no_iac"`
}
MultiEdgeTemplate struct {
Resource construct.ResourceId `yaml:"resource"`
Sources []construct.ResourceId `yaml:"sources"`
Targets []construct.ResourceId `yaml:"targets"`
// DirectEdgeOnly signals that the edge cannot be used within constructing other paths
// and can only be used as a direct edge
DirectEdgeOnly bool `yaml:"direct_edge_only"`
// DeploymentOrderReversed is specified when the edge is in the opposite direction of the deployment order
DeploymentOrderReversed bool `yaml:"deployment_order_reversed"`
// DeletetionDependent is used to specify edges which should not influence the deletion criteria of a resource
// a true value specifies the target being deleted is dependent on the source and do not need to depend on
// satisfication of the deletion criteria to attempt to delete the true source of the edge.
DeletionDependent bool `yaml:"deletion_dependent"`
// Unique see type [Unique]
Unique Unique `yaml:"unique"`
OperationalRules []OperationalRule `yaml:"operational_rules"`
EdgeWeightMultiplier float32 `yaml:"edge_weight_multiplier"`
Classification []string `yaml:"classification"`
NoIac bool `json:"no_iac" yaml:"no_iac"`
}
// Unique is used to specify whether the source or target of an edge must only have a single edge of this type
// - Source=false & Target=false (default) indicates that S->T is a many-to-many relationship
// (for examples, Lambda -> DynamoDB)
// - Source=true & Target=false indicates that S->T is a one-to-many relationship
// (for examples, SQS -> Event Source Mapping)
// - Source=false & Target=true indicates that S->T is a many-to-one relationship
// (for examples, Event Source Mapping -> Lambda)
// - Source=true & Target=true indicates that S->T is a one-to-one relationship
// (for examples, RDS Proxy -> Proxy Target Group)
Unique struct {
// Source indicates whether the source must only have a single edge of this type.
Source bool `yaml:"source"`
// Target indicates whether the target must only have a single edge of this type.
Target bool `yaml:"target"`
}
)
func EdgeTemplatesFromMulti(multi MultiEdgeTemplate) []EdgeTemplate {
var templates []EdgeTemplate
for _, source := range multi.Sources {
templates = append(templates, EdgeTemplate{
Source: source,
Target: multi.Resource,
DirectEdgeOnly: multi.DirectEdgeOnly,
DeploymentOrderReversed: multi.DeploymentOrderReversed,
DeletionDependent: multi.DeletionDependent,
Unique: multi.Unique,
OperationalRules: multi.OperationalRules,
EdgeWeightMultiplier: multi.EdgeWeightMultiplier,
Classification: multi.Classification,
})
}
for _, target := range multi.Targets {
templates = append(templates, EdgeTemplate{
Source: multi.Resource,
Target: target,
DirectEdgeOnly: multi.DirectEdgeOnly,
DeploymentOrderReversed: multi.DeploymentOrderReversed,
DeletionDependent: multi.DeletionDependent,
Unique: multi.Unique,
OperationalRules: multi.OperationalRules,
EdgeWeightMultiplier: multi.EdgeWeightMultiplier,
Classification: multi.Classification,
})
}
return templates
}
func (u *Unique) UnmarshalYAML(n *yaml.Node) error {
type helper Unique
var h helper
if err := n.Decode(&h); err == nil {
*u = Unique(h)
return nil
}
var str string
if err := n.Decode(&str); err == nil {
switch str {
case "one_to_one", "one-to-one":
u.Source = true
u.Target = true
case "one_to_many", "one-to-many":
u.Source = true
u.Target = false
case "many_to_one", "many-to-one":
u.Source = false
u.Target = true
case "many_to_many", "many-to-many":
u.Source = false
u.Target = false
default:
return fmt.Errorf("invalid 'unique' string: %s", str)
}
return nil
}
var b bool
if err := n.Decode(&b); err == nil {
u.Source = b
u.Target = b
return nil
}
return fmt.Errorf("could not decode 'unique' field")
}
// CanAdd returns whether the edge source -> target can be added based on the uniqueness rules.
// - "many-to-many" always returns true
// - "one-to-many" returns true if the target does not have any edges that match the source type
// - "many-to-one" returns true if the source does not have any edges that match the target type
// - "one-to-one" returns true if neither the source nor the target have any edges that match the other type
func (u Unique) CanAdd(edges []construct.Edge, source, target construct.ResourceId) bool {
if !u.Source && !u.Target {
return true
}
if u.Source { // one-to-many or one-to-one
sourceSel := construct.ResourceId{Provider: source.Provider, Type: source.Type}
for _, e := range edges {
if e.Target != target {
continue
}
// Make sure that the target doesn't have any edges that match the source type
if sourceSel.Matches(e.Source) && e.Source != source {
return false
}
}
}
if u.Target { // many-to-one or one-to-one
targetSel := construct.ResourceId{Provider: target.Provider, Type: target.Type}
for _, e := range edges {
if e.Source != source {
continue
}
// Make sure that the source doesn't have any edges that match the target type
if targetSel.Matches(e.Target) && e.Target != target {
return false
}
}
}
return true
}
package knowledgebase
import (
"bytes"
"errors"
"fmt"
"strings"
"text/template"
construct "github.com/klothoplatform/klotho/pkg/construct"
)
type (
Consumption struct {
Emitted []ConsumptionObject `json:"emitted" yaml:"emitted"`
Consumed []ConsumptionObject `json:"consumed" yaml:"consumed"`
}
ConsumptionObject struct {
Model string `json:"model" yaml:"model"`
Value any `json:"value" yaml:"value"`
Resource string `json:"resource" yaml:"resource"`
PropertyPath string `json:"property_path" yaml:"property_path"`
Converter string `json:"converter" yaml:"converter"`
}
DelayedConsumption struct {
Value any
Resource construct.ResourceId
PropertyPath string
}
)
func sanitizeForConsumption(ctx DynamicContext, resource *construct.Resource, propTmpl Property, val any) (any, error) {
err := propTmpl.Validate(resource, val, ctx)
var sanErr *SanitizeError
if errors.As(err, &sanErr) {
val = sanErr.Sanitized
} else if err != nil {
return val, err
}
return val, nil
}
func ConsumeFromResource(consumer, emitter *construct.Resource, ctx DynamicContext) ([]DelayedConsumption, error) {
consumerTemplate, err := ctx.KB().GetResourceTemplate(consumer.ID)
if err != nil {
return nil, err
}
emitterTemplate, err := ctx.KB().GetResourceTemplate(emitter.ID)
if err != nil {
return nil, err
}
var errs error
addErr := func(consume ConsumptionObject, emit ConsumptionObject, err error) {
if err == nil {
return
}
errs = errors.Join(errs, fmt.Errorf(
"error consuming %s from emitter %s: %w",
consume.PropertyPath, emit.PropertyPath, err,
))
}
delays := []DelayedConsumption{}
for _, consume := range consumerTemplate.Consumption.Consumed {
for _, emit := range emitterTemplate.Consumption.Emitted {
if consume.Model == emit.Model {
val, err := emit.Emit(ctx, emitter.ID)
if err != nil {
addErr(consume, emit, err)
continue
}
id := consumer.ID
if consume.Resource != "" {
data := DynamicValueData{Resource: consumer.ID}
err = ctx.ExecuteDecode(consume.Resource, data, &id)
if err != nil {
addErr(consume, emit, err)
continue
}
}
consumeTmpl, err := ctx.KB().GetResourceTemplate(id)
if err != nil {
addErr(consume, emit, err)
continue
}
resource, err := ctx.DAG().Vertex(id)
if err != nil {
addErr(consume, emit, err)
continue
}
// we ignore the error here because if we cant get the property we will attempt to apply it as a constraint later on
pval, _ := resource.GetProperty(consume.PropertyPath)
if consume.Converter != "" {
val, err = consume.Convert(val, id, ctx)
if err != nil {
addErr(consume, emit, err)
continue
}
}
val, err = sanitizeForConsumption(ctx, resource, consumeTmpl.GetProperty(consume.PropertyPath), val)
if err != nil {
addErr(consume, emit, err)
continue
}
if pval == nil {
delays = append(delays, DelayedConsumption{
Value: val,
Resource: id,
PropertyPath: consume.PropertyPath,
})
continue
}
err = consume.Consume(val, ctx, resource)
if err != nil {
addErr(consume, emit, err)
continue
}
}
}
}
return delays, errs
}
// HasConsumedFromResource returns true if the consumer has consumed from the emitter
// In order to return true, only one of the emitted values has to be set correctly
func HasConsumedFromResource(consumer, emitter *construct.Resource, ctx DynamicContext) (bool, error) {
consumerTemplate, err := ctx.KB().GetResourceTemplate(consumer.ID)
if err != nil {
return false, err
}
emitterTemplate, err := ctx.KB().GetResourceTemplate(emitter.ID)
if err != nil {
return false, err
}
noEmittedMatches := true
var errs error
for _, consume := range consumerTemplate.Consumption.Consumed {
for _, emit := range emitterTemplate.Consumption.Emitted {
if consume.Model == emit.Model {
noEmittedMatches = false
val, err := emit.Emit(ctx, emitter.ID)
if err != nil {
errs = errors.Join(errs, err)
continue
}
id := consumer.ID
if consume.Resource != "" {
data := DynamicValueData{Resource: consumer.ID}
err = ctx.ExecuteDecode(consume.Resource, data, &id)
if err != nil {
errs = errors.Join(errs, err)
continue
}
}
resource, err := ctx.DAG().Vertex(id)
if err != nil {
errs = errors.Join(errs, err)
continue
}
pval, _ := resource.GetProperty(consume.PropertyPath)
if pval == nil {
continue
}
if consume.Converter != "" {
val, err = consume.Convert(val, id, ctx)
if err != nil {
errs = errors.Join(errs, err)
continue
}
}
rt, err := ctx.KB().GetResourceTemplate(resource.ID)
if err != nil {
errs = errors.Join(errs, err)
continue
}
prop := rt.GetProperty(consume.PropertyPath)
if prop == nil {
errs = errors.Join(errs, fmt.Errorf("property %s not found", consume.PropertyPath))
continue
}
val, err = sanitizeForConsumption(ctx, resource, prop, val)
if err != nil {
errs = errors.Join(errs, err)
continue
}
if prop.Contains(pval, val) {
return true, nil
}
}
}
}
return noEmittedMatches, nil
}
func (c *ConsumptionObject) Convert(value any, res construct.ResourceId, ctx DynamicContext) (any, error) {
if c.Converter == "" {
return value, fmt.Errorf("no converter specified")
}
if c.PropertyPath == "" {
return value, fmt.Errorf("no property path specified")
}
t, err := template.New("config").Funcs(template.FuncMap{
"sub": func(a int, b int) int {
return a - b
},
"add": func(a int, b int) int {
return a + b
},
},
).Parse(c.Converter)
if err != nil {
return value, err
}
buf := new(bytes.Buffer)
if err := t.Execute(buf, value); err != nil {
return value, err
}
bstr := strings.TrimSpace(buf.String())
// We convert here just to make sure it gets translated to the right type of input
// We will convert again when consuming to ensure strings/etc are converted to their respective struct
// if they match a property ref/id/etc
val, err := TransformToPropertyValue(res, c.PropertyPath, bstr, ctx, DynamicValueData{Resource: res})
if err != nil {
return val, err
}
val, err = TransformToPropertyValue(res, c.PropertyPath, val, ctx, DynamicValueData{Resource: res})
if err != nil {
return val, err
}
return val, nil
}
func (c *ConsumptionObject) Emit(ctx DynamicContext, resource construct.ResourceId) (any, error) {
if c.Value == "" {
return nil, fmt.Errorf("no value specified")
}
if c.Model == "" {
return nil, fmt.Errorf("no property path specified")
}
if c.Resource != "" {
data := DynamicValueData{Resource: resource}
err := ctx.ExecuteDecode(c.Resource, data, resource)
if err != nil {
return nil, err
}
}
model := ctx.KB().GetModel(c.Model)
data := DynamicValueData{Resource: resource}
val, err := model.GetObjectValue(c.Value, ctx, data)
if err != nil {
return nil, err
}
if err != nil {
return val, err
}
return val, nil
}
func (c *ConsumptionObject) Consume(val any, ctx DynamicContext, resource *construct.Resource) error {
rt, err := ctx.KB().GetResourceTemplate(resource.ID)
if err != nil {
return err
}
propTmpl := rt.GetProperty(c.PropertyPath)
return propTmpl.AppendProperty(resource, val)
}
package knowledgebase
import (
"errors"
"fmt"
"reflect"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/graph_addons"
"github.com/klothoplatform/klotho/pkg/set"
)
type (
// DependencyLayer represents how far away a resource to return for the [Upstream]/[Downstream] methods.
// 1. ResourceLocalLayer (layer 1) represents any unique resources the target resource needs to be operational,
// transitively.
// 2. ResourceGlueLayer (layer 2) represents all upstream/downstream resources that represent glue.
// This will not include any other functional resources and will stopsearching paths
// once a functional resource is reached.
// 3. FirstFunctionalLayer (layer 3) represents all upstream/downstream resources that represent glue and
// the first functional resource in other paths from the target resource.
DependencyLayer string
)
const (
// ResourceLocalLayer (layer 1)
ResourceLocalLayer DependencyLayer = "local"
// ResourceDirectLayer (layer 2)
ResourceDirectLayer DependencyLayer = "direct"
// ResourceGlueLayer (layer 2)
ResourceGlueLayer DependencyLayer = "glue"
// FirstFunctionalLayer (layer 3)
FirstFunctionalLayer DependencyLayer = "first"
// AllDepsLayer (layer 4)
AllDepsLayer DependencyLayer = "all"
)
func resourceLocal(
dag construct.Graph,
kb TemplateKB,
rid construct.ResourceId,
ids set.Set[construct.ResourceId],
) graph_addons.WalkGraphFunc[construct.ResourceId] {
return func(path graph_addons.Path[construct.ResourceId], nerr error) error {
if len(path) <= 1 {
// skip source, this shouldn't happen but just in case
return nil
}
// Since we're skipping the path if it doesn't match, we only need to check the most recently added (ie, the last)
// resource in the path.
last := path[len(path)-1]
prevLast := path[len(path)-2]
sideEffect, err := IsOperationalResourceSideEffect(dag, kb, prevLast, last)
if err != nil {
return errors.Join(nerr, err)
}
if !sideEffect {
return graph_addons.SkipPath
}
ids.Add(last)
return nil
}
}
func resourceDirect(
dag construct.Graph,
ids set.Set[construct.ResourceId],
) graph_addons.WalkGraphFunc[construct.ResourceId] {
return func(path graph_addons.Path[construct.ResourceId], nerr error) error {
id := path[len(path)-1]
if ids != nil {
ids.Add(id)
}
return graph_addons.SkipPath
}
}
func resourceGlue(
kb TemplateKB,
ids set.Set[construct.ResourceId],
) graph_addons.WalkGraphFunc[construct.ResourceId] {
return func(path graph_addons.Path[construct.ResourceId], nerr error) error {
id := path[len(path)-1]
if GetFunctionality(kb, id) == Unknown {
if ids != nil {
ids.Add(id)
}
return nil
}
return graph_addons.SkipPath
}
}
func firstFunctional(
kb TemplateKB,
ids set.Set[construct.ResourceId],
) graph_addons.WalkGraphFunc[construct.ResourceId] {
return func(path graph_addons.Path[construct.ResourceId], nerr error) error {
id := path[len(path)-1]
if ids != nil {
ids.Add(id)
}
if GetFunctionality(kb, id) == Unknown {
return nil
}
return graph_addons.SkipPath
}
}
func allDeps(
ids set.Set[construct.ResourceId],
) graph_addons.WalkGraphFunc[construct.ResourceId] {
resourceSet := set.Set[construct.ResourceId]{}
return func(path graph_addons.Path[construct.ResourceId], nerr error) error {
id := path[len(path)-1]
if ids != nil {
ids.Add(id)
}
resourceSet.Add(id)
return nil
}
}
// DependenciesSkipEdgeLayer returns a function which can be used in calls to
// [construct.DownstreamDependencies] and [construct.UpstreamDependencies].
func DependenciesSkipEdgeLayer(
dag construct.Graph,
kb TemplateKB,
rid construct.ResourceId,
layer DependencyLayer,
) func(construct.Edge) bool {
switch layer {
case ResourceLocalLayer:
return func(e construct.Edge) bool {
isSideEffect, err := IsOperationalResourceSideEffect(dag, kb, rid, e.Target)
return err != nil || !isSideEffect
}
case ResourceGlueLayer:
return func(e construct.Edge) bool {
return GetFunctionality(kb, e.Target) != Unknown
}
case FirstFunctionalLayer:
return func(e construct.Edge) bool {
// Keep the source -> X edges, since source likely is != Unknown
if e.Source == rid {
return false
}
// Unknown -> X edges are not interesting, keep those
if GetFunctionality(kb, e.Source) == Unknown {
return false
}
// Since source is now != Unknown, only keep edges w/ target == Unknown
return GetFunctionality(kb, e.Target) != Unknown
}
default:
fallthrough
case AllDepsLayer:
return construct.DontSkipEdges
}
}
func Downstream(dag construct.Graph, kb TemplateKB, rid construct.ResourceId, layer DependencyLayer) ([]construct.ResourceId, error) {
result := set.Set[construct.ResourceId]{}
var f graph_addons.WalkGraphFunc[construct.ResourceId]
switch layer {
case ResourceLocalLayer:
f = resourceLocal(dag, kb, rid, result)
case ResourceDirectLayer:
// use a more performant implementation for direct since we can use the edges directly.
edges, err := dag.Edges()
if err != nil {
return nil, err
}
var ids []construct.ResourceId
for _, edge := range edges {
if edge.Source == rid {
ids = append(ids, edge.Target)
}
}
return ids, nil
case ResourceGlueLayer:
f = resourceGlue(kb, result)
case FirstFunctionalLayer:
f = firstFunctional(kb, result)
case AllDepsLayer:
f = allDeps(result)
default:
return nil, fmt.Errorf("unknown layer %s", layer)
}
err := graph_addons.WalkDown(dag, rid, f)
return result.ToSlice(), err
}
func DownstreamFunctional(dag construct.Graph, kb TemplateKB, resource construct.ResourceId) ([]construct.ResourceId, error) {
var result []construct.ResourceId
err := graph_addons.WalkDown(dag, resource, func(path graph_addons.Path[construct.ResourceId], nerr error) error {
id := path[len(path)-1]
if GetFunctionality(kb, id) != Unknown {
result = append(result, id)
return graph_addons.SkipPath
}
return nil
})
return result, err
}
func Upstream(dag construct.Graph, kb TemplateKB, rid construct.ResourceId, layer DependencyLayer) ([]construct.ResourceId, error) {
result := set.Set[construct.ResourceId]{}
var f graph_addons.WalkGraphFunc[construct.ResourceId]
switch layer {
case ResourceLocalLayer:
f = resourceLocal(dag, kb, rid, result)
case ResourceDirectLayer:
// use a more performant implementation for direct since we can use the edges directly.
edges, err := dag.Edges()
if err != nil {
return nil, err
}
var ids []construct.ResourceId
for _, edge := range edges {
if edge.Target == rid {
ids = append(ids, edge.Source)
}
}
return ids, nil
case ResourceGlueLayer:
f = resourceGlue(kb, result)
case FirstFunctionalLayer:
f = firstFunctional(kb, result)
case AllDepsLayer:
f = allDeps(result)
default:
return nil, fmt.Errorf("unknown layer %s", layer)
}
err := graph_addons.WalkUp(dag, rid, f)
return result.ToSlice(), err
}
func layerWalkFunc(
dag construct.Graph,
kb TemplateKB,
rid construct.ResourceId,
layer DependencyLayer,
result set.Set[construct.ResourceId],
) (graph_addons.WalkGraphFunc[construct.ResourceId], error) {
if result == nil {
result = set.Set[construct.ResourceId]{}
}
switch layer {
case ResourceLocalLayer:
return resourceLocal(dag, kb, rid, result), nil
case ResourceDirectLayer:
return resourceDirect(dag, result), nil
case ResourceGlueLayer:
return resourceGlue(kb, result), nil
case FirstFunctionalLayer:
return firstFunctional(kb, result), nil
case AllDepsLayer:
return allDeps(result), nil
default:
return nil, fmt.Errorf("unknown layer %s", layer)
}
}
func UpstreamFunctional(dag construct.Graph, kb TemplateKB, resource construct.ResourceId) ([]construct.ResourceId, error) {
var result []construct.ResourceId
err := graph_addons.WalkUp(dag, resource, func(path graph_addons.Path[construct.ResourceId], nerr error) error {
id := path[len(path)-1]
if GetFunctionality(kb, id) != Unknown {
result = append(result, id)
return graph_addons.SkipPath
}
return nil
})
return result, err
}
func IsOperationalResourceSideEffect(dag construct.Graph, kb TemplateKB, rid, sideEffect construct.ResourceId) (bool, error) {
template, err := kb.GetResourceTemplate(rid)
if err != nil {
return false, fmt.Errorf("error cheecking %s is side effect of %s: %w", sideEffect, rid, err)
}
sideEffectResource, err := dag.Vertex(sideEffect)
if err != nil {
return false, fmt.Errorf("could not find side effect resource %s: %w", sideEffect, err)
}
resource, err := dag.Vertex(rid)
if err != nil {
return false, fmt.Errorf("could not find resource %s: %w", rid, err)
}
dynCtx := DynamicValueContext{Graph: dag, KnowledgeBase: kb}
isSideEffect := false
err = template.LoopProperties(resource, func(property Property) error {
ruleSatisfied := false
rule := property.Details().OperationalRule
if rule == nil || len(rule.Step.Resources) == 0 {
return nil
}
path, err := resource.PropertyPath(property.Details().Path)
if err != nil {
return fmt.Errorf(
"error checking if %s is side effect of %s in property %s: %w",
sideEffect, rid, property.Details().Name, err,
)
}
data := DynamicValueData{Resource: rid, Path: path}
step := rule.Step
// We only check if the resource selector is a match in terms of properties and classifications (not the actual id)
// We do this because if we have explicit ids in the selector and someone changes the id of a side effect resource
// we would no longer think it is a side effect since the id would no longer match.
// To combat this we just check against type
for j, resourceSelector := range step.Resources {
if match, err := resourceSelector.IsMatch(dynCtx, data, sideEffectResource); match {
ruleSatisfied = true
break
} else if err != nil {
return fmt.Errorf(
"error checking if %s is side effect of %s in property %s, resource %d: %w",
sideEffect, rid, property.Details().Name, j, err,
)
}
}
if !ruleSatisfied {
return nil
}
// If the side effect resource fits the rule we then perform 2 more checks
// 1. is there a path in the direction of the rule
// 2. Is the property set with the resource that we are checking for
if step.Direction == DirectionUpstream {
resources, err := graph.ShortestPathStable(dag, sideEffect, rid, construct.ResourceIdLess)
if len(resources) == 0 || err != nil {
return nil
}
} else {
resources, err := graph.ShortestPathStable(dag, rid, sideEffect, construct.ResourceIdLess)
if len(resources) == 0 || err != nil {
return nil
}
}
propertyVal, err := resource.GetProperty(property.Details().Path)
if err != nil || propertyVal == nil {
return nil
}
val := reflect.ValueOf(propertyVal)
if val.Kind() == reflect.Array || val.Kind() == reflect.Slice {
for i := 0; i < val.Len(); i++ {
if arrId, ok := val.Index(i).Interface().(construct.ResourceId); ok && arrId == sideEffect {
isSideEffect = true
return ErrStopWalk
} else if ref, ok := val.Index(i).Interface().(construct.PropertyRef); ok && ref.Resource == sideEffect {
isSideEffect = true
return ErrStopWalk
}
}
} else {
if val.IsZero() {
return nil
}
if valId, ok := val.Interface().(construct.ResourceId); ok && valId == sideEffect {
isSideEffect = true
return ErrStopWalk
} else if ref, ok := val.Interface().(construct.PropertyRef); ok && ref.Resource == sideEffect {
isSideEffect = true
return ErrStopWalk
}
}
return nil
})
return isSideEffect, err
}
package knowledgebase
import (
"errors"
"fmt"
"sort"
"text/template"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
"go.uber.org/zap"
)
//go:generate mockgen --source=./kb.go -destination=./template_kb_mock_test.go -package=knowledgebase
//go:generate mockgen --source=./kb.go -destination=../engine/operational_eval/template_kb_mock_test.go -package=operational_eval
type (
TemplateKB interface {
ListResources() []*ResourceTemplate
GetModel(model string) *Model
Edges() ([]graph.Edge[*ResourceTemplate], error)
AddResourceTemplate(template *ResourceTemplate) error
AddEdgeTemplate(template *EdgeTemplate) error
GetResourceTemplate(id construct.ResourceId) (*ResourceTemplate, error)
GetEdgeTemplate(from, to construct.ResourceId) *EdgeTemplate
HasDirectPath(from, to construct.ResourceId) bool
HasFunctionalPath(from, to construct.ResourceId) bool
AllPaths(from, to construct.ResourceId) ([][]*ResourceTemplate, error)
GetAllowedNamespacedResourceIds(ctx DynamicValueContext, resourceId construct.ResourceId) ([]construct.ResourceId, error)
GetClassification(id construct.ResourceId) Classification
GetResourcesNamespaceResource(resource *construct.Resource) (construct.ResourceId, error)
GetResourcePropertyType(resource construct.ResourceId, propertyName string) string
GetPathSatisfactionsFromEdge(source, target construct.ResourceId) ([]EdgePathSatisfaction, error)
}
// KnowledgeBase is a struct that represents the object which contains the knowledge of how to make resources operational
KnowledgeBase struct {
underlying Graph
Models map[string]*Model
}
EdgePathSatisfaction struct {
// Signals if the classification is derived from the target or not
// we need this to know how to construct the edge we are going to run expansion on if we have resource values in the classification
Classification string
Source PathSatisfactionRoute
Target PathSatisfactionRoute
}
ValueOrTemplate struct {
Value any
Template *template.Template
}
Graph = graph.Graph[string, *ResourceTemplate]
)
const (
glueEdgeWeight = 0
defaultEdgeWeight = 1
functionalBoundaryEdgeWeight = 10000
)
func NewKB() *KnowledgeBase {
return &KnowledgeBase{
underlying: graph.New[string, *ResourceTemplate](func(t *ResourceTemplate) string {
return t.Id().QualifiedTypeName()
}, graph.Directed()),
}
}
func (kb *KnowledgeBase) Graph() Graph {
return kb.underlying
}
func (kb *KnowledgeBase) GetModel(model string) *Model {
return kb.Models[model]
}
// ListResources returns a list of all resources in the knowledge base
// The returned list of resource templates will be sorted by the templates fully qualified type name
func (kb *KnowledgeBase) ListResources() []*ResourceTemplate {
predecessors, err := kb.underlying.PredecessorMap()
if err != nil {
panic(err)
}
var result []*ResourceTemplate
var ids []string
for vId := range predecessors {
ids = append(ids, vId)
}
sort.Strings(ids)
for _, id := range ids {
if v, err := kb.underlying.Vertex(id); err == nil {
result = append(result, v)
} else {
panic(err)
}
}
return result
}
func (kb *KnowledgeBase) Edges() ([]graph.Edge[*ResourceTemplate], error) {
edges, err := kb.underlying.Edges()
if err != nil {
return nil, err
}
var result []graph.Edge[*ResourceTemplate]
for _, edge := range edges {
src, err := kb.underlying.Vertex(edge.Source)
if err != nil {
return nil, err
}
dst, err := kb.underlying.Vertex(edge.Target)
if err != nil {
return nil, err
}
result = append(result, graph.Edge[*ResourceTemplate]{
Source: src,
Target: dst,
})
}
return result, nil
}
func (kb *KnowledgeBase) AddResourceTemplate(template *ResourceTemplate) error {
return kb.underlying.AddVertex(template)
}
func (kb *KnowledgeBase) AddEdgeTemplate(template *EdgeTemplate) error {
sourceTmpl, err := kb.underlying.Vertex(template.Source.QualifiedTypeName())
if err != nil {
return fmt.Errorf("could not find source template: %w", err)
}
targetTmpl, err := kb.underlying.Vertex(template.Target.QualifiedTypeName())
if err != nil {
return fmt.Errorf("could not find target template: %w", err)
}
weight := defaultEdgeWeight
if sourceTmpl.GetFunctionality() == Unknown {
if targetTmpl.GetFunctionality() == Unknown {
weight = glueEdgeWeight
} else {
weight = functionalBoundaryEdgeWeight
}
}
return kb.underlying.AddEdge(
template.Source.QualifiedTypeName(),
template.Target.QualifiedTypeName(),
graph.EdgeData(template),
graph.EdgeWeight(weight),
)
}
func (kb *KnowledgeBase) GetResourceTemplate(id construct.ResourceId) (*ResourceTemplate, error) {
return kb.underlying.Vertex(id.QualifiedTypeName())
}
func (kb *KnowledgeBase) GetEdgeTemplate(from, to construct.ResourceId) *EdgeTemplate {
edge, err := kb.underlying.Edge(from.QualifiedTypeName(), to.QualifiedTypeName())
// Even if the edge does not exist, we still return nil so that we know there is no edge template since there is no edge
if err != nil {
return nil
}
data := edge.Properties.Data
if data == nil {
return nil
}
if template, ok := data.(*EdgeTemplate); ok {
return template
}
return nil
}
func (kb *KnowledgeBase) HasDirectPath(from, to construct.ResourceId) bool {
_, err := kb.underlying.Edge(from.QualifiedTypeName(), to.QualifiedTypeName())
return err == nil
}
func (kb *KnowledgeBase) HasFunctionalPath(from, to construct.ResourceId) bool {
fromType := from.QualifiedTypeName()
toType := to.QualifiedTypeName()
if fromType == toType {
// For resources that can reference themselves, such as aws:api_resource
return true
}
path, err := graph.ShortestPathStable(
kb.underlying,
from.QualifiedTypeName(),
to.QualifiedTypeName(),
func(a, b string) bool { return a < b },
)
if errors.Is(err, graph.ErrTargetNotReachable) {
return false
}
if err != nil {
zap.S().Errorf(
"error in finding shortes path from %s to %s: %v",
from.QualifiedTypeName(), to.QualifiedTypeName(), err,
)
return false
}
for _, id := range path[1 : len(path)-1] {
template, err := kb.underlying.Vertex(id)
if err != nil {
panic(err)
}
if template.GetFunctionality() != Unknown {
return false
}
}
return true
}
func (kb *KnowledgeBase) AllPaths(from, to construct.ResourceId) ([][]*ResourceTemplate, error) {
paths, err := graph.AllPathsBetween(kb.underlying, from.QualifiedTypeName(), to.QualifiedTypeName())
if err != nil {
return nil, err
}
resources := make([][]*ResourceTemplate, len(paths))
for i, path := range paths {
resources[i] = make([]*ResourceTemplate, len(path))
for j, id := range path {
resources[i][j], _ = kb.underlying.Vertex(id)
}
}
return resources, nil
}
func (kb *KnowledgeBase) GetAllowedNamespacedResourceIds(ctx DynamicValueContext, resourceId construct.ResourceId) ([]construct.ResourceId, error) {
template, err := kb.GetResourceTemplate(resourceId)
if err != nil {
return nil, fmt.Errorf("could not find resource template for %s: %w", resourceId, err)
}
var result []construct.ResourceId
property := template.GetNamespacedProperty()
if property == nil {
return result, nil
}
rule := property.Details().OperationalRule
if rule == nil {
return result, nil
}
if rule.Step.Resources != nil {
for _, resource := range rule.Step.Resources {
if resource.Selector != "" {
id, err := ExecuteDecodeAsResourceId(ctx, resource.Selector, DynamicValueData{Resource: resourceId})
if err != nil {
return nil, err
}
template, err := kb.GetResourceTemplate(id)
if err != nil {
return nil, err
}
if template.ResourceContainsClassifications(resource.Classifications) {
result = append(result, id)
}
}
if resource.Classifications != nil && resource.Selector == "" {
for _, resTempalte := range kb.ListResources() {
if resTempalte.ResourceContainsClassifications(resource.Classifications) {
result = append(result, resTempalte.Id())
}
}
}
}
}
return result, nil
}
func GetFunctionality(kb TemplateKB, id construct.ResourceId) Functionality {
template, _ := kb.GetResourceTemplate(id)
if template == nil {
return Unknown
}
return template.GetFunctionality()
}
func (kb *KnowledgeBase) GetClassification(id construct.ResourceId) Classification {
template, _ := kb.GetResourceTemplate(id)
if template == nil {
return Classification{}
}
return template.Classification
}
func (kb *KnowledgeBase) GetResourcesNamespaceResource(resource *construct.Resource) (construct.ResourceId, error) {
template, err := kb.GetResourceTemplate(resource.ID)
if err != nil {
return construct.ResourceId{}, err
}
namespaceProperty := template.GetNamespacedProperty()
if namespaceProperty != nil {
ns, err := resource.GetProperty(namespaceProperty.Details().Name)
if err != nil {
return construct.ResourceId{}, err
}
if ns == nil {
return construct.ResourceId{}, nil
}
if _, ok := ns.(construct.ResourceId); !ok {
return construct.ResourceId{}, fmt.Errorf("namespace property does not contain a ResourceId, got %s", ns)
}
return ns.(construct.ResourceId), nil
}
return construct.ResourceId{}, nil
}
func (kb *KnowledgeBase) GetResourcePropertyType(resource construct.ResourceId, propertyName string) string {
template, err := kb.GetResourceTemplate(resource)
if err != nil {
return ""
}
for _, property := range template.Properties {
if property.Details().Name == propertyName {
return property.Type()
}
}
return ""
}
// TransformToPropertyValue transforms a value to the correct type for a given property
// This is used for transforming values from the config template (and any interface value we want to set on a resource) to the correct type for the resource
func TransformToPropertyValue(
resource construct.ResourceId,
propertyName string,
value interface{},
ctx DynamicContext,
data DynamicValueData,
) (interface{}, error) {
template, err := ctx.KB().GetResourceTemplate(resource)
if err != nil {
return nil, err
}
property := template.GetProperty(propertyName)
if property == nil {
return nil, fmt.Errorf(
"could not find property %s on resource %s",
propertyName, resource,
)
}
if value == nil {
return property.ZeroValue(), nil
}
val, err := property.Parse(value, ctx, data)
if err != nil {
return nil, fmt.Errorf(
"could not parse value %v for property %s on resource %s: %w",
value, property.Details().Name, resource, err,
)
}
return val, nil
}
func TransformAllPropertyValues(ctx DynamicValueContext) error {
ids, err := construct.TopologicalSort(ctx.DAG())
if err != nil {
return err
}
resources, err := construct.ResolveIds(ctx.DAG(), ids)
if err != nil {
return err
}
var errs error
resourceLoop:
for _, resource := range resources {
tmpl, err := ctx.KB().GetResourceTemplate(resource.ID)
if err != nil {
errs = errors.Join(errs, err)
continue
}
data := DynamicValueData{Resource: resource.ID}
for name := range tmpl.Properties {
path, err := resource.PropertyPath(name)
if err != nil {
errs = errors.Join(errs, err)
continue
}
preXform, _ := path.Get()
if preXform == nil {
continue
}
val, err := TransformToPropertyValue(resource.ID, name, preXform, ctx, data)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("error transforming %s#%s: %w", resource.ID, name, err))
continue resourceLoop
}
err = path.Set(val)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("errors setting %s#%s: %w", resource.ID, name, err))
continue resourceLoop
}
}
}
return errs
}
package kbtesting
import (
"errors"
"fmt"
"strings"
"testing"
"github.com/dominikbraun/graph"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/knowledgebase"
)
func StringToGraphElement(e string) (any, error) {
var errs []error
if !strings.Contains(e, "->") {
parts := strings.Split(e, ":")
if len(parts) != 2 {
errs = append(errs, fmt.Errorf("invalid resource ID %q", e))
} else {
return &knowledgebase.ResourceTemplate{
QualifiedTypeName: e,
}, nil
}
}
var path construct.Path
pathErr := path.Parse(e)
if len(path) > 1 {
ets := make([]*knowledgebase.EdgeTemplate, len(path)-1)
for i, id := range path {
if id.Provider == "" || id.Type == "" {
return nil, fmt.Errorf("missing provider or type in path element %d", i)
}
if i == 0 {
continue
}
ets[i-1] = &knowledgebase.EdgeTemplate{
Source: path[i-1],
Target: id,
}
}
return ets, nil
} else if pathErr == nil {
pathErr = fmt.Errorf("path must have at least two elements (got %d)", len(path))
}
errs = append(errs, pathErr)
return nil, errors.Join(errs...)
}
// AddElement is a utility function for adding an element to a graph. See [MakeKB] for more information on supported
// element types. Returns whether adding the element failed.
func AddElement(t *testing.T, g knowledgebase.Graph, e any) (failed bool) {
must := func(err error) {
if err != nil {
t.Error(err)
failed = true
}
}
if estr, ok := e.(string); ok {
var err error
e, err = StringToGraphElement(estr)
if err != nil {
t.Errorf("invalid element %q (type %[1]T) Parse errors: %v", e, err)
return true
}
}
addIfMissing := func(res *knowledgebase.ResourceTemplate) {
err := g.AddVertex(res)
if err != nil && !errors.Is(err, graph.ErrVertexAlreadyExists) {
t.Errorf("could add vertex %s: %v", res.QualifiedTypeName, err)
failed = true
}
}
addEdge := func(e *knowledgebase.EdgeTemplate) {
must(g.AddEdge(e.Source.QualifiedTypeName(), e.Target.QualifiedTypeName(), graph.EdgeData(e)))
}
switch e := e.(type) {
case knowledgebase.ResourceTemplate:
addIfMissing(&e)
case *knowledgebase.ResourceTemplate:
addIfMissing(e)
case knowledgebase.EdgeTemplate:
addIfMissing(&knowledgebase.ResourceTemplate{QualifiedTypeName: e.Source.QualifiedTypeName()})
addIfMissing(&knowledgebase.ResourceTemplate{QualifiedTypeName: e.Target.QualifiedTypeName()})
addEdge(&e)
case *knowledgebase.EdgeTemplate:
addIfMissing(&knowledgebase.ResourceTemplate{QualifiedTypeName: e.Source.QualifiedTypeName()})
addIfMissing(&knowledgebase.ResourceTemplate{QualifiedTypeName: e.Target.QualifiedTypeName()})
addEdge(e)
case []*knowledgebase.EdgeTemplate:
for _, edge := range e {
addIfMissing(&knowledgebase.ResourceTemplate{QualifiedTypeName: edge.Source.QualifiedTypeName()})
addIfMissing(&knowledgebase.ResourceTemplate{QualifiedTypeName: edge.Target.QualifiedTypeName()})
addEdge(edge)
}
default:
t.Errorf("invalid element of type %T", e)
return true
}
return
}
// MakeKB is a utility function for creating a KnowledgeBase from a list of elements which can be of types:
// - ResourceTemplate, *ResourceTemplate : adds the given resource template
// - EdgeTemplate, *EdgeTemplate : adds the given edge template
// - []*EdgeTemplate : adds all the edges in the list
// - string : parses the string as either a QualifiedTypeName or a path of QualifiedTypeNames and adds as empty templates
func MakeKB(t *testing.T, elements ...any) *knowledgebase.KnowledgeBase {
kb := knowledgebase.NewKB()
failed := false
for i, e := range elements {
elemFailed := AddElement(t, kb.Graph(), e)
if elemFailed {
t.Errorf("failed to add element[%d] (%v) to graph", i, e)
failed = true
}
}
if failed {
// Fail now because if the graph didn't parse correctly, then the rest of the test is likely to fail
t.FailNow()
}
return kb
}
package knowledgebase
import (
"fmt"
)
type (
Model struct {
Name string `json:"name" yaml:"name"`
Properties Properties `json:"properties" yaml:"properties"`
Property Property `json:"property" yaml:"property"`
}
)
// GetObjectValue returns the value of the object as the model type
func (m *Model) GetObjectValue(val any, ctx DynamicContext, data DynamicValueData) (any, error) {
GetVal := func(p Property, val map[string]any) (any, error) {
propVal, found := val[p.Details().Name]
if !found {
defaultVal, err := p.GetDefaultValue(ctx, data)
if err != nil {
return nil, err
}
if defaultVal != nil {
return defaultVal, nil
}
return nil, fmt.Errorf("property %s not found", p.Details().Name)
}
return p.Parse(propVal, ctx, data)
}
if m.Properties != nil && m.Property != nil {
return nil, fmt.Errorf("model has both properties and a property")
}
if m.Properties == nil && m.Property == nil {
return nil, fmt.Errorf("model has neither properties nor a property")
}
var errs error
if m.Properties != nil {
obj := map[string]any{}
for name, prop := range m.Properties {
valMap, ok := val.(map[string]any)
if !ok {
errs = fmt.Errorf("%s\n%s", errs, fmt.Errorf("value for model object is not a map"))
continue
}
val, err := GetVal(prop, valMap)
if err != nil {
errs = fmt.Errorf("%s\n%s", errs, err.Error())
continue
}
obj[name] = val
}
return obj, errs
} else {
value, err := m.Property.Parse(val, ctx, data)
if err != nil {
return nil, err
}
return value, nil
}
}
package knowledgebase
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"reflect"
construct "github.com/klothoplatform/klotho/pkg/construct"
"gopkg.in/yaml.v3"
)
type (
OperationalRule struct {
If string `json:"if" yaml:"if"`
Steps []OperationalStep `json:"steps" yaml:"steps"`
ConfigurationRules []ConfigurationRule `json:"configuration_rules" yaml:"configuration_rules"`
}
EdgeRule struct {
If string `json:"if" yaml:"if"`
Steps []EdgeOperationalStep `json:"steps" yaml:"steps"`
ConfigurationRules []ConfigurationRule `json:"configuration_rules" yaml:"configuration_rules"`
}
AdditionalRule struct {
If string `json:"if" yaml:"if"`
Steps []OperationalStep `json:"steps" yaml:"steps"`
}
PropertyRule struct {
If string `json:"if" yaml:"if"`
Step OperationalStep `json:"step" yaml:"step"`
Value any `json:"value" yaml:"value"`
}
EdgeOperationalStep struct {
Resource string `json:"resource" yaml:"resource"`
OperationalStep
}
// OperationalRule defines a rule that must pass checks and actions which must be carried out to make a resource operational
OperationalStep struct {
Resource string `json:"resource" yaml:"resource"`
// Direction defines the direction of the rule. The direction options are upstream or downstream
Direction Direction `json:"direction" yaml:"direction"`
// Resources defines the resource types that the rule should be enforced on. Resource types must be specified if classifications is not specified
Resources []ResourceSelector `json:"resources" yaml:"resources"`
// NumNeeded defines the number of resources that must satisfy the rule
NumNeeded int `json:"num_needed" yaml:"num_needed"`
// FailIfMissing fails if the step is not satisfied when being evaluated. If this flag is set, the step cannot create dependencies
FailIfMissing bool `json:"fail_if_missing" yaml:"fail_if_missing"`
// Unique defines if the resource that is created should be unique
Unique bool `json:"unique" yaml:"unique"`
// UseRef defines if the rule should set the field to the property reference instead of the resource itself
UsePropertyRef string `json:"use_property_ref" yaml:"use_property_ref"`
// SelectionOperator defines how the rule should select a resource if one does not exist
SelectionOperator SelectionOperator `json:"selection_operator" yaml:"selection_operator"`
}
ConfigurationRule struct {
Resource string `json:"resource" yaml:"resource"`
Config Configuration `json:"configuration" yaml:"configuration"`
}
// Configuration defines how to act on any intrinsic values of a resource to make it operational
Configuration struct {
// Field defines a field that should be set on the resource
Field string `json:"field" yaml:"field"`
// Value defines the value that should be set on the resource
Value any `json:"value" yaml:"value"`
}
ResourceSelector struct {
Selector string `json:"selector" yaml:"selector"`
Properties map[string]any `json:"properties" yaml:"properties"`
// NumPreferred defines the amount of resources that should be preferred to satisfy the selector.
// This number is only used if num needed on the step is not met
NumPreferred int `json:"num_preferred" yaml:"num_preferred"`
// Classifications defines the classifications that the rule should be enforced on. Classifications must be specified if resource types is not specified
Classifications []string `json:"classifications" yaml:"classifications"`
}
// Direction defines the direction of the rule. The direction options are upstream or downstream
Direction string
SelectionOperator string
)
const (
DirectionUpstream Direction = "upstream"
DirectionDownstream Direction = "downstream"
SpreadSelectionOperator SelectionOperator = "spread"
ClusterSelectionOperator SelectionOperator = "cluster"
ClosestSelectionOperator SelectionOperator = ""
)
func hashStruct(s any) string {
hash := sha256.New()
err := json.NewEncoder(hash).Encode(s)
if err != nil {
// All types that use this function should be able to be encoded to JSON.
// If this panic occurs, there's a programming error.
panic(fmt.Errorf("error hashing struct %v (%[1]T): %w", s, err))
}
return hex.EncodeToString(hash.Sum(nil))
}
func (rule AdditionalRule) Hash() string {
return hashStruct(rule)
}
func (rule OperationalRule) Hash() string {
return hashStruct(rule)
}
func (d Direction) Edge(resource, dep construct.ResourceId) construct.SimpleEdge {
if d == DirectionUpstream {
return construct.SimpleEdge{Source: dep, Target: resource}
}
return construct.SimpleEdge{Source: resource, Target: dep}
}
// IsMatch checks if the resource selector is a match for the given resource
func (p ResourceSelector) IsMatch(ctx DynamicValueContext, data DynamicValueData, res *construct.Resource) (bool, error) {
return p.matches(ctx, data, res, false)
}
// CanUse checks if the `res` can be used to satisfy the resource selector. This differs from [IsMatch] because it will
// also consider unset properties to be able to be used. This is primarily used for when empty resources are created
// during path expansion, other resources' selectors can be used to configure those empty resources.
func (p ResourceSelector) CanUse(ctx DynamicValueContext, data DynamicValueData, res *construct.Resource) (bool, error) {
return p.matches(ctx, data, res, true)
}
func (p ResourceSelector) matches(
ctx DynamicValueContext,
data DynamicValueData,
res *construct.Resource,
allowEmpty bool,
) (bool, error) {
ids, err := p.ExtractResourceIds(ctx, data)
if err != nil {
return false, fmt.Errorf("error extracting resource ids in resource selector: %w", err)
}
matchesType := false
for _, id := range ids {
// We only check if the resource selector is a match in terms of properties and classifications (not the actual id)
// We do this because if we have explicit ids in the selector and someone changes the id of a side effect resource
// we would no longer think it is a side effect since the id would no longer match.
// To combat this we just check against type
sel := construct.ResourceId{Provider: id.Provider, Type: id.Type}
if sel.Matches(res.ID) {
matchesType = true
break
}
}
if !matchesType {
return false, nil
}
template, err := ctx.KB().GetResourceTemplate(res.ID)
if err != nil {
return false, fmt.Errorf("error getting resource template in resource selector: %w", err)
}
for k, v := range p.Properties {
property, err := res.GetProperty(k)
if err != nil {
return false, err
}
selectorPropertyVal, err := TransformToPropertyValue(res.ID, k, v, ctx, data)
if err != nil {
return false, fmt.Errorf("error transforming property value in resource selector: %w", err)
}
if !reflect.DeepEqual(property, selectorPropertyVal) {
if !(allowEmpty && property == nil) {
return false, nil
}
}
if !template.ResourceContainsClassifications(p.Classifications) {
return false, nil
}
}
return true, nil
}
func (p *ResourceSelector) UnmarshalYAML(n *yaml.Node) error {
type h ResourceSelector
var r h
err := n.Decode(&r)
if err != nil {
var selectorString string
err = n.Decode(&selectorString)
if err == nil {
r.Selector = selectorString
} else {
return fmt.Errorf("error decoding resource selector: %w", err)
}
}
*p = ResourceSelector(r)
return nil
}
func (p ResourceSelector) ExtractResourceIds(ctx DynamicValueContext, data DynamicValueData) (ids construct.ResourceList, errs error) {
var selectors construct.ResourceList
if p.Selector != "" {
err := ctx.ExecuteDecode(p.Selector, data, &selectors)
if err != nil {
errs = errors.Join(errs, err)
if errs != nil {
return nil, errs
}
}
} else {
for _, res := range ctx.KB().ListResources() {
selectors = append(selectors, res.Id())
}
}
for _, id := range selectors {
resTmpl, err := ctx.KB().GetResourceTemplate(id)
if err != nil {
errs = errors.Join(errs, err)
continue
}
if resTmpl == nil {
errs = errors.Join(errs, fmt.Errorf("could not find resource template for %s", id))
continue
}
if resTmpl.ResourceContainsClassifications(p.Classifications) {
ids = append(ids, id)
}
}
return
}
package knowledgebase
import (
"fmt"
"strings"
construct "github.com/klothoplatform/klotho/pkg/construct"
"gopkg.in/yaml.v3"
)
type (
PathSatisfaction struct {
AsTarget []PathSatisfactionRoute `json:"as_target" yaml:"as_target"`
AsSource []PathSatisfactionRoute `json:"as_source" yaml:"as_source"`
// DenyClassifications is a list of classifications that the resource cannot be included in paths during expansion
DenyClassifications []string `yaml:"deny_classifications"`
}
PathSatisfactionRoute struct {
Classification string `json:"classification" yaml:"classification"`
PropertyReference string `json:"property_reference" yaml:"property_reference"`
Validity PathSatisfactionValidityOperation `json:"validity" yaml:"validity"`
Script string `json:"script" yaml:"script"`
}
PathSatisfactionValidityOperation string
)
const (
DownstreamOperation PathSatisfactionValidityOperation = "downstream"
)
func (p *PathSatisfactionRoute) UnmarshalYAML(n *yaml.Node) error {
type h PathSatisfactionRoute
var p2 h
err := n.Decode(&p2)
if err != nil {
routeString := n.Value
routeParts := strings.Split(routeString, "#")
p2.Classification = routeParts[0]
if len(routeParts) > 1 {
p2.PropertyReference = strings.Join(routeParts[1:], "#")
}
*p = PathSatisfactionRoute(p2)
return nil
}
p2.Validity = PathSatisfactionValidityOperation(strings.ToLower(string(p2.Validity)))
*p = PathSatisfactionRoute(p2)
if p.PropertyReference != "" && p.Script != "" {
return fmt.Errorf("path satisfaction route cannot have both property reference and script")
}
return nil
}
func (kb *KnowledgeBase) GetPathSatisfactionsFromEdge(source, target construct.ResourceId) ([]EdgePathSatisfaction, error) {
srcTempalte, err := kb.GetResourceTemplate(source)
if err != nil {
return nil, err
}
targetTemplate, err := kb.GetResourceTemplate(target)
if err != nil {
return nil, err
}
pathSatisfications := []EdgePathSatisfaction{}
trgtsAdded := map[PathSatisfactionRoute]struct{}{}
for _, src := range srcTempalte.PathSatisfaction.AsSource {
srcClassificationHandled := false
for _, trgt := range targetTemplate.PathSatisfaction.AsTarget {
if trgt.Classification == src.Classification {
useSrc := src
useTrgt := trgt
pathSatisfications = append(pathSatisfications, EdgePathSatisfaction{
Classification: src.Classification,
Source: useSrc,
Target: useTrgt,
})
srcClassificationHandled = true
trgtsAdded[trgt] = struct{}{}
}
}
if !srcClassificationHandled {
useSrc := src
pathSatisfications = append(pathSatisfications, EdgePathSatisfaction{
Classification: src.Classification,
Source: useSrc,
})
}
}
for _, trgt := range targetTemplate.PathSatisfaction.AsTarget {
if _, ok := trgtsAdded[trgt]; !ok {
useTrgt := trgt
pathSatisfications = append(pathSatisfications, EdgePathSatisfaction{
Classification: trgt.Classification,
Target: useTrgt,
})
}
}
if len(pathSatisfications) == 0 {
pathSatisfications = append(pathSatisfications, EdgePathSatisfaction{})
}
return pathSatisfications, nil
}
func (v PathSatisfactionRoute) PropertyReferenceChangesBoundary() bool {
if v.Validity != "" {
return false
}
return v.PropertyReference != ""
}
package properties
import (
"fmt"
construct "github.com/klothoplatform/klotho/pkg/construct"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
type (
AnyProperty struct {
SharedPropertyFields
knowledgebase.PropertyDetails
}
)
func (a *AnyProperty) SetProperty(resource *construct.Resource, value any) error {
return resource.SetProperty(a.Path, value)
}
func (a *AnyProperty) AppendProperty(resource *construct.Resource, value any) error {
propVal, err := resource.GetProperty(a.Path)
if err != nil {
return err
}
if propVal == nil {
return a.SetProperty(resource, value)
}
return resource.AppendProperty(a.Path, value)
}
func (a *AnyProperty) RemoveProperty(resource *construct.Resource, value any) error {
return resource.RemoveProperty(a.Path, value)
}
func (a *AnyProperty) Details() *knowledgebase.PropertyDetails {
return &a.PropertyDetails
}
func (a *AnyProperty) Clone() knowledgebase.Property {
clone := *a
return &clone
}
func (a *AnyProperty) GetDefaultValue(ctx knowledgebase.DynamicContext, data knowledgebase.DynamicValueData) (any, error) {
if a.DefaultValue == nil {
return nil, nil
}
return a.Parse(a.DefaultValue, ctx, data)
}
func (a *AnyProperty) Parse(value any, ctx knowledgebase.DynamicContext, data knowledgebase.DynamicValueData) (any, error) {
if val, ok := value.(string); ok {
// first check if its a resource id
rType := ResourceProperty{}
id, err := rType.Parse(val, ctx, data)
if err == nil {
return id, nil
}
// check if its a property ref
ref, err := ParsePropertyRef(val, ctx, data)
if err == nil {
return ref, nil
}
// check if its any other template string
var result any
err = ctx.ExecuteDecode(val, data, &result)
if err == nil {
return result, nil
}
}
if mapVal, ok := value.(map[string]any); ok {
m := MapProperty{KeyProperty: &StringProperty{}, ValueProperty: &AnyProperty{}}
return m.Parse(mapVal, ctx, data)
}
if listVal, ok := value.([]any); ok {
l := ListProperty{ItemProperty: &AnyProperty{}}
return l.Parse(listVal, ctx, data)
}
return value, nil
}
func (a *AnyProperty) ZeroValue() any {
return nil
}
func (a *AnyProperty) Contains(value any, contains any) bool {
if val, ok := value.(string); ok {
s := StringProperty{}
return s.Contains(val, contains)
}
if mapVal, ok := value.(map[string]any); ok {
m := MapProperty{KeyProperty: &StringProperty{}, ValueProperty: &AnyProperty{}}
return m.Contains(mapVal, contains)
}
if listVal, ok := value.([]any); ok {
l := ListProperty{ItemProperty: &AnyProperty{}}
return l.Contains(listVal, contains)
}
return false
}
func (a *AnyProperty) Type() string {
return "any"
}
func (a *AnyProperty) Validate(resource *construct.Resource, value any, ctx knowledgebase.DynamicContext) error {
if a.DeployTime && value == nil && !resource.Imported {
return nil
}
if a.Required && value == nil {
return fmt.Errorf(knowledgebase.ErrRequiredProperty, a.Path, resource.ID)
}
return nil
}
func (a *AnyProperty) SubProperties() knowledgebase.Properties {
return nil
}
package properties
import (
"fmt"
construct "github.com/klothoplatform/klotho/pkg/construct"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
type (
BoolProperty struct {
SharedPropertyFields
knowledgebase.PropertyDetails
}
)
func (b *BoolProperty) SetProperty(resource *construct.Resource, value any) error {
if val, ok := value.(bool); ok {
return resource.SetProperty(b.Path, val)
} else if val, ok := value.(construct.PropertyRef); ok {
return resource.SetProperty(b.Path, val)
}
return fmt.Errorf("invalid bool value %v", value)
}
func (b *BoolProperty) AppendProperty(resource *construct.Resource, value any) error {
return b.SetProperty(resource, value)
}
func (b *BoolProperty) RemoveProperty(resource *construct.Resource, value any) error {
propVal, err := resource.GetProperty(b.Path)
if err != nil {
return err
}
if propVal == nil {
return nil
}
return resource.RemoveProperty(b.Path, value)
}
func (b *BoolProperty) Clone() knowledgebase.Property {
clone := *b
return &clone
}
func (b *BoolProperty) Details() *knowledgebase.PropertyDetails {
return &b.PropertyDetails
}
func (b *BoolProperty) GetDefaultValue(ctx knowledgebase.DynamicContext, data knowledgebase.DynamicValueData) (any, error) {
if b.DefaultValue == nil {
return nil, nil
}
return b.Parse(b.DefaultValue, ctx, data)
}
func (b *BoolProperty) Parse(value any, ctx knowledgebase.DynamicContext, data knowledgebase.DynamicValueData) (any, error) {
if val, ok := value.(string); ok {
var result bool
err := ctx.ExecuteDecode(val, data, &result)
return result, err
}
if val, ok := value.(bool); ok {
return val, nil
}
val, err := ParsePropertyRef(value, ctx, data)
if err == nil {
return val, nil
}
return nil, fmt.Errorf("invalid bool value %v", value)
}
func (b *BoolProperty) ZeroValue() any {
return false
}
func (b *BoolProperty) Contains(value any, contains any) bool {
return false
}
func (b *BoolProperty) Type() string {
return "bool"
}
func (b *BoolProperty) Validate(resource *construct.Resource, value any, ctx knowledgebase.DynamicContext) error {
if b.DeployTime && value == nil && !resource.Imported {
return nil
}
if value == nil {
if b.Required {
return fmt.Errorf(knowledgebase.ErrRequiredProperty, b.Path, resource.ID)
}
return nil
}
if _, ok := value.(bool); !ok {
return fmt.Errorf("invalid bool value %v", value)
}
return nil
}
func (b *BoolProperty) SubProperties() knowledgebase.Properties {
return nil
}
package properties
import (
"fmt"
construct "github.com/klothoplatform/klotho/pkg/construct"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
type (
FloatProperty struct {
MinValue *float64
MaxValue *float64
SharedPropertyFields
knowledgebase.PropertyDetails
}
)
func (f *FloatProperty) SetProperty(resource *construct.Resource, value any) error {
if val, ok := value.(float64); ok {
return resource.SetProperty(f.Path, val)
} else if val, ok := value.(construct.PropertyRef); ok {
return resource.SetProperty(f.Path, val)
}
return fmt.Errorf("invalid float value %v", value)
}
func (f *FloatProperty) AppendProperty(resource *construct.Resource, value any) error {
return f.SetProperty(resource, value)
}
func (f *FloatProperty) RemoveProperty(resource *construct.Resource, value any) error {
propVal, err := resource.GetProperty(f.Path)
if err != nil {
return err
}
if propVal == nil {
return nil
}
return resource.RemoveProperty(f.Path, value)
}
func (f *FloatProperty) Details() *knowledgebase.PropertyDetails {
return &f.PropertyDetails
}
func (f *FloatProperty) Clone() knowledgebase.Property {
clone := *f
return &clone
}
func (f *FloatProperty) GetDefaultValue(ctx knowledgebase.DynamicContext, data knowledgebase.DynamicValueData) (any, error) {
if f.DefaultValue == nil {
return nil, nil
}
return f.Parse(f.DefaultValue, ctx, data)
}
func (f *FloatProperty) Parse(value any, ctx knowledgebase.DynamicContext, data knowledgebase.DynamicValueData) (any, error) {
if val, ok := value.(string); ok {
var result float32
err := ctx.ExecuteDecode(val, data, &result)
return result, err
}
if val, ok := value.(float32); ok {
return val, nil
}
if val, ok := value.(float64); ok {
return val, nil
}
if val, ok := value.(int); ok {
return float64(val), nil
}
val, err := ParsePropertyRef(value, ctx, data)
if err == nil {
return val, nil
}
return nil, fmt.Errorf("invalid float value %v", value)
}
func (f *FloatProperty) ZeroValue() any {
return 0.0
}
func (f *FloatProperty) Contains(value any, contains any) bool {
return false
}
func (f *FloatProperty) Type() string {
return "float"
}
func (f *FloatProperty) Validate(resource *construct.Resource, value any, ctx knowledgebase.DynamicContext) error {
if f.DeployTime && value == nil && !resource.Imported {
return nil
}
if value == nil {
if f.Required {
return fmt.Errorf(knowledgebase.ErrRequiredProperty, f.Path, resource.ID)
}
return nil
}
floatVal, ok := value.(float64)
if !ok {
return fmt.Errorf("invalid float value %v", value)
}
if f.MinValue != nil && floatVal < *f.MinValue {
return fmt.Errorf("float value %f is less than lower bound %f", value, *f.MinValue)
}
if f.MaxValue != nil && floatVal > *f.MaxValue {
return fmt.Errorf("float value %f is greater than upper bound %f", value, *f.MaxValue)
}
return nil
}
func (f *FloatProperty) SubProperties() knowledgebase.Properties {
return nil
}
package properties
import (
"fmt"
"math"
construct "github.com/klothoplatform/klotho/pkg/construct"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
type (
IntProperty struct {
MinValue *int
MaxValue *int
SharedPropertyFields
knowledgebase.PropertyDetails
}
)
func (i *IntProperty) SetProperty(resource *construct.Resource, value any) error {
if val, ok := value.(int); ok {
return resource.SetProperty(i.Path, val)
} else if val, ok := value.(construct.PropertyRef); ok {
return resource.SetProperty(i.Path, val)
}
return fmt.Errorf("invalid int value %v", value)
}
func (i *IntProperty) AppendProperty(resource *construct.Resource, value any) error {
return i.SetProperty(resource, value)
}
func (i *IntProperty) RemoveProperty(resource *construct.Resource, value any) error {
propVal, err := resource.GetProperty(i.Path)
if err != nil {
return err
}
if propVal == nil {
return nil
}
return resource.RemoveProperty(i.Path, value)
}
func (i *IntProperty) Details() *knowledgebase.PropertyDetails {
return &i.PropertyDetails
}
func (i *IntProperty) Clone() knowledgebase.Property {
clone := *i
return &clone
}
func (i *IntProperty) GetDefaultValue(ctx knowledgebase.DynamicContext, data knowledgebase.DynamicValueData) (any, error) {
if i.DefaultValue == nil {
return nil, nil
}
return i.Parse(i.DefaultValue, ctx, data)
}
func (i *IntProperty) Parse(value any, ctx knowledgebase.DynamicContext, data knowledgebase.DynamicValueData) (any, error) {
if val, ok := value.(string); ok {
var result int
err := ctx.ExecuteDecode(val, data, &result)
return result, err
}
if val, ok := value.(int); ok {
return val, nil
}
EPSILON := 0.0000001
if val, ok := value.(float32); ok {
ival := int(val)
if math.Abs(float64(val)-float64(ival)) > EPSILON {
return 0, fmt.Errorf("cannot convert non-integral float to int: %f", val)
}
return int(val), nil
} else if val, ok := value.(float64); ok {
ival := int(val)
if math.Abs(val-float64(ival)) > EPSILON {
return 0, fmt.Errorf("cannot convert non-integral float to int: %f", val)
}
return int(val), nil
}
val, err := ParsePropertyRef(value, ctx, data)
if err == nil {
return val, nil
}
return nil, fmt.Errorf("invalid int value %v", value)
}
func (i *IntProperty) ZeroValue() any {
return 0
}
func (i *IntProperty) Contains(value any, contains any) bool {
return false
}
func (i *IntProperty) Type() string {
return "int"
}
func (i *IntProperty) Validate(resource *construct.Resource, value any, ctx knowledgebase.DynamicContext) error {
if i.DeployTime && value == nil && !resource.Imported {
return nil
}
if value == nil {
if i.Required {
return fmt.Errorf(knowledgebase.ErrRequiredProperty, i.Path, resource.ID)
}
return nil
}
intVal, ok := value.(int)
if !ok {
return fmt.Errorf("invalid int value %v", value)
}
if i.MinValue != nil && intVal < *i.MinValue {
return fmt.Errorf("int value %v is less than lower bound %d", value, *i.MinValue)
}
if i.MaxValue != nil && intVal > *i.MaxValue {
return fmt.Errorf("int value %v is greater than upper bound %d", value, *i.MaxValue)
}
return nil
}
func (i *IntProperty) SubProperties() knowledgebase.Properties {
return nil
}
package properties
import (
"errors"
"fmt"
"reflect"
"strings"
"github.com/klothoplatform/klotho/pkg/collectionutil"
construct "github.com/klothoplatform/klotho/pkg/construct"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
type (
ListProperty struct {
MinLength *int
MaxLength *int
ItemProperty knowledgebase.Property
Properties knowledgebase.Properties
SharedPropertyFields
knowledgebase.PropertyDetails
}
)
func (l *ListProperty) SetProperty(resource *construct.Resource, value any) error {
if val, ok := value.([]any); ok {
return resource.SetProperty(l.Path, val)
}
return fmt.Errorf("invalid list value %v", value)
}
func (l *ListProperty) AppendProperty(resource *construct.Resource, value any) error {
propval, err := resource.GetProperty(l.Path)
if err != nil {
return err
}
if propval == nil {
err := l.SetProperty(resource, []any{})
if err != nil {
return err
}
}
if l.ItemProperty != nil && !strings.HasPrefix(l.ItemProperty.Type(), "list") {
if reflect.ValueOf(value).Kind() == reflect.Slice || reflect.ValueOf(value).Kind() == reflect.Array {
var errs error
for i := 0; i < reflect.ValueOf(value).Len(); i++ {
err := resource.AppendProperty(l.Path, reflect.ValueOf(value).Index(i).Interface())
if err != nil {
errs = errors.Join(errs, err)
}
}
return errs
}
}
return resource.AppendProperty(l.Path, value)
}
func (l *ListProperty) RemoveProperty(resource *construct.Resource, value any) error {
propval, err := resource.GetProperty(l.Path)
if err != nil {
return err
}
if propval == nil {
return nil
}
if l.ItemProperty != nil && !strings.HasPrefix(l.ItemProperty.Type(), "list") {
if reflect.ValueOf(value).Kind() == reflect.Slice || reflect.ValueOf(value).Kind() == reflect.Array {
var errs error
for i := 0; i < reflect.ValueOf(value).Len(); i++ {
err := resource.RemoveProperty(l.Path, reflect.ValueOf(value).Index(i).Interface())
if err != nil {
errs = errors.Join(errs, err)
}
}
return errs
}
}
return resource.RemoveProperty(l.Path, value)
}
func (l *ListProperty) Details() *knowledgebase.PropertyDetails {
return &l.PropertyDetails
}
func (l *ListProperty) Clone() knowledgebase.Property {
var itemProp knowledgebase.Property
if l.ItemProperty != nil {
itemProp = l.ItemProperty.Clone()
}
var props knowledgebase.Properties
if l.Properties != nil {
props = l.Properties.Clone()
}
clone := *l
clone.ItemProperty = itemProp
clone.Properties = props
return &clone
}
func (list *ListProperty) GetDefaultValue(ctx knowledgebase.DynamicContext, data knowledgebase.DynamicValueData) (any, error) {
if list.DefaultValue == nil {
return nil, nil
}
return list.Parse(list.DefaultValue, ctx, data)
}
func (list *ListProperty) Parse(value any, ctx knowledgebase.DynamicContext, data knowledgebase.DynamicValueData) (any, error) {
var result []any
val, ok := value.([]any)
if !ok {
// before we fail, check to see if the entire value is a template
if strVal, ok := value.(string); ok {
var result []any
err := ctx.ExecuteDecode(strVal, data, &result)
if err != nil {
return nil, fmt.Errorf("invalid list value %v: %w", value, err)
}
val = result
}
}
for _, v := range val {
if len(list.Properties) != 0 {
m := MapProperty{Properties: list.Properties}
val, err := m.Parse(v, ctx, data)
if err != nil {
return nil, err
}
result = append(result, val)
} else {
val, err := list.ItemProperty.Parse(v, ctx, data)
if err != nil {
return nil, err
}
result = append(result, val)
}
}
return result, nil
}
func (l *ListProperty) ZeroValue() any {
return nil
}
func (l *ListProperty) Contains(value any, contains any) bool {
list, ok := value.([]any)
if !ok {
return false
}
containsList, ok := contains.([]any)
if !ok {
return collectionutil.Contains(list, contains)
}
for _, v := range list {
for _, cv := range containsList {
if reflect.DeepEqual(v, cv) {
return true
}
}
}
return false
}
func (l *ListProperty) Type() string {
if l.ItemProperty != nil {
return fmt.Sprintf("list(%s)", l.ItemProperty.Type())
}
return "list"
}
func (l *ListProperty) Validate(resource *construct.Resource, value any, ctx knowledgebase.DynamicContext) error {
if l.DeployTime && value == nil && !resource.Imported {
return nil
}
if value == nil {
if l.Required {
return fmt.Errorf(knowledgebase.ErrRequiredProperty, l.Path, resource.ID)
}
return nil
}
listVal, ok := value.([]any)
if !ok {
return fmt.Errorf("invalid list value %v", value)
}
if l.MinLength != nil {
if len(listVal) < *l.MinLength {
return fmt.Errorf("list value %v is too short. min length is %d", value, *l.MinLength)
}
}
if l.MaxLength != nil {
if len(listVal) > *l.MaxLength {
return fmt.Errorf("list value %v is too long. max length is %d", value, *l.MaxLength)
}
}
validList := make([]any, len(listVal))
var errs error
hasSanitized := false
for i, v := range listVal {
if l.ItemProperty != nil {
err := l.ItemProperty.Validate(resource, v, ctx)
if err != nil {
var sanitizeErr *knowledgebase.SanitizeError
if errors.As(err, &sanitizeErr) {
validList[i] = sanitizeErr.Sanitized
hasSanitized = true
} else {
errs = errors.Join(errs, err)
}
} else {
validList[i] = v
}
} else {
vmap, ok := v.(map[string]any)
if !ok {
return fmt.Errorf("invalid value for list index %d in sub properties validation: expected map[string]any got %T", i, v)
}
validIndex := make(map[string]any)
for _, prop := range l.SubProperties() {
val, ok := vmap[prop.Details().Name]
if !ok {
continue
}
err := prop.Validate(resource, val, ctx)
if err != nil {
var sanitizeErr *knowledgebase.SanitizeError
if errors.As(err, &sanitizeErr) {
validIndex[prop.Details().Name] = sanitizeErr.Sanitized
hasSanitized = true
} else {
errs = errors.Join(errs, err)
}
} else {
validIndex[prop.Details().Name] = val
}
}
validList[i] = validIndex
}
}
if errs != nil {
return errs
}
if hasSanitized {
return &knowledgebase.SanitizeError{
Input: listVal,
Sanitized: validList,
}
}
return nil
}
func (l *ListProperty) SubProperties() knowledgebase.Properties {
return l.Properties
}
func (l *ListProperty) Item() knowledgebase.Property {
return l.ItemProperty
}
package properties
import (
"errors"
"fmt"
"reflect"
construct "github.com/klothoplatform/klotho/pkg/construct"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
type (
MapProperty struct {
MinLength *int
MaxLength *int
KeyProperty knowledgebase.Property
ValueProperty knowledgebase.Property
Properties knowledgebase.Properties
SharedPropertyFields
knowledgebase.PropertyDetails
}
)
func (m *MapProperty) SetProperty(resource *construct.Resource, value any) error {
if val, ok := value.(map[string]any); ok {
return resource.SetProperty(m.Path, val)
}
return fmt.Errorf("invalid resource value %v", value)
}
func (m *MapProperty) AppendProperty(resource *construct.Resource, value any) error {
return resource.AppendProperty(m.Path, value)
}
func (m *MapProperty) RemoveProperty(resource *construct.Resource, value any) error {
propVal, err := resource.GetProperty(m.Path)
if err != nil {
return err
}
if propVal == nil {
return nil
}
propMap, ok := propVal.(map[string]any)
if !ok {
return fmt.Errorf("error attempting to remove map property: invalid property value %v", propVal)
}
if val, ok := value.(map[string]any); ok {
for k, v := range val {
if val, found := propMap[k]; found && reflect.DeepEqual(val, v) {
delete(propMap, k)
}
}
return resource.SetProperty(m.Path, propMap)
}
return resource.RemoveProperty(m.Path, value)
}
func (m *MapProperty) Details() *knowledgebase.PropertyDetails {
return &m.PropertyDetails
}
func (m *MapProperty) Clone() knowledgebase.Property {
var keyProp knowledgebase.Property
if m.KeyProperty != nil {
keyProp = m.KeyProperty.Clone()
}
var valProp knowledgebase.Property
if m.ValueProperty != nil {
valProp = m.ValueProperty.Clone()
}
var props knowledgebase.Properties
if m.Properties != nil {
props = m.Properties.Clone()
}
clone := *m
clone.KeyProperty = keyProp
clone.ValueProperty = valProp
clone.Properties = props
return &clone
}
func (m *MapProperty) GetDefaultValue(ctx knowledgebase.DynamicContext, data knowledgebase.DynamicValueData) (any, error) {
if m.DefaultValue == nil {
return nil, nil
}
return m.Parse(m.DefaultValue, ctx, data)
}
func (m *MapProperty) Parse(value any, ctx knowledgebase.DynamicContext, data knowledgebase.DynamicValueData) (any, error) {
result := map[string]any{}
mapVal, ok := value.(map[string]any)
if !ok {
// before we fail, check to see if the entire value is a template
if strVal, ok := value.(string); ok {
err := ctx.ExecuteDecode(strVal, data, &result)
return result, err
}
mapVal, ok = value.(construct.Properties)
if !ok {
return nil, fmt.Errorf("invalid map value %v", value)
}
}
// If we are an object with sub properties then we know that we need to get the type of our sub properties to determine how we are parsed into a value
if len(m.Properties) != 0 {
var errs error
for key, prop := range m.Properties {
if _, found := mapVal[key]; found {
val, err := prop.Parse(mapVal[key], ctx, data)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("unable to parse value for sub property %s: %w", key, err))
continue
}
result[key] = val
} else {
val, err := prop.GetDefaultValue(ctx, data)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("unable to get default value for sub property %s: %w", key, err))
continue
}
if val == nil {
continue
}
result[key] = val
}
}
}
if m.KeyProperty == nil || m.ValueProperty == nil {
return result, nil
}
// Else we are a set type of map and can just loop over the values
for key, v := range mapVal {
keyVal, err := m.KeyProperty.Parse(key, ctx, data)
if err != nil {
return nil, err
}
val, err := m.ValueProperty.Parse(v, ctx, data)
if err != nil {
return nil, err
}
switch keyVal := keyVal.(type) {
case string:
result[keyVal] = val
case construct.ResourceId:
result[keyVal.String()] = val
case construct.PropertyRef:
result[keyVal.String()] = val
default:
return nil, fmt.Errorf("invalid key type for map property type %s", keyVal)
}
}
return result, nil
}
func (m *MapProperty) ZeroValue() any {
return nil
}
func (m *MapProperty) Contains(value any, contains any) bool {
mapVal, ok := value.(map[string]any)
if !ok {
return false
}
containsMap, ok := contains.(map[string]any)
if !ok {
return false
}
for k, v := range containsMap {
if val, found := mapVal[k]; found || reflect.DeepEqual(val, v) {
return true
}
}
for _, v := range mapVal {
for _, cv := range containsMap {
if reflect.DeepEqual(v, cv) {
return true
}
}
}
return false
}
func (m *MapProperty) Type() string {
if m.KeyProperty != nil && m.ValueProperty != nil {
return fmt.Sprintf("map(%s,%s)", m.KeyProperty.Type(), m.ValueProperty.Type())
}
return "map"
}
func (m *MapProperty) Validate(resource *construct.Resource, value any, ctx knowledgebase.DynamicContext) error {
if m.DeployTime && value == nil && !resource.Imported {
return nil
}
if value == nil {
if m.Required {
return fmt.Errorf(knowledgebase.ErrRequiredProperty, m.Path, resource.ID)
}
return nil
}
mapVal, ok := value.(map[string]any)
if !ok {
return fmt.Errorf("invalid map value %v", value)
}
if m.MinLength != nil {
if len(mapVal) < *m.MinLength {
return fmt.Errorf("map value %v is too short. min length is %d", value, *m.MinLength)
}
}
if m.MaxLength != nil {
if len(mapVal) > *m.MaxLength {
return fmt.Errorf("map value %v is too long. max length is %d", value, *m.MaxLength)
}
}
var errs error
hasSanitized := false
validMap := make(map[string]any)
// Only validate values if its a primitive map, otherwise let the sub properties handle their own validation
for k, v := range mapVal {
if m.KeyProperty != nil {
var sanitizeErr *knowledgebase.SanitizeError
if err := m.KeyProperty.Validate(resource, k, ctx); errors.As(err, &sanitizeErr) {
k = sanitizeErr.Sanitized.(string)
hasSanitized = true
} else if err != nil {
errs = errors.Join(errs, fmt.Errorf("invalid key %v for map property type %s: %w", k, m.KeyProperty.Type(), err))
}
}
if m.ValueProperty != nil {
var sanitizeErr *knowledgebase.SanitizeError
if err := m.ValueProperty.Validate(resource, v, ctx); errors.As(err, &sanitizeErr) {
v = sanitizeErr.Sanitized
hasSanitized = true
} else if err != nil {
errs = errors.Join(errs, fmt.Errorf("invalid value %v for map property type %s: %w", v, m.ValueProperty.Type(), err))
}
}
validMap[k] = v
}
if errs != nil {
return errs
}
if hasSanitized {
return &knowledgebase.SanitizeError{
Input: mapVal,
Sanitized: validMap,
}
}
return nil
}
func (m *MapProperty) SubProperties() knowledgebase.Properties {
return m.Properties
}
func (m *MapProperty) Key() knowledgebase.Property {
return m.KeyProperty
}
func (m *MapProperty) Value() knowledgebase.Property {
return m.ValueProperty
}
package properties
import (
"bytes"
"fmt"
"text/template"
construct "github.com/klothoplatform/klotho/pkg/construct"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
type (
SharedPropertyFields struct {
DefaultValue any
ValidityChecks []PropertyValidityCheck
}
PropertyValidityCheck struct {
template *template.Template
}
ValidityCheckData struct {
Properties construct.Properties `json:"properties" yaml:"properties"`
Value any `json:"value" yaml:"value"`
}
)
func ParsePropertyRef(value any, ctx knowledgebase.DynamicContext, data knowledgebase.DynamicValueData) (construct.PropertyRef, error) {
if val, ok := value.(string); ok {
result := construct.PropertyRef{}
err := ctx.ExecuteDecode(val, data, &result)
return result, err
}
if val, ok := value.(map[string]interface{}); ok {
rp := ResourceProperty{}
id, err := rp.Parse(val["resource"], ctx, data)
if err != nil {
return construct.PropertyRef{}, err
}
return construct.PropertyRef{
Property: val["property"].(string),
Resource: id.(construct.ResourceId),
}, nil
}
if val, ok := value.(construct.PropertyRef); ok {
return val, nil
}
return construct.PropertyRef{}, fmt.Errorf("invalid property reference value %v", value)
}
func ValidatePropertyRef(value construct.PropertyRef, propertyType string, ctx knowledgebase.DynamicContext) (refVal any, err error) {
resource, err := ctx.DAG().Vertex(value.Resource)
if err != nil {
return nil, fmt.Errorf("error getting resource %s, while validating property ref: %w", value.Resource, err)
}
if resource == nil {
return nil, fmt.Errorf("resource %s does not exist", value.Resource)
}
rt, err := ctx.KB().GetResourceTemplate(value.Resource)
if err != nil {
return nil, err
}
prop := rt.GetProperty(value.Property)
if prop == nil {
return nil, fmt.Errorf("property %s does not exist on resource %s", value.Property, value.Resource)
}
if prop.Type() != propertyType {
return nil, fmt.Errorf("property %s on resource %s is not of type %s", value.Property, value.Resource, propertyType)
}
if prop.Details().DeployTime {
return nil, nil
}
propVal, err := resource.GetProperty(value.Property)
if err != nil {
return nil, fmt.Errorf("error getting property %s on resource %s, while validating property ref: %w", value.Property, value.Resource, err)
}
// recurse down in case of a nested property ref
for propValRef, ok := propVal.(construct.PropertyRef); ok; propValRef, ok = propVal.(construct.PropertyRef) {
propVal, err = ValidatePropertyRef(propValRef, propertyType, ctx)
if err != nil {
return nil, err
}
if propVal == nil {
return nil, nil
}
}
return propVal, nil
}
func (p *PropertyValidityCheck) Validate(value any, properties construct.Properties) error {
var buff bytes.Buffer
data := ValidityCheckData{
Properties: properties,
Value: value,
}
err := p.template.Execute(&buff, data)
if err != nil {
return err
}
result := buff.String()
if result != "" {
return fmt.Errorf("invalid value %v: %s", value, result)
}
return nil
}
package properties
import (
"fmt"
construct "github.com/klothoplatform/klotho/pkg/construct"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
type (
ResourceProperty struct {
AllowedTypes construct.ResourceList
SharedPropertyFields
knowledgebase.PropertyDetails
}
)
func (r *ResourceProperty) SetProperty(resource *construct.Resource, value any) error {
if val, ok := value.(construct.ResourceId); ok {
return resource.SetProperty(r.Path, val)
} else if val, ok := value.(construct.PropertyRef); ok {
return resource.SetProperty(r.Path, val)
}
return fmt.Errorf("invalid resource value %v", value)
}
func (r *ResourceProperty) AppendProperty(resource *construct.Resource, value any) error {
return r.SetProperty(resource, value)
}
func (r *ResourceProperty) RemoveProperty(resource *construct.Resource, value any) error {
propVal, err := resource.GetProperty(r.Path)
if err != nil {
return err
}
if propVal == nil {
return nil
}
propId, ok := propVal.(construct.ResourceId)
if !ok {
return fmt.Errorf("error attempting to remove resource property: invalid property value %v", propVal)
}
valId, ok := value.(construct.ResourceId)
if !ok {
return fmt.Errorf("error attempting to remove resource property: invalid resource value %v", value)
}
if !propId.Matches(valId) {
return fmt.Errorf("error attempting to remove resource property: resource value %v does not match property value %v", value, propVal)
}
return resource.RemoveProperty(r.Path, value)
}
func (r *ResourceProperty) Details() *knowledgebase.PropertyDetails {
return &r.PropertyDetails
}
func (r *ResourceProperty) Clone() knowledgebase.Property {
clone := *r
return &clone
}
func (r *ResourceProperty) GetDefaultValue(ctx knowledgebase.DynamicContext, data knowledgebase.DynamicValueData) (any, error) {
if r.DefaultValue == nil {
return nil, nil
}
return r.Parse(r.DefaultValue, ctx, data)
}
func (r *ResourceProperty) Parse(value any, ctx knowledgebase.DynamicContext, data knowledgebase.DynamicValueData) (any, error) {
if val, ok := value.(string); ok {
id, err := knowledgebase.ExecuteDecodeAsResourceId(ctx, val, data)
if !id.IsZero() && len(r.AllowedTypes) > 0 && !r.AllowedTypes.MatchesAny(id) {
return nil, fmt.Errorf("resource value %v does not match allowed types %s", value, r.AllowedTypes)
}
return id, err
}
if val, ok := value.(map[string]interface{}); ok {
id := construct.ResourceId{
Type: val["type"].(string),
Name: val["name"].(string),
Provider: val["provider"].(string),
}
if namespace, ok := val["namespace"]; ok {
id.Namespace = namespace.(string)
}
if len(r.AllowedTypes) > 0 && !r.AllowedTypes.MatchesAny(id) {
return nil, fmt.Errorf("resource value %v does not match type %s", value, r.AllowedTypes)
}
return id, nil
}
if val, ok := value.(construct.ResourceId); ok {
if len(r.AllowedTypes) > 0 && !r.AllowedTypes.MatchesAny(val) {
return nil, fmt.Errorf("resource value %v does not match type %s", value, r.AllowedTypes)
}
return val, nil
}
val, err := ParsePropertyRef(value, ctx, data)
if err == nil {
if len(r.AllowedTypes) > 0 && !r.AllowedTypes.MatchesAny(val.Resource) {
return nil, fmt.Errorf("resource value %v does not match type %s", value, r.AllowedTypes)
}
return val, nil
}
return nil, fmt.Errorf("invalid resource value %v", value)
}
func (r *ResourceProperty) ZeroValue() any {
return construct.ResourceId{}
}
func (r *ResourceProperty) Contains(value any, contains any) bool {
if val, ok := value.(construct.ResourceId); ok {
if cont, ok := contains.(construct.ResourceId); ok {
return val.Matches(cont)
}
}
return false
}
func (r *ResourceProperty) Type() string {
if len(r.AllowedTypes) > 0 {
typeString := ""
for i, t := range r.AllowedTypes {
typeString += t.String()
if i < len(r.AllowedTypes)-1 {
typeString += ", "
}
}
return fmt.Sprintf("resource(%s)", typeString)
}
return "resource"
}
func (r *ResourceProperty) Validate(resource *construct.Resource, value any, ctx knowledgebase.DynamicContext) error {
if r.DeployTime && value == nil && !resource.Imported {
return nil
}
if value == nil {
if r.Required {
return fmt.Errorf(knowledgebase.ErrRequiredProperty, r.Path, resource.ID)
}
return nil
}
id, ok := value.(construct.ResourceId)
if !ok {
return fmt.Errorf("invalid resource value %v", value)
}
if len(r.AllowedTypes) > 0 && !r.AllowedTypes.MatchesAny(id) {
return fmt.Errorf("resource value %v does not match allowed types %s", value, r.AllowedTypes)
}
return nil
}
func (r *ResourceProperty) SubProperties() knowledgebase.Properties {
return nil
}
package properties
import (
"errors"
"fmt"
"reflect"
construct "github.com/klothoplatform/klotho/pkg/construct"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/klothoplatform/klotho/pkg/set"
)
type (
SetProperty struct {
MinLength *int
MaxLength *int
ItemProperty knowledgebase.Property
Properties knowledgebase.Properties
SharedPropertyFields
knowledgebase.PropertyDetails
}
)
func (s *SetProperty) SetProperty(resource *construct.Resource, value any) error {
if val, ok := value.(set.HashedSet[string, any]); ok {
return resource.SetProperty(s.Path, val)
}
return fmt.Errorf("invalid set value %v", value)
}
func (s *SetProperty) AppendProperty(resource *construct.Resource, value any) error {
propVal, err := resource.GetProperty(s.Path)
if err != nil {
return err
}
if propVal == nil {
if val, ok := value.(set.HashedSet[string, any]); ok {
return s.SetProperty(resource, val)
}
}
return resource.AppendProperty(s.Path, value)
}
func (s *SetProperty) RemoveProperty(resource *construct.Resource, value any) error {
propVal, err := resource.GetProperty(s.Path)
if err != nil {
return err
}
if propVal == nil {
return nil
}
propSet, ok := propVal.(set.HashedSet[string, any])
if !ok {
return errors.New("invalid set value")
}
if val, ok := value.(set.HashedSet[string, any]); ok {
for _, v := range val.ToSlice() {
propSet.Remove(v)
}
} else {
return fmt.Errorf("invalid set value %v", value)
}
return s.SetProperty(resource, propSet)
}
func (s *SetProperty) Details() *knowledgebase.PropertyDetails {
return &s.PropertyDetails
}
func (s *SetProperty) Clone() knowledgebase.Property {
var itemProp knowledgebase.Property
if s.ItemProperty != nil {
itemProp = s.ItemProperty.Clone()
}
var props knowledgebase.Properties
if s.Properties != nil {
props = s.Properties.Clone()
}
clone := *s
clone.ItemProperty = itemProp
clone.Properties = props
return &clone
}
func (s *SetProperty) GetDefaultValue(ctx knowledgebase.DynamicContext, data knowledgebase.DynamicValueData) (any, error) {
if s.DefaultValue == nil {
return nil, nil
}
return s.Parse(s.DefaultValue, ctx, data)
}
func (s *SetProperty) Parse(value any, ctx knowledgebase.DynamicContext, data knowledgebase.DynamicValueData) (any, error) {
var result = set.HashedSet[string, any]{
Hasher: func(s any) string {
return fmt.Sprintf("%v", s)
},
Less: func(s1, s2 string) bool {
return s1 < s2
},
}
var vals []any
if valSet, ok := value.(set.HashedSet[string, any]); ok {
vals = valSet.ToSlice()
} else if val, ok := value.([]any); ok {
vals = val
} else {
// before we fail, check to see if the entire value is a template
if strVal, ok := value.(string); ok {
err := ctx.ExecuteDecode(strVal, data, &vals)
if err != nil {
return nil, err
}
}
}
for _, v := range vals {
if len(s.Properties) != 0 {
m := MapProperty{Properties: s.Properties}
val, err := m.Parse(v, ctx, data)
if err != nil {
return nil, err
}
result.Add(val)
} else {
val, err := s.ItemProperty.Parse(v, ctx, data)
if err != nil {
return nil, err
}
result.Add(val)
}
}
return result, nil
}
func (s *SetProperty) ZeroValue() any {
return nil
}
func (s *SetProperty) Contains(value any, contains any) bool {
valSet, ok := value.(set.HashedSet[string, any])
if !ok {
return false
}
for _, val := range valSet.M {
if reflect.DeepEqual(contains, val) {
return true
}
}
return false
}
func (s *SetProperty) Type() string {
if s.ItemProperty != nil {
return fmt.Sprintf("set(%s)", s.ItemProperty.Type())
}
return "set"
}
func (s *SetProperty) Validate(resource *construct.Resource, value any, ctx knowledgebase.DynamicContext) error {
if s.DeployTime && value == nil && !resource.Imported {
return nil
}
if value == nil {
if s.Required {
return fmt.Errorf(knowledgebase.ErrRequiredProperty, s.Path, resource.ID)
}
return nil
}
setVal, ok := value.(set.HashedSet[string, any])
if !ok {
return fmt.Errorf("could not validate set property: invalid set value %v", value)
}
if s.MinLength != nil {
if setVal.Len() < *s.MinLength {
return fmt.Errorf("value %s is too short. minimum length is %d", setVal.M, *s.MinLength)
}
}
if s.MaxLength != nil {
if setVal.Len() > *s.MaxLength {
return fmt.Errorf("value %s is too long. maximum length is %d", setVal.M, *s.MaxLength)
}
}
// Only validate values if its a primitive list, otherwise let the sub properties handle their own validation
if s.ItemProperty != nil {
var errs error
hasSanitized := false
validSet := set.HashedSet[string, any]{Hasher: setVal.Hasher}
for _, item := range setVal.ToSlice() {
if err := s.ItemProperty.Validate(resource, item, ctx); err != nil {
var sanitizeErr *knowledgebase.SanitizeError
if errors.As(err, &sanitizeErr) {
validSet.Add(sanitizeErr.Sanitized)
hasSanitized = true
} else {
errs = errors.Join(errs, fmt.Errorf("invalid item %v: %v", item, err))
}
} else {
validSet.Add(item)
}
}
if errs != nil {
return errs
}
if hasSanitized {
return &knowledgebase.SanitizeError{
Input: setVal,
Sanitized: validSet,
}
}
}
return nil
}
func (s *SetProperty) SubProperties() knowledgebase.Properties {
return s.Properties
}
func (s *SetProperty) Item() knowledgebase.Property {
return s.ItemProperty
}
package properties
import (
"fmt"
"strings"
"github.com/klothoplatform/klotho/pkg/collectionutil"
construct "github.com/klothoplatform/klotho/pkg/construct"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
type (
StringProperty struct {
SanitizeTmpl *knowledgebase.SanitizeTmpl
AllowedValues []string
SharedPropertyFields
knowledgebase.PropertyDetails
}
)
func (str *StringProperty) SetProperty(resource *construct.Resource, value any) error {
if val, ok := value.(string); ok {
return resource.SetProperty(str.Path, val)
} else if val, ok := value.(construct.PropertyRef); ok {
return resource.SetProperty(str.Path, val)
}
return fmt.Errorf("could not set string property: invalid string value %v", value)
}
func (str *StringProperty) AppendProperty(resource *construct.Resource, value any) error {
return str.SetProperty(resource, value)
}
func (str *StringProperty) RemoveProperty(resource *construct.Resource, value any) error {
propVal, err := resource.GetProperty(str.Path)
if err != nil {
return err
}
if propVal == nil {
return nil
}
return resource.RemoveProperty(str.Path, nil)
}
func (s *StringProperty) Details() *knowledgebase.PropertyDetails {
return &s.PropertyDetails
}
func (s *StringProperty) Clone() knowledgebase.Property {
clone := *s
return &clone
}
func (s *StringProperty) GetDefaultValue(ctx knowledgebase.DynamicContext, data knowledgebase.DynamicValueData) (any, error) {
if s.DefaultValue == nil {
return nil, nil
}
return s.Parse(s.DefaultValue, ctx, data)
}
func (str *StringProperty) Parse(value any, ctx knowledgebase.DynamicContext, data knowledgebase.DynamicValueData) (any, error) {
// Here we have to try to parse to a property ref first, since a string representation of a property ref would match string parsing
val, err := ParsePropertyRef(value, ctx, data)
if err == nil {
return val, nil
}
switch val := value.(type) {
case string:
err := ctx.ExecuteDecode(val, data, &val)
return val, err
case int, int32, int64, float32, float64, bool:
return fmt.Sprintf("%v", val), nil
}
return nil, fmt.Errorf("could not parse string property: invalid string value %v (%[1]T)", value)
}
func (s *StringProperty) ZeroValue() any {
return ""
}
func (s *StringProperty) Contains(value any, contains any) bool {
vString, ok := value.(string)
if !ok {
return false
}
cString, ok := contains.(string)
if !ok {
return false
}
return strings.Contains(vString, cString)
}
func (s *StringProperty) Type() string {
return "string"
}
func (s *StringProperty) Validate(resource *construct.Resource, value any, ctx knowledgebase.DynamicContext) error {
if s.DeployTime && value == nil && !resource.Imported {
return nil
}
if value == nil {
if s.Required {
return fmt.Errorf(knowledgebase.ErrRequiredProperty, s.Path, resource.ID)
}
return nil
}
stringVal, ok := value.(string)
if !ok {
propertyRef, ok := value.(construct.PropertyRef)
if !ok {
return fmt.Errorf("could not validate property: invalid string value %v", value)
}
refVal, err := ValidatePropertyRef(propertyRef, s.Type(), ctx)
if err != nil {
return err
}
if refVal == nil {
return nil
}
stringVal, ok = refVal.(string)
if !ok {
return fmt.Errorf("could not validate property: invalid string value %v", value)
}
}
if len(s.AllowedValues) > 0 && !collectionutil.Contains(s.AllowedValues, stringVal) {
return fmt.Errorf("value %s is not allowed. allowed values are %s", stringVal, s.AllowedValues)
}
if s.SanitizeTmpl != nil {
return s.SanitizeTmpl.Check(stringVal)
}
return nil
}
func (s *StringProperty) SubProperties() knowledgebase.Properties {
return nil
}
package reader
import (
"errors"
"fmt"
"io/fs"
construct "github.com/klothoplatform/klotho/pkg/construct"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
)
func NewKBFromFs(resources, edges, models fs.FS) (*knowledgebase.KnowledgeBase, error) {
var errs error
kb := knowledgebase.NewKB()
readerModels, err := ModelsFromFS(models)
if err != nil {
return nil, err
}
kbModels := map[string]*knowledgebase.Model{}
for name, model := range readerModels {
kbModel, err := model.Convert()
if err != nil {
errs = errors.Join(errs, fmt.Errorf("error converting model %s: %w", name, err))
}
kbModels[name] = kbModel
}
if errs != nil {
return nil, errs
}
kb.Models = kbModels
templates, err := TemplatesFromFs(resources, readerModels)
if err != nil {
return nil, err
}
edgeTemplates, err := EdgeTemplatesFromFs(edges)
if err != nil {
return nil, err
}
for _, template := range templates {
err = kb.AddResourceTemplate(template)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("error adding resource template %s: %w", template.QualifiedTypeName, err))
}
}
for _, template := range edgeTemplates {
err = kb.AddEdgeTemplate(template)
if err != nil {
errs = errors.Join(errs,
fmt.Errorf("error adding edge template %s -> %s: %w",
template.Source.QualifiedTypeName(),
template.Target.QualifiedTypeName(),
err),
)
}
}
return kb, errs
}
func ModelsFromFS(dir fs.FS) (map[string]*Model, error) {
inputModels := map[string]*Model{}
log := zap.S().Named("kb.load.models")
err := fs.WalkDir(dir, ".", func(path string, d fs.DirEntry, nerr error) error {
if d.IsDir() {
return nil
}
log.Debugf("Loading model: %s", path)
f, err := dir.Open(path)
if err != nil {
return errors.Join(nerr, fmt.Errorf("error opening model file %s: %w", path, err))
}
model := Model{}
err = yaml.NewDecoder(f).Decode(&model)
if err != nil {
return errors.Join(nerr, fmt.Errorf("error decoding model file %s: %w", path, err))
}
inputModels[model.Name] = &model
return nil
})
// Update models to only reference properties and not other models and then convert property/properties to internal types
for _, model := range inputModels {
uerr := updateModels(nil, model.Properties, inputModels)
if uerr != nil {
err = errors.Join(err, uerr)
}
}
return inputModels, err
}
func TemplatesFromFs(dir fs.FS, models map[string]*Model) (map[construct.ResourceId]*knowledgebase.ResourceTemplate, error) {
templates := map[construct.ResourceId]*knowledgebase.ResourceTemplate{}
log := zap.S().Named("kb.load.resources")
err := fs.WalkDir(dir, ".", func(path string, d fs.DirEntry, nerr error) error {
if d.IsDir() {
return nil
}
log.Debugf("Loading resource template: %s", path)
f, err := dir.Open(path)
if err != nil {
return errors.Join(nerr, err)
}
resTemplate := &ResourceTemplate{}
err = yaml.NewDecoder(f).Decode(resTemplate)
if err != nil {
return errors.Join(nerr, fmt.Errorf("error decoding resource template %s: %w", path, err))
}
err = updateModels(nil, resTemplate.Properties, models)
if err != nil {
return errors.Join(nerr, fmt.Errorf("error updating models for resource template %s: %w", path, err))
}
id := construct.ResourceId{}
err = id.UnmarshalText([]byte(resTemplate.QualifiedTypeName))
if err != nil {
return errors.Join(nerr, fmt.Errorf("error unmarshalling resource template id for %s: %w", path, err))
}
if templates[id] != nil {
return errors.Join(nerr, fmt.Errorf("duplicate template for %s in %s", id, path))
}
rt, err := resTemplate.Convert()
if err != nil {
return errors.Join(nerr, fmt.Errorf("error converting resource template %s: %w", path, err))
}
templates[id] = rt
return nil
})
return templates, err
}
func EdgeTemplatesFromFs(dir fs.FS) (map[string]*knowledgebase.EdgeTemplate, error) {
templates := map[string]*knowledgebase.EdgeTemplate{}
log := zap.S().Named("kb.load.edges")
err := fs.WalkDir(dir, ".", func(path string, d fs.DirEntry, nerr error) error {
if d.IsDir() {
return nil
}
log.Debugf("Loading edge template: %s", path)
f, err := dir.Open(path)
if err != nil {
return errors.Join(nerr, fmt.Errorf("error opening edge template %s: %w", path, err))
}
edgeTemplate := &knowledgebase.EdgeTemplate{}
err = yaml.NewDecoder(f).Decode(edgeTemplate)
if err != nil {
return errors.Join(nerr, fmt.Errorf("error decoding edge template %s: %w", path, err))
}
if edgeTemplate.Source.IsZero() || edgeTemplate.Target.IsZero() {
f, err := dir.Open(path)
if err != nil {
return errors.Join(nerr, fmt.Errorf("error opening edge template %s: %w", path, err))
}
multiEdgeTemplate := &knowledgebase.MultiEdgeTemplate{}
err = yaml.NewDecoder(f).Decode(multiEdgeTemplate)
if err != nil {
return errors.Join(nerr, fmt.Errorf("error decoding edge template %s: %w", path, err))
}
if !multiEdgeTemplate.Resource.IsZero() && (len(multiEdgeTemplate.Sources) > 0 || len(multiEdgeTemplate.Targets) > 0) {
edgeTemplates := knowledgebase.EdgeTemplatesFromMulti(*multiEdgeTemplate)
for _, edgeTemplate := range edgeTemplates {
id := edgeTemplate.Source.QualifiedTypeName() + "->" + edgeTemplate.Target.QualifiedTypeName()
if templates[id] != nil {
return errors.Join(nerr, fmt.Errorf("duplicate template for %s in %s", id, path))
}
et := edgeTemplate
templates[id] = &et
}
return nil
}
}
id := edgeTemplate.Source.QualifiedTypeName() + "->" + edgeTemplate.Target.QualifiedTypeName()
if templates[id] != nil {
return errors.Join(nerr, fmt.Errorf("duplicate template for %s in %s", id, path))
}
templates[id] = edgeTemplate
return nil
})
return templates, err
}
package reader
import (
"fmt"
"strings"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
)
type (
Model struct {
Name string `json:"name" yaml:"name"`
Properties Properties `json:"properties" yaml:"properties"`
Property *Property `json:"property" yaml:"property"`
}
)
func (m Model) Convert() (*knowledgebase.Model, error) {
model := &knowledgebase.Model{}
model.Name = m.Name
if m.Properties != nil {
properties, err := m.Properties.Convert()
if err != nil {
return nil, err
}
model.Properties = properties
}
if m.Property != nil {
property, err := m.Property.Convert()
if err != nil {
return nil, err
}
model.Property = property
}
return model, nil
}
func (p Property) ModelType() *string {
typeString := strings.TrimSuffix(strings.TrimPrefix(p.Type, "list("), ")")
parts := strings.Split(typeString, "(")
if parts[0] != "model" {
return nil
}
if len(parts) == 1 {
return &p.Name
}
if len(parts) != 2 {
return nil
}
modelType := strings.TrimSuffix(parts[1], ")")
return &modelType
}
func updateModels(property *Property, properties Properties, models map[string]*Model) error {
for name, p := range properties {
modelType := p.ModelType()
if modelType != nil {
if len(p.Properties) != 0 {
return fmt.Errorf("property %s has properties but is labeled as a model", name)
}
model := models[*modelType]
if model == nil || model.Properties == nil {
return fmt.Errorf("model %s not found", *modelType)
}
// We know that this means we want the properties to be spread onto the resource
if p.Name == *modelType {
if model.Property != nil {
return fmt.Errorf("model %s as property can not be spread into properties", *modelType)
}
delete(properties, name)
for name, prop := range model.Properties {
// since properties are pointers and models can be reused, we need to clone the property from the model itself
newProp := prop.Clone()
// We need to make sure we perpend the parent property path
if property != nil {
newProp.Path = fmt.Sprintf("%s.%s", property.Path, prop.Path)
}
// we also need to check if the current property has a default and propagate it lower
if p.DefaultValue != nil {
defaultMap, ok := p.DefaultValue.(map[string]any)
if !ok {
return fmt.Errorf("default value for %s is not a map", p.Path)
}
newProp.DefaultValue = defaultMap[name]
}
properties[name] = newProp
}
if property != nil {
if err := updateModelPaths(property); err != nil {
return err
}
}
} else {
m := models[*modelType]
if m.Properties != nil {
p.Properties = models[*modelType].Properties.Clone()
modelString := fmt.Sprintf("model(%s)", *modelType)
if p.Type == modelString {
p.Type = "map"
} else if p.Type == fmt.Sprintf("list(%s)", modelString) {
p.Type = "list"
}
if err := updateModelPaths(p); err != nil {
return err
}
} else if m.Property != nil {
p = m.Property.Clone()
}
}
}
err := updateModels(p, p.Properties, models)
if err != nil {
return err
}
}
return nil
}
func updateModelPaths(p *Property) error {
for _, prop := range p.Properties {
prop.Path = fmt.Sprintf("%s.%s", p.Path, prop.Name)
err := updateModelPaths(prop)
if err != nil {
return err
}
}
return nil
}
package reader
import (
"fmt"
"reflect"
"strings"
"github.com/google/uuid"
construct "github.com/klothoplatform/klotho/pkg/construct"
knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/klothoplatform/klotho/pkg/knowledgebase/properties"
"gopkg.in/yaml.v3"
)
type (
// Properties defines the structure of properties defined in yaml as a part of a template.
Properties map[string]*Property
// Property defines the structure of a property defined in yaml as a part of a template.
// these fields must be exactly the union of all the fields in the different property types.
Property struct {
Name string `json:"name" yaml:"name"`
// Type defines the type of the property
Type string `json:"type" yaml:"type"`
Description string `json:"description" yaml:"description"`
IsImportant bool `json:"important" yaml:"important"`
Namespace bool `json:"namespace" yaml:"namespace"`
DefaultValue any `json:"default_value" yaml:"default_value"`
Required bool `json:"required" yaml:"required"`
ConfigurationDisabled bool `json:"configuration_disabled" yaml:"configuration_disabled"`
DeployTime bool `json:"deploy_time" yaml:"deploy_time"`
OperationalRule *knowledgebase.PropertyRule `json:"operational_rule" yaml:"operational_rule"`
Properties Properties `json:"properties" yaml:"properties"`
// MinLength defines the minimum length of a string, list, set, or map (number of entries)
MinLength *int `yaml:"min_length"`
MaxLength *int `yaml:"max_length"`
MinValue *float64 `yaml:"min_value"`
MaxValue *float64 `yaml:"max_value"`
// UniqueItems defines whether the items in a list or set must be unique
UniqueItems *bool `yaml:"unique_items"`
// UniqueKeys defines whether the keys in a map must be unique (default true)
UniqueKeys *bool `yaml:"unique_keys"`
AllowedTypes construct.ResourceList `yaml:"allowed_types"`
SanitizeTmpl string `yaml:"sanitize"`
AllowedValues []string `yaml:"allowed_values"`
KeyProperty *Property `yaml:"key_property"`
ValueProperty *Property `yaml:"value_property"`
ItemProperty *Property `yaml:"item_property"`
Path string `json:"-" yaml:"-"`
}
)
func (p *Properties) UnmarshalYAML(n *yaml.Node) error {
type h Properties
var p2 h
err := n.Decode(&p2)
if err != nil {
return err
}
for name, property := range p2 {
property.Name = name
property.Path = name
setChildPaths(property, name)
p2[name] = property
}
*p = Properties(p2)
return nil
}
func (p *Properties) Convert() (knowledgebase.Properties, error) {
var errs error
props := knowledgebase.Properties{}
for name, prop := range *p {
propertyType, err := prop.Convert()
if err != nil {
errs = fmt.Errorf("%w\n%s", errs, err.Error())
continue
}
props[name] = propertyType
}
return props, errs
}
func (p *Property) Convert() (knowledgebase.Property, error) {
propertyType, err := InitializeProperty(p.Type)
if err != nil {
return nil, err
}
propertyType.Details().Path = p.Path
srcVal := reflect.ValueOf(p).Elem()
dstVal := reflect.ValueOf(propertyType).Elem()
for i := 0; i < srcVal.NumField(); i++ {
srcField := srcVal.Field(i)
fieldName := srcVal.Type().Field(i).Name
dstField := dstVal.FieldByName(fieldName)
if !dstField.IsValid() || !dstField.CanSet() {
continue
}
// Skip nil pointers
if (srcField.Kind() == reflect.Ptr || srcField.Kind() == reflect.Interface) && srcField.IsNil() {
continue
// skip empty arrays and slices
} else if (srcField.Kind() == reflect.Array || srcField.Kind() == reflect.Slice) && srcField.Len() == 0 {
continue
}
// Handle sub properties so we can recurse down the tree
switch fieldName {
case "Properties":
properties := srcField.Interface().(Properties)
var errs error
props := knowledgebase.Properties{}
for name, prop := range properties {
propertyType, err := prop.Convert()
if err != nil {
errs = fmt.Errorf("%w\n%s", errs, err.Error())
continue
}
props[name] = propertyType
}
if errs != nil {
return nil, fmt.Errorf("could not convert sub properties: %w", errs)
}
dstField.Set(reflect.ValueOf(props))
continue
case "KeyProperty", "ValueProperty":
if !strings.HasPrefix(p.Type, "map") {
return nil, fmt.Errorf("property must be 'map' (was %s) for %s", p.Type, fieldName)
}
keyType, valueType, hasElementTypes := strings.Cut(
strings.TrimSuffix(strings.TrimPrefix(p.Type, "map("), ")"),
",",
)
elemProp := srcField.Interface().(*Property)
// Add the element's type if it is not specified but is on the parent.
// For example, 'map(string,string)' on the parent means the key_property doesn't need 'type: string'
if hasElementTypes {
if fieldName == "KeyProperty" {
if elemProp.Type != "" && elemProp.Type != keyType {
return nil, fmt.Errorf("key property type must be %s (was %s)", keyType, elemProp.Type)
} else if elemProp.Type == "" {
elemProp.Type = keyType
}
} else {
if elemProp.Type != "" && elemProp.Type != valueType {
return nil, fmt.Errorf("value property type must be %s (was %s)", valueType, elemProp.Type)
} else if elemProp.Type == "" {
elemProp.Type = valueType
}
}
}
converted, err := elemProp.Convert()
if err != nil {
return nil, fmt.Errorf("could not convert %s: %w", fieldName, err)
}
srcField = reflect.ValueOf(converted)
case "ItemProperty":
if !strings.HasPrefix(p.Type, "list") && !strings.HasPrefix(p.Type, "set") {
return nil, fmt.Errorf("property must be 'list' or 'set' (was %s) for %s", p.Type, fieldName)
}
hasItemType := strings.Contains(p.Type, "(")
elemProp := srcField.Interface().(*Property)
if hasItemType {
itemType := strings.TrimSuffix(
strings.TrimPrefix(strings.TrimPrefix(p.Type, "list("), "set("),
")",
)
if elemProp.Type != "" && elemProp.Type != itemType {
return nil, fmt.Errorf("item property type must be %s (was %s)", itemType, elemProp.Type)
} else if elemProp.Type == "" {
elemProp.Type = itemType
}
}
converted, err := elemProp.Convert()
if err != nil {
return nil, fmt.Errorf("could not convert %s: %w", fieldName, err)
}
srcField = reflect.ValueOf(converted)
}
if srcField.Type().AssignableTo(dstField.Type()) {
dstField.Set(srcField)
continue
}
if dstField.Kind() == reflect.Ptr && srcField.Kind() == reflect.Ptr {
if srcField.Type().Elem().AssignableTo(dstField.Type().Elem()) {
dstField.Set(srcField)
continue
} else if srcField.Type().Elem().ConvertibleTo(dstField.Type().Elem()) {
val := srcField.Elem().Convert(dstField.Type().Elem())
// set dest field to a pointer of val
dstField.Set(reflect.New(dstField.Type().Elem()))
dstField.Elem().Set(val)
continue
}
}
if conversion, found := fieldConversion[fieldName]; found {
err := conversion(srcField, p, propertyType)
if err != nil {
return nil, err
}
continue
}
return nil, fmt.Errorf(
"could not assign %s#%s (%s) to field in %T (%s)",
p.Path, fieldName, srcField.Type(), propertyType, dstField.Type(),
)
}
return propertyType, nil
}
func setChildPaths(property *Property, currPath string) {
for name, child := range property.Properties {
child.Name = name
path := currPath + "." + name
child.Path = path
setChildPaths(child, path)
}
}
func (p Properties) Clone() Properties {
newProps := make(Properties, len(p))
for k, v := range p {
newProps[k] = v.Clone()
}
return newProps
}
func (p *Property) Clone() *Property {
cloned := *p
cloned.Properties = make(Properties, len(p.Properties))
for k, v := range p.Properties {
cloned.Properties[k] = v.Clone()
}
return &cloned
}
// fieldConversion is a map providing functionality on how to convert inputs into our internal types if they are not inherently the same structure
var fieldConversion = map[string]func(val reflect.Value, p *Property, kp knowledgebase.Property) error{
"SanitizeTmpl": func(val reflect.Value, p *Property, kp knowledgebase.Property) error {
sanitizeTmpl, ok := val.Interface().(string)
if !ok {
return fmt.Errorf("invalid sanitize template")
}
if sanitizeTmpl == "" {
return nil
}
// generate random uuid as the name of the template
name := uuid.New().String()
tmpl, err := knowledgebase.NewSanitizationTmpl(name, sanitizeTmpl)
if err != nil {
return err
}
dstField := reflect.ValueOf(kp).Elem().FieldByName("SanitizeTmpl")
dstField.Set(reflect.ValueOf(tmpl))
return nil
},
}
func InitializeProperty(ptype string) (knowledgebase.Property, error) {
if ptype == "" {
return nil, fmt.Errorf("property does not have a type")
}
parts := strings.Split(ptype, "(")
p, found := initializePropertyFunc[parts[0]]
if !found {
return nil, fmt.Errorf("unknown property type '%s'", ptype)
}
var val string
if len(parts) > 1 {
val = strings.TrimSuffix(strings.Join(parts[1:], "("), ")")
}
return p(val)
}
var initializePropertyFunc map[string]func(val string) (knowledgebase.Property, error)
func init() {
// initializePropertyFunc initialization is deferred to prevent cyclic initialization (a compiler error) with `InitializeProperty`
initializePropertyFunc = map[string]func(val string) (knowledgebase.Property, error){
"string": func(val string) (knowledgebase.Property, error) { return &properties.StringProperty{}, nil },
"int": func(val string) (knowledgebase.Property, error) { return &properties.IntProperty{}, nil },
"float": func(val string) (knowledgebase.Property, error) { return &properties.FloatProperty{}, nil },
"bool": func(val string) (knowledgebase.Property, error) { return &properties.BoolProperty{}, nil },
"resource": func(val string) (knowledgebase.Property, error) {
id := construct.ResourceId{}
err := id.UnmarshalText([]byte(val))
if err != nil {
return nil, fmt.Errorf("invalid resource id for property type %s: %w", val, err)
}
return &properties.ResourceProperty{
AllowedTypes: construct.ResourceList{id},
}, nil
},
"map": func(val string) (knowledgebase.Property, error) {
if val == "" {
return &properties.MapProperty{}, nil
}
args := strings.Split(val, ",")
if len(args) != 2 {
return nil, fmt.Errorf("invalid number of arguments for map property type: %s", val)
}
keyVal, err := InitializeProperty(args[0])
if err != nil {
return nil, err
}
valProp, err := InitializeProperty(args[1])
if err != nil {
return nil, err
}
return &properties.MapProperty{KeyProperty: keyVal, ValueProperty: valProp}, nil
},
"list": func(val string) (knowledgebase.Property, error) {
if val == "" {
return &properties.ListProperty{}, nil
}
itemProp, err := InitializeProperty(val)
if err != nil {
return nil, err
}
return &properties.ListProperty{ItemProperty: itemProp}, nil
},
"set": func(val string) (knowledgebase.Property, error) {
if val == "" {
return &properties.SetProperty{}, nil
}
itemProp, err := InitializeProperty(val)
if err != nil {
return nil, err
}
return &properties.SetProperty{ItemProperty: itemProp}, nil
},
"any": func(val string) (knowledgebase.Property, error) { return &properties.AnyProperty{}, nil },
}
}
package reader
import knowledgebase "github.com/klothoplatform/klotho/pkg/knowledgebase"
type (
// ResourceTemplate defines how rules are handled by the engine in terms of making sure they are functional in the graph
ResourceTemplate struct {
QualifiedTypeName string `json:"qualified_type_name" yaml:"qualified_type_name"`
DisplayName string `json:"display_name" yaml:"display_name"`
Properties Properties `json:"properties" yaml:"properties"`
Classification knowledgebase.Classification `json:"classification" yaml:"classification"`
PathSatisfaction knowledgebase.PathSatisfaction `json:"path_satisfaction" yaml:"path_satisfaction"`
AdditionalRules []knowledgebase.AdditionalRule `json:"additional_rules" yaml:"additional_rules"`
Consumption knowledgebase.Consumption `json:"consumption" yaml:"consumption"`
// DeleteContext defines the context in which a resource can be deleted
DeleteContext knowledgebase.DeleteContext `json:"delete_context" yaml:"delete_context"`
// Views defines the views that the resource should be added to as a distinct node
Views map[string]string `json:"views" yaml:"views"`
NoIac bool `json:"no_iac" yaml:"no_iac"`
DeploymentPermissions knowledgebase.DeploymentPermissions `json:"deployment_permissions" yaml:"deployment_permissions"`
SanitizeNameTmpl string `yaml:"sanitize_name"`
}
)
func (r *ResourceTemplate) Convert() (*knowledgebase.ResourceTemplate, error) {
kbProperties, err := r.Properties.Convert()
if err != nil {
return nil, err
}
var sanitizeTmpl *knowledgebase.SanitizeTmpl
if r.SanitizeNameTmpl != "" {
sanitizeTmpl, err = knowledgebase.NewSanitizationTmpl(r.QualifiedTypeName, r.SanitizeNameTmpl)
if err != nil {
return nil, err
}
}
return &knowledgebase.ResourceTemplate{
QualifiedTypeName: r.QualifiedTypeName,
DisplayName: r.DisplayName,
Properties: kbProperties,
AdditionalRules: r.AdditionalRules,
Classification: r.Classification,
PathSatisfaction: r.PathSatisfaction,
Consumption: r.Consumption,
DeleteContext: r.DeleteContext,
Views: r.Views,
NoIac: r.NoIac,
DeploymentPermissions: r.DeploymentPermissions,
SanitizeNameTmpl: sanitizeTmpl,
}, nil
}
package knowledgebase
import (
"errors"
"fmt"
"reflect"
"sort"
"strings"
"github.com/klothoplatform/klotho/pkg/collectionutil"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/set"
"gopkg.in/yaml.v3"
)
//go:generate mockgen -source=./resource_template.go --destination=./resource_template_mock_test.go --package=knowledgebase
//go:generate mockgen -source=./resource_template.go --destination=../engine/operational_eval/resource_template_mock_test.go --package=operational_eval
//go:generate mockgen -source=./resource_template.go --destination=../infra/state_reader/resource_template_mock_test.go --package=statereader
type (
// ResourceTemplate defines how rules are handled by the engine in terms of making sure they are functional in the graph
ResourceTemplate struct {
// QualifiedTypeName is the qualified type name of the resource
QualifiedTypeName string `json:"qualified_type_name" yaml:"qualified_type_name"`
// DisplayName is the common name that refers to the resource
DisplayName string `json:"display_name" yaml:"display_name"`
// Properties defines the properties that the resource has
Properties Properties `json:"properties" yaml:"properties"`
AdditionalRules []AdditionalRule `json:"additional_rules" yaml:"additional_rules"`
// Classification defines the classification of the resource
Classification Classification `json:"classification" yaml:"classification"`
// PathSatisfaction defines what paths must exist for the resource must be connected to and from
PathSatisfaction PathSatisfaction `json:"path_satisfaction" yaml:"path_satisfaction"`
// Consumption defines properties the resource may emit or consume from other resources it is connected to or expanded from
Consumption Consumption `json:"consumption" yaml:"consumption"`
// DeleteContext defines the context in which a resource can be deleted
DeleteContext DeleteContext `json:"delete_context" yaml:"delete_context"`
// Views defines the views that the resource should be added to as a distinct node
Views map[string]string `json:"views" yaml:"views"`
// NoIac defines if the resource should be ignored by the IaC engine
NoIac bool `json:"no_iac" yaml:"no_iac"`
// DeploymentPermissions defines the permissions that are required to deploy and tear down the resource
DeploymentPermissions DeploymentPermissions `json:"deployment_permissions" yaml:"deployment_permissions"`
// SanitizeNameTmpl defines a template that is used to sanitize the name of the resource
SanitizeNameTmpl *SanitizeTmpl `yaml:"sanitize_name"`
}
DeploymentPermissions struct {
// Deploy defines the permissions that are required to deploy the resource
Deploy []string `json:"deploy" yaml:"deploy"`
// TearDown defines the permissions that are required to tear down the resource
TearDown []string `json:"tear_down" yaml:"tear_down"`
// Update defines the permissions that are required to update the resource
Update []string `json:"update" yaml:"update"`
}
// PropertyDetails defines the common details of a property
PropertyDetails struct {
Name string `json:"name" yaml:"name"`
// DefaultValue has to be any because it may be a template and it may be a value of the correct type
Namespace bool `yaml:"namespace"`
// Required defines if the property is required
Required bool `json:"required" yaml:"required"`
// ConfigurationDisabled defines if the property is allowed to be configured by the user
ConfigurationDisabled bool `json:"configuration_disabled" yaml:"configuration_disabled"`
// DeployTime defines if the property is only available at deploy time
DeployTime bool `json:"deploy_time" yaml:"deploy_time"`
// OperationalRule defines a rule that is executed at runtime to determine the value of the property
OperationalRule *PropertyRule `json:"operational_rule" yaml:"operational_rule"`
// Description is a description of the property. This is not used in the engine solving,
// but is metadata returned by the `ListResourceTypes` CLI command.
Description string `json:"description" yaml:"description"`
// Path is the path to the property in the resource
Path string `json:"-" yaml:"-"`
// IsImportant is a flag to denote what properties are subjectively important to show in the InfracopilotUI via
// the `ListResourceTypes` CLI command. This field is not & should not be used in the engine.
IsImportant bool
}
// Property is an interface used to define a property that exists on a resource
// Properties are used to define the structure of a resource and how it is configured
// Each property implementation refers to a specific type of property, such as a string or a list, etc
Property interface {
// SetProperty sets the value of the property on the resource
SetProperty(resource *construct.Resource, value any) error
// AppendProperty appends the value to the property on the resource
AppendProperty(resource *construct.Resource, value any) error
// RemoveProperty removes the value from the property on the resource
RemoveProperty(resource *construct.Resource, value any) error
// Details returns the property details for the property
Details() *PropertyDetails
// Clone returns a clone of the property
Clone() Property
// Type returns the string representation of the type of the property, as it should appear in the resource template
Type() string
// GetDefaultValue returns the default value for the property,
// pertaining to the specific data being passed in for execution
GetDefaultValue(ctx DynamicContext, data DynamicValueData) (any, error)
// Validate ensures the value is valid for the property to `Set` (not `Append` for collection types)
// and returns an error if it is not
Validate(resource *construct.Resource, value any, ctx DynamicContext) error
// SubProperties returns the sub properties of the property, if any.
// This is used for properties that are complex structures, such as lists, sets, or maps
SubProperties() Properties
// Parse parses a given value to ensure it is the correct type for the property.
// If the given value cannot be converted to the respective property type an error is returned.
// The returned value will always be the correct type for the property
Parse(value any, ctx DynamicContext, data DynamicValueData) (any, error)
// ZeroValue returns the zero value for the property type
ZeroValue() any
// Contains returns true if the value contains the given value
Contains(value any, contains any) bool
}
// MapProperty is an interface for properties that implement map structures
MapProperty interface {
// Key returns the property representing the keys of the map
Key() Property
// Value returns the property representing the values of the map
Value() Property
}
// CollectionProperty is an interface for properties that implement collection structures
CollectionProperty interface {
// Item returns the structure of the items within the collection
Item() Property
}
// Properties is a map of properties
Properties map[string]Property
// Classification defines the classification of a resource
Classification struct {
// Is defines the classifications that the resource belongs to
Is []string `json:"is" yaml:"is"`
// Gives defines the attributes that the resource gives to other resources
Gives []Gives `json:"gives" yaml:"gives"`
}
// Gives defines an attribute that can be provided to other functionalities for the resource it belongs to
Gives struct {
// Attribute is the attribute that is given
Attribute string
// Functionality is the list of functionalities that the attribute is given to
Functionality []string
}
// DeleteContext is supposed to tell us when we are able to delete a resource based on its dependencies
DeleteContext struct {
// RequiresNoUpstream is a boolean that tells us if deletion relies on there being no upstream resources
RequiresNoUpstream bool `yaml:"requires_no_upstream" toml:"requires_no_upstream"`
// RequiresNoDownstream is a boolean that tells us if deletion relies on there being no downstream resources
RequiresNoDownstream bool `yaml:"requires_no_downstream" toml:"requires_no_downstream"`
// RequiresNoUpstreamOrDownstream is a boolean that tells us if deletion relies on there being no upstream or downstream resources
RequiresNoUpstreamOrDownstream bool `yaml:"requires_no_upstream_or_downstream" toml:"requires_no_upstream_or_downstream"`
}
Functionality string
)
const (
ErrRequiredProperty = "required property %s is not set on resource %s"
Compute Functionality = "compute"
Cluster Functionality = "cluster"
Storage Functionality = "storage"
Api Functionality = "api"
Messaging Functionality = "messaging"
Unknown Functionality = "Unknown"
)
func (g *Gives) UnmarshalJSON(content []byte) error {
givesString := string(content)
if givesString == "" {
return nil
}
gives := strings.Split(givesString, ":")
g.Attribute = strings.ReplaceAll(gives[0], "\"", "")
if len(gives) == 1 {
g.Functionality = []string{"*"}
return nil
}
g.Functionality = strings.Split(strings.ReplaceAll(gives[1], "\"", ""), ",")
return nil
}
func (g *Gives) UnmarshalYAML(n *yaml.Node) error {
givesString := n.Value
if givesString == "" {
return nil
}
gives := strings.Split(givesString, ":")
g.Attribute = strings.ReplaceAll(gives[0], "\"", "")
if len(gives) == 1 {
g.Functionality = []string{"*"}
return nil
}
g.Functionality = strings.Split(strings.ReplaceAll(gives[1], "\"", ""), ",")
return nil
}
func (p *Properties) Clone() Properties {
newProps := make(Properties, len(*p))
for k, v := range *p {
newProps[k] = v.Clone()
}
return newProps
}
func (template ResourceTemplate) Id() construct.ResourceId {
args := strings.Split(template.QualifiedTypeName, ":")
return construct.ResourceId{
Provider: args[0],
Type: args[1],
}
}
// CreateResource creates an empty resource for the given ID, running any sanitization rules on the ID.
// NOTE: Because of sanitization, once created callers must use the resulting ID for all future operations
// and not the input ID.
func CreateResource(kb TemplateKB, id construct.ResourceId) (*construct.Resource, error) {
rt, err := kb.GetResourceTemplate(id)
if err != nil {
return nil, fmt.Errorf("could not create resource: get template err: %w", err)
}
id.Name, err = rt.SanitizeName(id.Name)
if err != nil {
return nil, fmt.Errorf("could not create resource: %w", err)
}
return &construct.Resource{
ID: id,
Properties: make(construct.Properties),
}, nil
}
func (rt ResourceTemplate) SanitizeName(name string) (string, error) {
if rt.SanitizeNameTmpl == nil {
return name, nil
}
return rt.SanitizeNameTmpl.Execute(name)
}
func (template ResourceTemplate) GivesAttributeForFunctionality(attribute string, functionality Functionality) bool {
for _, give := range template.Classification.Gives {
if give.Attribute == attribute && (collectionutil.Contains(give.Functionality, string(functionality)) || collectionutil.Contains(give.Functionality, "*")) {
return true
}
}
return false
}
func (template ResourceTemplate) GetFunctionality() Functionality {
if len(template.Classification.Is) == 0 {
return Unknown
}
var functionality Functionality
for _, c := range template.Classification.Is {
matched := true
alreadySet := functionality != ""
switch c {
case "compute":
functionality = Compute
case "cluster":
functionality = Cluster
case "storage":
functionality = Storage
case "api":
functionality = Api
case "messaging":
functionality = Messaging
default:
matched = false
}
if matched && alreadySet {
return Unknown
}
}
if functionality == "" {
return Unknown
}
return functionality
}
func (template ResourceTemplate) ResourceContainsClassifications(needs []string) bool {
for _, need := range needs {
if !collectionutil.Contains(template.Classification.Is, need) && template.QualifiedTypeName != need {
return false
}
}
return true
}
func (template ResourceTemplate) GetNamespacedProperty() Property {
for _, property := range template.Properties {
if property.Details().Namespace {
return property
}
}
return nil
}
func (template ResourceTemplate) GetProperty(path string) Property {
fields := strings.Split(path, ".")
properties := template.Properties
FIELDS:
for i, field := range fields {
currFieldName := strings.Split(field, "[")[0]
found := false
for name, property := range properties {
if name != currFieldName {
continue
}
found = true
if len(fields) == i+1 {
// use a clone resource so we can modify the name in case anywhere in the path
// has index strings or map keys
clone := property.Clone()
details := clone.Details()
details.Path = path
return clone
} else {
properties = property.SubProperties()
if len(properties) == 0 {
if mp, ok := property.(MapProperty); ok {
clone := mp.Value().Clone()
details := clone.Details()
details.Path = path
return clone
} else if cp, ok := property.(CollectionProperty); ok {
clone := cp.Item().Clone()
details := clone.Details()
details.Path = path
return clone
}
}
}
continue FIELDS
}
if !found {
return nil
}
}
return nil
}
var ErrStopWalk = errors.New("stop walk")
// ReplacePath runs a simple [strings.ReplaceAll] on the path of the property and all of its sub properties.
// NOTE: this mutates the property, so make sure to [Property.Clone] it first if you don't want that.
func ReplacePath(p Property, original, replacement string) {
p.Details().Path = strings.ReplaceAll(p.Details().Path, original, replacement)
for _, prop := range p.SubProperties() {
ReplacePath(prop, original, replacement)
}
}
func (tmpl ResourceTemplate) LoopProperties(res *construct.Resource, addProp func(Property) error) error {
queue := []Properties{tmpl.Properties}
var props Properties
var errs error
for len(queue) > 0 {
props, queue = queue[0], queue[1:]
propKeys := make([]string, 0, len(props))
for k := range props {
propKeys = append(propKeys, k)
}
sort.Strings(propKeys)
for _, key := range propKeys {
prop := props[key]
err := addProp(prop)
if err != nil {
if errors.Is(err, ErrStopWalk) {
return nil
}
errs = errors.Join(errs, err)
continue
}
if strings.HasPrefix(prop.Type(), "list") || strings.HasPrefix(prop.Type(), "set") {
p, err := res.GetProperty(prop.Details().Path)
if err != nil || p == nil {
continue
}
// Because lists/sets will start as empty, do not recurse into their sub-properties if its not set.
// To allow for defaults within list objects and operational rules to be run, we will look in the property
// to see if there are values.
if strings.HasPrefix(prop.Type(), "list") {
length := reflect.ValueOf(p).Len()
for i := 0; i < length; i++ {
subProperties := make(Properties)
for subK, subProp := range prop.SubProperties() {
propTemplate := subProp.Clone()
ReplacePath(propTemplate, prop.Details().Path, fmt.Sprintf("%s[%d]", prop.Details().Path, i))
subProperties[subK] = propTemplate
}
if len(subProperties) > 0 {
queue = append(queue, subProperties)
}
}
} else if strings.HasPrefix(prop.Type(), "set") {
hs, ok := p.(set.HashedSet[string, any])
if !ok {
errs = errors.Join(errs, fmt.Errorf("could not cast property to set"))
continue
}
for i := range hs.ToSlice() {
subProperties := make(Properties)
for subK, subProp := range prop.SubProperties() {
propTemplate := subProp.Clone()
ReplacePath(propTemplate, prop.Details().Path, fmt.Sprintf("%s[%d]", prop.Details().Path, i))
subProperties[subK] = propTemplate
}
if len(subProperties) > 0 {
queue = append(queue, subProperties)
}
}
}
} else if prop.SubProperties() != nil {
queue = append(queue, prop.SubProperties())
}
}
}
return errs
}
func IsCollectionProperty(p Property) bool {
if _, ok := p.(CollectionProperty); ok {
return true
}
if _, ok := p.(MapProperty); ok {
return true
}
return false
}
package knowledgebase
import (
"bytes"
"crypto/sha256"
"fmt"
"regexp"
"strings"
"sync"
"text/template"
)
type (
SanitizeTmpl struct {
template *template.Template
}
// SanitizeError is returned when a value is sanitized if the input is not valid. The Sanitized field
// is always the same type as the Input field.
SanitizeError struct {
Input any
Sanitized any
}
)
func NewSanitizationTmpl(name string, tmpl string) (*SanitizeTmpl, error) {
t, err := template.New(name + "/sanitize").
Funcs(template.FuncMap{
"replace": func(pattern, replace, name string) (string, error) {
re, err := regexp.Compile(pattern)
if err != nil {
return name, err
}
return re.ReplaceAllString(name, replace), nil
},
"length": func(min, max int, name string) string {
if len(name) < min {
return name + strings.Repeat("0", min-len(name))
}
if len(name) > max {
base := name[:max-8]
h := sha256.New()
fmt.Fprint(h, name)
x := fmt.Sprintf("%x", h.Sum(nil))
return base + x[:8]
}
return name
},
"lower": strings.ToLower,
"upper": strings.ToUpper,
}).
Parse(tmpl)
return &SanitizeTmpl{
template: t,
}, err
}
var sanitizeBufs = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}
func (t SanitizeTmpl) Execute(value string) (string, error) {
buf := sanitizeBufs.Get().(*bytes.Buffer)
defer sanitizeBufs.Put(buf)
buf.Reset()
err := t.template.Execute(buf, value)
if err != nil {
return value, fmt.Errorf("could not execute sanitize name template on %q: %w", value, err)
}
return strings.TrimSpace(buf.String()), nil
}
func (t SanitizeTmpl) Check(value string) error {
sanitized, err := t.Execute(value)
if err != nil {
return err
}
if sanitized != value {
return &SanitizeError{
Input: value,
Sanitized: sanitized,
}
}
return nil
}
func (err SanitizeError) Error() string {
return fmt.Sprintf("invalid value %q, suggested value: %q", err.Input, err.Sanitized)
}
package logging
import (
"errors"
"io"
"os"
"path/filepath"
"strings"
"sync"
"go.uber.org/zap/zapcore"
)
type CategoryWriter struct {
Encoder zapcore.Encoder
LogRootPath string
files *sync.Map // map[string]io.Writer
}
func NewCategoryWriter(enc zapcore.Encoder, logRootPath string) *CategoryWriter {
return &CategoryWriter{
Encoder: enc,
LogRootPath: logRootPath,
files: &sync.Map{},
}
}
func (c *CategoryWriter) Enabled(lvl zapcore.Level) bool {
return true
}
func (c *CategoryWriter) With(fields []zapcore.Field) zapcore.Core {
clone := c.clone()
for i := range fields {
fields[i].AddTo(clone.Encoder)
}
return clone
}
func (c *CategoryWriter) Check(ent zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
if c.Enabled(ent.Level) {
return ce.AddCore(ent, c)
}
return ce
}
func (c *CategoryWriter) Write(ent zapcore.Entry, fields []zapcore.Field) error {
if ent.LoggerName == "" {
return nil
}
categ, rest, _ := strings.Cut(ent.LoggerName, ".")
categ = strings.TrimSpace(categ)
categ = strings.ReplaceAll(categ, string(os.PathSeparator), "_")
if categ == "" {
return nil
}
ent.LoggerName = rest // trim the category from the logger name for better readability
w, ok := c.files.Load(categ)
if !ok {
err := os.MkdirAll(c.LogRootPath, 0755)
if err != nil {
return err
}
logPath := filepath.Join(c.LogRootPath, categ+".log")
f, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return err
}
var loaded bool
w, loaded = c.files.LoadOrStore(categ, f)
if loaded {
f.Close()
} else {
// Don't pass `O_TRUNC` to the open in case we open the file simultaneously.
// Wait until we know for sure that the one we opened is the cannonical one in the map.
if _, err := f.Seek(0, io.SeekStart); err != nil {
return err
}
if err := f.Truncate(0); err != nil {
return err
}
}
}
buf, err := c.Encoder.EncodeEntry(ent, fields)
if err != nil {
return err
}
_, err = w.(io.Writer).Write(buf.Bytes())
buf.Free()
if err != nil {
return err
}
if ent.Level > zapcore.ErrorLevel {
if syncer, ok := w.(interface{ Sync() error }); ok {
syncer.Sync() //nolint:errcheck
}
}
return nil
}
func (c *CategoryWriter) Sync() error {
var errs error
c.files.Range(func(key, value interface{}) bool {
if syncer, ok := value.(interface{ Sync() error }); ok {
errs = errors.Join(errs, syncer.Sync())
}
return false
})
return errs
}
func (c *CategoryWriter) clone() *CategoryWriter {
return &CategoryWriter{
Encoder: c.Encoder.Clone(),
LogRootPath: c.LogRootPath,
files: c.files,
}
}
package logging
import (
"bufio"
"bytes"
"context"
"io"
"os/exec"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
type loggerWriter struct {
logger *zap.Logger
level zapcore.Level
}
type CommandLogger struct {
RootLogger *zap.Logger
StdoutLevel zapcore.Level
StderrLevel zapcore.Level
}
func NewLoggerWriter(logger *zap.Logger, level zapcore.Level) io.Writer {
return loggerWriter{logger: logger, level: level}
}
func (w loggerWriter) Write(p []byte) (n int, err error) {
var lines []string
if bytes.Contains(p, []byte{'\n'}) {
lineBytes := bytes.Split(p, []byte{'\n'})
lines = make([]string, 0, len(lineBytes))
for _, line := range lineBytes {
if len(line) != 0 {
lines = append(lines, string(line))
}
}
} else {
lines = []string{string(p)}
}
for _, line := range lines {
if ce := w.logger.Check(w.level, line); ce != nil {
ce.Write()
}
}
return len(p), nil
}
func (w loggerWriter) ReadFrom(r io.Reader) (int64, error) {
buf := bufio.NewScanner(r)
var n int64
for buf.Scan() {
txt := buf.Text()
if ce := w.logger.Check(w.level, txt); ce != nil {
ce.Write()
}
n += int64(len(txt))
}
return n, buf.Err()
}
func Command(ctx context.Context, cfg CommandLogger, name string, arg ...string) *exec.Cmd {
cmd := exec.CommandContext(ctx, name, arg...)
cmd.Stdout = &loggerWriter{logger: cfg.RootLogger.Named("stdout"), level: cfg.StdoutLevel}
cmd.Stderr = &loggerWriter{logger: cfg.RootLogger.Named("stderr"), level: cfg.StderrLevel}
return cmd
}
package logging
import (
"strings"
"sync"
"go.uber.org/zap/zapcore"
)
// EntryLeveller is a zapcore.Core that filters log entries based on the module name
// similar to Log4j or python's logging module.
type EntryLeveller struct {
zapcore.Core
levels sync.Map // map[string]zapcore.Level
}
func NewEntryLeveller(core zapcore.Core, levels map[string]zapcore.Level) *EntryLeveller {
el := &EntryLeveller{Core: core}
for k, v := range levels {
el.levels.Store(k, v)
}
return el
}
func (el *EntryLeveller) With(f []zapcore.Field) zapcore.Core {
next := &EntryLeveller{
Core: el.Core.With(f),
}
el.levels.Range(func(k, v interface{}) bool {
next.levels.Store(k, v)
return true
})
return next
}
func (el *EntryLeveller) checkModule(e zapcore.Entry, ce *zapcore.CheckedEntry, module string) (*zapcore.CheckedEntry, bool) {
if level, ok := el.levels.Load(module); ok {
el.levels.Store(e.LoggerName, level)
if e.Level < level.(zapcore.Level) {
return nil, true
}
return ce.AddCore(e, el), true
}
return nil, false
}
func (el *EntryLeveller) Check(e zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
if ce, ok := el.checkModule(e, ce, e.LoggerName); ok {
return ce
}
if e.LoggerName == "" {
return el.Core.Check(e, ce)
}
nameParts := strings.Split(e.LoggerName, ".")
for i := len(nameParts); i > 0; i-- {
module := strings.Join(nameParts[:i], ".")
if ce, ok := el.checkModule(e, ce, module); ok {
return ce
}
}
if ce, ok := el.checkModule(e, ce, ""); ok {
return ce
}
return el.Core.Check(e, ce)
}
package logging
import (
"context"
"go.uber.org/zap"
)
type contextKey string
var logKey contextKey = "log"
func GetLogger(ctx context.Context) *zap.Logger {
l := ctx.Value(logKey)
if l == nil {
return zap.L()
}
return l.(*zap.Logger)
}
func WithLogger(ctx context.Context, logger *zap.Logger) context.Context {
return context.WithValue(ctx, logKey, logger)
}
package logging
import (
"fmt"
"os"
"strings"
"time"
prettyconsole "github.com/thessem/zap-prettyconsole"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"golang.org/x/term"
)
type LogOpts struct {
Verbose bool
Color string
CategoryLogsDir string
Encoding string
DefaultLevels map[string]zapcore.Level
}
func (opts LogOpts) Encoder() zapcore.Encoder {
switch opts.Encoding {
case "json":
if opts.Verbose {
return zapcore.NewJSONEncoder(zap.NewDevelopmentEncoderConfig())
} else {
return zapcore.NewJSONEncoder(zap.NewProductionEncoderConfig())
}
case "console", "pretty_console", "":
useColor := true
switch opts.Color {
case "auto":
useColor = term.IsTerminal(int(os.Stderr.Fd()))
case "always", "on":
useColor = true
case "never", "off":
useColor = false
}
if useColor {
cfg := prettyconsole.NewEncoderConfig()
cfg.EncodeTime = TimeOffsetFormatter(time.Now(), useColor)
return prettyconsole.NewEncoder(cfg)
}
cfg := zap.NewDevelopmentEncoderConfig()
cfg.EncodeTime = TimeOffsetFormatter(time.Now(), useColor)
return zapcore.NewConsoleEncoder(cfg)
default:
panic(fmt.Errorf("unknown encoding %q", opts.Encoding))
}
}
func (opts LogOpts) EntryLeveller(core zapcore.Core) zapcore.Core {
levels := opts.DefaultLevels
if levelEnv, ok := os.LookupEnv("LOG_LEVEL"); ok {
values := strings.Split(levelEnv, ",")
levels = make(map[string]zapcore.Level, len(values))
for _, v := range values {
k, v, ok := strings.Cut(v, "=")
if !ok {
continue
}
lvl, err := zapcore.ParseLevel(v)
if err != nil {
continue
}
levels[k] = lvl
}
}
if len(levels) > 0 {
core = NewEntryLeveller(core, levels)
}
return core
}
func (opts LogOpts) CategoryCore(core zapcore.Core) zapcore.Core {
if opts.CategoryLogsDir != "" {
var categEnc zapcore.Encoder
switch opts.Encoding {
case "json":
categEnc = zapcore.NewJSONEncoder(zap.NewDevelopmentEncoderConfig())
case "console", "pretty_console", "":
categEnc = zapcore.NewConsoleEncoder(zap.NewDevelopmentEncoderConfig())
default:
panic(fmt.Errorf("unknown encoding %q", opts.Encoding))
}
core = zapcore.NewTee(
core,
NewCategoryWriter(categEnc, opts.CategoryLogsDir),
)
}
return core
}
func (opts LogOpts) NewCore(w zapcore.WriteSyncer) zapcore.Core {
enc := opts.Encoder()
leveller := zap.NewAtomicLevel()
if opts.Verbose {
leveller.SetLevel(zap.DebugLevel)
} else {
leveller.SetLevel(zap.InfoLevel)
}
core := zapcore.NewCore(enc, w, leveller)
core = opts.EntryLeveller(core)
core = opts.CategoryCore(core)
return core
}
func (opts LogOpts) NewLogger() *zap.Logger {
return zap.New(opts.NewCore(os.Stderr))
}
// TimeOffsetFormatter returns a time encoder that formats the time as an offset from the start time.
// This is mostly useful for CLI logging not long-standing services as times beyond a few minutes will
// be less readable.
func TimeOffsetFormatter(start time.Time, color bool) zapcore.TimeEncoder {
var colStart = "\x1b[90m"
var colEnd = "\x1b[0m"
if !color {
colStart = ""
colEnd = ""
}
return func(t time.Time, e zapcore.PrimitiveArrayEncoder) {
diff := t.Sub(start)
if diff < time.Second {
e.AppendString(fmt.Sprintf(" %s%3dms%s", colStart, diff.Milliseconds(), colEnd))
} else if diff < 5*time.Minute {
e.AppendString(fmt.Sprintf("%s%5.1fs%s", colStart, diff.Seconds(), colEnd))
} else {
e.AppendString(fmt.Sprintf("%s%5.1fm%s", colStart, diff.Minutes(), colEnd))
}
}
}
package multierr
import (
"bytes"
"errors"
"fmt"
)
type Error []error
func (e Error) Error() string {
switch len(e) {
case 0:
// Generally won't be called, but here for completion sake
return "<nil>"
case 1:
return e[0].Error()
default:
buf := new(bytes.Buffer)
fmt.Fprintf(buf, "%d errors occurred:", len(e))
for _, err := range e {
fmt.Fprintf(buf, `
* %v`, err)
}
return buf.String()
}
}
// Append will mutate e and append the error. Will no-op if `err == nil`.
// Typical usage should be via auto-referencing [syntax sugar](https://go.dev/ref/spec#Calls):
//
// var e Error
// e.Append(err)
func (e *Error) Append(err error) {
switch {
case e == nil:
// if the pointer to the array is nil, nothing we can do.
// this shouldn't normally happen unless callers are for some reason
// using `*Error`, which they shouldn't (`Error` as an array is already a pointer)
case err == nil:
// Do nothing
case *e == nil:
*e = Error{err}
default:
*e = append(*e, err)
}
}
// Append adds err2 to err1.
// - If err1 and err2 are nil, returns nil
// - If err1 is nil, returns an [Error] with only err2
// - If err2 is nil, returns an [Error] with only err1
// - If err1 is an [Error], it returns a copy with err2 appended
// - Otherwise, returns a new [Error] with err1 and err2 as the sole elements
// NOTE: unlike `err1.Append`, this does not mutate err1.
func Append(err1, err2 error) Error {
switch {
case err1 == nil && err2 == nil:
return nil
case err1 == nil:
return Error{err2}
case err2 == nil:
return Error{err1}
}
if merr, ok := err1.(Error); ok {
merr.Append(err2)
return merr
}
return Error{err1, err2}
}
// ErrOrNil is used to convert this multierr into a [error]. This is necessary because it is a typed nil
//
// func example() error {
// var e Error
// return e
// }
// if example() != nil {
// ! this will run!
// }
//
// in otherwords,
//
// (Error)(nil) != nil
//
// Additionally, if there's only a single error, it will automatically unwrap it.
func (e Error) ErrOrNil() error {
switch len(e) {
case 0:
return nil
case 1:
return e[0]
default:
return e
}
}
// Unwrap implements the interface used in [errors.Unwrap]
func (e Error) Unwrap() error {
switch len(e) {
case 0:
return nil
case 1:
return e[0]
default:
return e[1:]
}
}
// As implements the interface used in [errors.As] by iterating through the members
// returning true on the first match.
func (e Error) As(target interface{}) bool {
for _, err := range e {
if errors.As(err, target) {
return true
}
}
return false
}
// Is implements the interface used in [errors.Is] by iterating through the members
// returning true on the first match.
func (e Error) Is(target error) bool {
for _, err := range e {
if errors.Is(err, target) {
return true
}
}
return false
}
package parseutils
import (
"fmt"
"regexp"
)
// ExpressionExtractor returns a function that returns up to n balanced expressions for the supplied start and end delimiters.
//
// The escape argument is used for detecting escaped delimiters and should typically be either `\` or `\\`
// depending on the format of the input string.
func ExpressionExtractor(escape string, start, end rune) func(input string, n int) []string {
return func(input string, n int) []string {
escapedStartPattern := regexp.MustCompile(fmt.Sprintf(`^[^%c]*?((?:%s)*)\%c`, start, escape, start))
escapedEndPattern := regexp.MustCompile(fmt.Sprintf(`^[^%c]*?((?:%s)*)\%c`, end, escape, end))
sCount := 0
eCount := 0
exprStartIndex := -1
lastMatchIndex := -1
var expressions []string
for i := 0; i < len(input); i++ {
switch rune(input[i]) {
case start:
match := escapedStartPattern.FindStringSubmatch(input[lastMatchIndex+1:])
if match[1] == "" || len(match[1])%len(escape) != 0 {
sCount++
}
lastMatchIndex = i
if exprStartIndex < 0 {
exprStartIndex = i
}
case end:
match := escapedEndPattern.FindStringSubmatch(input[lastMatchIndex+1:])
if match[1] == "" || len(match[1])%len(escape) != 0 {
eCount++
}
lastMatchIndex = i
}
if sCount > 0 && sCount == eCount && exprStartIndex >= 0 {
expressions = append(expressions, input[exprStartIndex:i+1])
if n > 0 && len(expressions) == n {
return expressions
}
// reset counters for next expression
exprStartIndex = -1
sCount = 0
eCount = 0
}
}
return expressions
}
}
package aws
import (
"encoding/json"
"fmt"
"sort"
"github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/engine/solution"
"github.com/klothoplatform/klotho/pkg/set"
"go.uber.org/zap"
)
// Permissions returns the permissions for the AWS provider
func DeploymentPermissionsPolicy(ctx solution.Solution) ([]byte, error) {
policy := &construct.Resource{
ID: construct.ResourceId{
Provider: "aws",
Type: "iam_policy",
Name: "deployment_permissions",
},
Properties: construct.Properties{
"Policy": map[string]any{
"Version": "2012-10-17",
},
},
}
kb := ctx.KnowledgeBase()
policyRt, err := kb.GetResourceTemplate(policy.ID)
if err != nil {
return nil, err
}
if policyRt == nil {
return nil, fmt.Errorf("resource template not found for resource %s", policy.ID)
}
// Find the StatementProperty so we can use its methods
statementProperty := policyRt.GetProperty("Policy.Statement")
actions := make(set.Set[string])
err = construct.WalkGraph(ctx.DataflowGraph(), func(id construct.ResourceId, resource *construct.Resource, nerr error) error {
if nerr != nil {
return nerr
}
if id.Provider != "aws" {
return nil
}
rt, err := kb.GetResourceTemplate(resource.ID)
if err != nil {
return err
}
if rt == nil {
return fmt.Errorf("resource template not found for resource %s", resource.ID)
}
if rt.NoIac {
return nil
}
resActions := make(set.Set[string])
resActions.Add(rt.DeploymentPermissions.Deploy...)
resActions.Add(rt.DeploymentPermissions.TearDown...)
resActions.Add(rt.DeploymentPermissions.Update...)
if len(resActions) == 0 {
zap.S().Warnf("No deployment permissions found for resource %s", resource.ID)
return nil
}
actions.AddFrom(resActions)
return nil
})
if err != nil {
return nil, err
}
if actions.Len() == 0 {
return nil, nil
}
actionList := actions.ToSlice()
sort.Strings(actionList)
statement := map[string]any{
"Effect": "Allow",
"Action": actionList,
"Resource": "*",
}
err = statementProperty.AppendProperty(policy, statement)
if err != nil {
return nil, err
}
pol, err := policy.GetProperty("Policy")
if err != nil {
return nil, err
}
return json.MarshalIndent(pol, "", " ")
}
package query
import (
"fmt"
sitter "github.com/smacker/go-tree-sitter"
)
type NextFunc[T any] func() (T, bool)
type MatchNodes = map[string]*sitter.Node
type NextMatchFunc = NextFunc[MatchNodes]
// Exec returns a function that acts as an iterator, each call will
// loop over the next match lazily and populate the results map with a mapping
// of field name as defined in the query to mapped node.
func Exec(lang *sitter.Language, c *sitter.Node, q string) NextMatchFunc {
if c == nil {
return func() (map[string]*sitter.Node, bool) {
return nil, false
}
}
query, err := sitter.NewQuery([]byte(q), lang)
if err != nil {
// Panic because this is a programmer error with the query string.
panic(fmt.Errorf("Error constructing query for %s: %w", q, err))
}
cursor := sitter.NewQueryCursor()
cursor.Exec(query, c)
nextMatch := func() (map[string]*sitter.Node, bool) {
match, found := cursor.NextMatch()
if !found || match == nil {
return nil, false
}
results := make(map[string]*sitter.Node)
for _, capture := range match.Captures {
results[query.CaptureNameForId(capture.Index)] = capture.Node
}
return results, true
}
return nextMatch
}
func Collect[T any](f NextFunc[T]) []T {
var results []T
for {
if elem, found := f(); found {
results = append(results, elem)
} else {
break
}
}
return results
}
package reflectutil
import (
"fmt"
"reflect"
)
func MapContainsKey(m any, key interface{}) (bool, error) {
var mapValue reflect.Value
if mValue, ok := m.(reflect.Value); ok {
mapValue = mValue
} else {
mapValue = reflect.ValueOf(m)
}
if mapValue.Kind() != reflect.Map {
return false, fmt.Errorf("value is not a map")
}
keyValue := reflect.ValueOf(key)
if !keyValue.IsValid() {
return false, fmt.Errorf("invalid key")
}
return mapValue.MapIndex(keyValue).IsValid(), nil
}
package reflectutil
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
)
/*
GetConcreteValue returns the concrete value of a reflect.Value.
This function is used to get the concrete value of a reflect.Value even if it is a pointer or interface.
Concrete values are values that are not pointers or interfaces (including maps, slices, structs, etc.).
*/
func GetConcreteValue(v reflect.Value) any {
if v.IsValid() {
return GetConcreteElement(v).Interface()
}
return nil
}
// IsNotConcrete returns true if the reflect.Value is a pointer or interface.
func IsNotConcrete(v reflect.Value) bool {
return v.Kind() == reflect.Pointer || v.Kind() == reflect.Interface
}
// GetConcreteElement returns the concrete reflect.Value of a reflect.Value
// when it is a pointer or interface or the same reflect.Value when it is already concrete.
func GetConcreteElement(v reflect.Value) reflect.Value {
for IsNotConcrete(v) {
v = v.Elem()
}
return v
}
// GetField returns the [reflect.Value] of a field in a struct or map.
func GetField(v reflect.Value, fieldExpr string) (reflect.Value, error) {
if v.Kind() == reflect.Invalid {
return reflect.Value{}, fmt.Errorf("value is nil")
}
fields := SplitPath(fieldExpr)
for _, field := range fields {
if strings.Contains(field, "[") != strings.Contains(field, "]") {
return reflect.Value{}, errors.New("invalid path: unclosed brackets ")
}
v = GetConcreteElement(v)
// Handle array/slice indices
if strings.Contains(field, "[") && !strings.Contains(field, ".") {
fieldName := field[:strings.Index(field, "[")]
indexStr := field[strings.Index(field, "[")+1 : strings.Index(field, "]")]
index, err := strconv.Atoi(indexStr)
if err != nil {
return reflect.Value{}, err
}
if fieldName != "" {
switch v.Kind() {
case reflect.Map:
v = GetConcreteElement(v.MapIndex(reflect.ValueOf(fieldName)))
case reflect.Struct:
v = GetConcreteElement(v.FieldByName(fieldName))
default:
return reflect.Value{}, fmt.Errorf("field is not a struct or map: %s", fieldName)
}
if !v.IsValid() {
return reflect.Value{}, fmt.Errorf("invalid field name: %s", fieldName)
}
}
switch v.Kind() {
case reflect.Slice, reflect.Array:
if index >= v.Len() {
return reflect.Value{}, fmt.Errorf("index out of range: %d", index)
}
v = v.Index(index)
default:
return reflect.Value{}, fmt.Errorf("field is not a slice or array: %s", fieldName)
}
} else {
field = strings.TrimSuffix(strings.TrimLeft(field, ".["), "]")
switch v.Kind() {
case reflect.Map:
if v.Type().Key().Kind() != reflect.String {
return reflect.Value{}, fmt.Errorf("unsupported map key type: %s: key type must be 'String'", v.Type().Key())
}
v = v.MapIndex(reflect.ValueOf(field))
if !v.IsValid() {
return reflect.Value{}, fmt.Errorf("invalid map key: %s", field)
}
case reflect.Struct:
v = v.FieldByName(field)
if !v.IsValid() {
return reflect.Value{}, fmt.Errorf("invalid field name: %s", field)
}
case reflect.Slice, reflect.Array:
index, err := strconv.Atoi(field)
if err != nil {
return reflect.Value{}, fmt.Errorf("invalid slice or array index: %s", field)
}
if index >= v.Len() {
return reflect.Value{}, fmt.Errorf("index out of range: %d", index)
}
v = v.Index(index)
default:
return reflect.Value{}, fmt.Errorf("unsupported type for field: %s", field)
}
}
}
return v, nil
}
func GetTypedField[T any](v reflect.Value, fieldExpr string) (T, bool) {
var zero T
fieldValue, err := GetField(v, fieldExpr)
if err != nil {
return zero, false
}
return GetTypedValue[T](fieldValue)
}
func GetTypedValue[T any](v any) (T, bool) {
var typedValue T
var ok bool
var tKind reflect.Kind
var rVal reflect.Value
if rVal, ok = v.(reflect.Value); !ok {
rVal = reflect.ValueOf(v)
}
tKind = rVal.Kind()
if tKind != reflect.Pointer && tKind != reflect.Interface {
typedValue, ok = GetConcreteValue(rVal).(T)
} else {
typedValue, ok = rVal.Interface().(T)
}
return typedValue, ok
}
func TracePath(v reflect.Value, fieldExpr string) ([]reflect.Value, error) {
if !v.IsValid() {
return nil, fmt.Errorf("value is invalid")
}
trace := []reflect.Value{v}
if fieldExpr == "" {
return trace, nil
}
fields := strings.Split(fieldExpr, ".")
for _, field := range fields {
last := trace[len(trace)-1]
next, err := GetField(last, field)
if err != nil {
return nil, err
}
trace = append(trace, next)
}
return trace, nil
}
// FirstOfType returns the first value in the slice that matches the specified type.
// If no matching value is found, it returns the zero value of the type and false.
func FirstOfType[T any](values []reflect.Value) (T, bool) {
var zero T
for _, v := range values {
if v.CanInterface() {
if val, ok := v.Interface().(T); ok {
return val, true
}
}
}
return zero, false
}
func LastOfType[T any](values []reflect.Value) (T, bool) {
// Create a new slice with reversed order
reversed := make([]reflect.Value, len(values))
for i, v := range values {
reversed[len(values)-1-i] = v
}
// Use FirstOfType on the reversed slice
return FirstOfType[T](reversed)
}
// IsAnyOf returns true if the [reflect.Value] is any of the specified types.
func IsAnyOf(v reflect.Value, types ...reflect.Kind) bool {
for _, t := range types {
if v.Kind() == t {
return true
}
}
return false
}
// SplitPath splits a path string into parts separated by '.' and '[', ']'.
// It is used to split a path string into parts that can be used to access fields in a slice, array, struct, or map.
// Bracketed components are treated as a single part, including the brackets.
func SplitPath(path string) []string {
var parts []string
bracket := 0
lastPartIdx := 0
for i := 0; i < len(path); i++ {
switch path[i] {
case '.':
if bracket == 0 {
if i > lastPartIdx {
parts = append(parts, path[lastPartIdx:i])
}
lastPartIdx = i
}
case '[':
if bracket == 0 {
if i > lastPartIdx {
parts = append(parts, path[lastPartIdx:i])
}
lastPartIdx = i
}
bracket++
case ']':
bracket--
if bracket == 0 {
parts = append(parts, path[lastPartIdx:i+1])
lastPartIdx = i + 1
}
}
if i == len(path)-1 && lastPartIdx <= i {
parts = append(parts, path[lastPartIdx:])
}
}
return parts
}
package set
import (
"sort"
"gopkg.in/yaml.v3"
)
type HashedSet[K comparable, T any] struct {
Hasher func(T) K
M map[K]T
// Less is used to sort the keys of the set when converting to a slice.
// If Less is nil, the keys will be sorted in an arbitrary order according to [map] iteration.
Less func(K, K) bool
}
func HashedSetOf[K comparable, T any](hasher func(T) K, vs ...T) HashedSet[K, T] {
s := HashedSet[K, T]{Hasher: hasher}
s.Add(vs...)
return s
}
func (s *HashedSet[K, T]) Add(vs ...T) {
if s.M == nil {
s.M = make(map[K]T)
}
for _, v := range vs {
hash := s.Hasher(v)
s.M[hash] = v
}
}
func (s *HashedSet[K, T]) Remove(v T) bool {
if s.M == nil {
return false
}
hash := s.Hasher(v)
_, ok := s.M[hash]
delete(s.M, hash)
return ok
}
func (s HashedSet[K, T]) Contains(v T) bool {
if s.M == nil {
return false
}
hash := s.Hasher(v)
_, ok := s.M[hash]
return ok
}
func (s HashedSet[K, T]) ContainsAll(vs ...T) bool {
for _, v := range vs {
if !s.Contains(v) {
return false
}
}
return true
}
func (s HashedSet[K, T]) ContainsAny(vs ...T) bool {
for _, v := range vs {
if s.Contains(v) {
return true
}
}
return false
}
func (s HashedSet[K, T]) Len() int {
return len(s.M)
}
func (s HashedSet[K, T]) ToSlice() []T {
if s.M == nil {
return nil
}
slice := make([]T, 0, len(s.M))
if s.Less != nil {
keys := make([]K, 0, len(s.M))
for k := range s.M {
keys = append(keys, k)
}
sort.Slice(keys, func(i, j int) bool {
return s.Less(keys[i], keys[j])
})
for _, k := range keys {
slice = append(slice, s.M[k])
}
} else {
for k := range s.M {
slice = append(slice, s.M[k])
}
}
return slice
}
func (s HashedSet[K, T]) ToMap() map[K]T {
m := make(map[K]T, len(s.M))
for k, v := range s.M {
m[k] = v
}
return m
}
func (s HashedSet[K, T]) Union(other HashedSet[K, T]) HashedSet[K, T] {
union := make(map[K]T)
for k := range s.M {
v := s.M[k]
union[k] = v
}
for k := range other.M {
v := other.M[k]
union[k] = v
}
return HashedSet[K, T]{
Hasher: s.Hasher,
M: union,
Less: s.Less,
}
}
func (s HashedSet[K, T]) Intersection(other HashedSet[K, T]) HashedSet[K, T] {
intersection := HashedSet[K, T]{
Hasher: s.Hasher,
M: make(map[K]T),
Less: s.Less,
}
for k := range s.M {
if _, ok := other.M[k]; ok {
intersection.M[k] = s.M[k]
}
}
return intersection
}
func (s HashedSet[K, T]) MarshalYAML() (interface{}, error) {
return s.ToSlice(), nil
}
func (s *HashedSet[K, T]) UnmarshalYAML(node *yaml.Node) error {
var slice []T
err := node.Decode(&slice)
if err != nil {
return err
}
s.Add(slice...)
return nil
}
package set
import (
"fmt"
"strings"
)
type Set[T comparable] map[T]struct{}
func SetOf[T comparable](vs ...T) Set[T] {
s := make(Set[T])
s.Add(vs...)
return s
}
func (s Set[T]) Add(vs ...T) {
for _, v := range vs {
s[v] = struct{}{}
}
}
func (s Set[T]) AddFrom(other Set[T]) {
for k := range other {
s[k] = struct{}{}
}
}
func (s Set[T]) Remove(v T) bool {
_, ok := s[v]
delete(s, v)
return ok
}
func (s Set[T]) Contains(v T) bool {
_, ok := s[v]
return ok
}
func (s Set[T]) ContainsAll(vs ...T) bool {
for _, v := range vs {
if !s.Contains(v) {
return false
}
}
return true
}
func (s Set[T]) ContainsAny(vs ...T) bool {
for _, v := range vs {
if s.Contains(v) {
return true
}
}
return false
}
func (s Set[T]) Len() int {
return len(s)
}
func (s Set[T]) ToSlice() []T {
slice := make([]T, 0, len(s))
for k := range s {
slice = append(slice, k)
}
return slice
}
func (s Set[T]) Union(other Set[T]) Set[T] {
union := make(Set[T])
for k := range s {
union[k] = struct{}{}
}
for k := range other {
union[k] = struct{}{}
}
return union
}
func (s Set[T]) Intersection(other Set[T]) Set[T] {
intersection := make(Set[T])
for k := range s {
if _, ok := other[k]; ok {
intersection.Add(k)
}
}
return intersection
}
func (s Set[T]) Difference(other Set[T]) Set[T] {
subtract := make(Set[T])
for k := range s {
if _, ok := other[k]; !ok {
subtract.Add(k)
}
}
return subtract
}
func (s Set[T]) String() string {
sb := new(strings.Builder)
sb.WriteString("{")
for i, k := range s.ToSlice() {
if i > 0 {
sb.WriteString(", ")
}
fmt.Fprintf(sb, "%v", k)
}
sb.WriteString("}")
return sb.String()
}
package templates
import (
"embed"
"github.com/klothoplatform/klotho/pkg/knowledgebase"
"github.com/klothoplatform/klotho/pkg/knowledgebase/reader"
)
//go:embed */resources/*.yaml
var ResourceTemplates embed.FS
//go:embed */edges/*.yaml
var EdgeTemplates embed.FS
//go:embed */models/*.yaml models/*.yaml
var Models embed.FS
func NewKBFromTemplates() (knowledgebase.TemplateKB, error) {
return reader.NewKBFromFs(ResourceTemplates, EdgeTemplates, Models)
}
package templateutils
import (
"embed"
"strings"
"text/template"
sprig "github.com/Masterminds/sprig/v3"
)
func MustTemplate(fs embed.FS, name string) *template.Template {
content, err := fs.ReadFile(name)
if err != nil {
panic(err)
}
t, err := template.New(name).
Funcs(mustTemplateFuncs).
Funcs(sprig.HermeticTxtFuncMap()).
Parse(string(content))
if err != nil {
panic(err)
}
return t
}
var mustTemplateFuncs = template.FuncMap{
"joinString": strings.Join,
"json": ToJSON,
"jsonPretty": ToJSONPretty,
"fileBase": FileBase,
"fileTrimExt": FileTrimExtFunc,
"fileSep": FileSep,
"replaceAll": ReplaceAll,
}
package templateutils
import (
"bytes"
"encoding/json"
"fmt"
"path/filepath"
"reflect"
"regexp"
"strings"
"text/template"
)
var UtilityFunctions = template.FuncMap{
"split": strings.Split,
"join": strings.Join,
"basename": filepath.Base,
"filterMatch": FilterMatch,
"mapString": MapString,
"zipToMap": ZipToMap,
"keysToMapWithDefault": KeysToMapWithDefault,
"replace": ReplaceAllRegex,
"replaceAll": ReplaceAll,
"hasSuffix": strings.HasSuffix,
"toLower": strings.ToLower,
"toUpper": strings.ToUpper,
"add": Add,
"sub": Sub,
"last": Last,
"makeSlice": MakeSlice,
"appendSlice": AppendSlice,
"sliceContains": SliceContains,
"matches": Matches,
"trimLeft": strings.TrimLeft,
"trimRight": strings.TrimRight,
"trimSpace": strings.TrimSpace,
"trimPrefix": strings.TrimPrefix,
"trimSuffix": strings.TrimSuffix,
}
func WithCommonFuncs(funcMap template.FuncMap) template.FuncMap {
for k, v := range UtilityFunctions {
funcMap[k] = v
}
return funcMap
}
// ToJSON converts any value to a JSON string.
func ToJSON(v any) (string, error) {
buf := new(bytes.Buffer)
enc := json.NewEncoder(buf)
if err := enc.Encode(v); err != nil {
return "", err
}
return strings.TrimSpace(buf.String()), nil
}
// ToJSONPretty converts any value to a pretty-printed JSON string.
func ToJSONPretty(v any) (string, error) {
buf := new(bytes.Buffer)
enc := json.NewEncoder(buf)
enc.SetIndent("", " ")
if err := enc.Encode(v); err != nil {
return "", err
}
return strings.TrimSpace(buf.String()), nil
}
// FileBase returns the last element of a filepath.
func FileBase(path string) string {
return filepath.Base(path)
}
// FileTrimExtFunc returns the path without the extension.
func FileTrimExtFunc(path string) string {
return strings.TrimSuffix(path, filepath.Ext(path))
}
// FileSep returns the separator for the current OS.
func FileSep() string {
return string(filepath.Separator)
}
// ReplaceAll replaces all occurrences of old with new in s.
func ReplaceAll(s string, old string, new string) string {
return strings.ReplaceAll(s, old, new)
}
// Matches returns true if the value matches the regex pattern.
func Matches(pattern, value string) (bool, error) {
re, err := regexp.Compile(pattern)
if err != nil {
return false, err
}
return re.MatchString(value), nil
}
// FilterMatch returns a json array by filtering the values array with the regex pattern
func FilterMatch(pattern string, values []string) ([]string, error) {
re, err := regexp.Compile(pattern)
if err != nil {
return nil, err
}
matches := make([]string, 0, len(values))
for _, v := range values {
if ok := re.MatchString(v); ok {
matches = append(matches, v)
}
}
return matches, nil
}
// MapString takes in a regex pattern and replacement as well as a json array of strings
// roughly `unmarshal value | sed s/pattern/replace/g | marshal`
func MapString(pattern, replace string, values []string) ([]string, error) {
re, err := regexp.Compile(pattern)
if err != nil {
return nil, err
}
nv := make([]string, len(values))
for i, v := range values {
nv[i] = re.ReplaceAllString(v, replace)
}
return nv, nil
}
// ZipToMap returns a json map by zipping the keys and values arrays
// Example: zipToMap(['a', 'b'], [1, 2]) => {"a": 1, "b": 2}
func ZipToMap(keys []string, valuesArg any) (map[string]any, error) {
// Have to use reflection here because technically, []string is not assignable to []any
// thanks Go.
valuesRefl := reflect.ValueOf(valuesArg)
if valuesRefl.Kind() != reflect.Slice && valuesRefl.Kind() != reflect.Array {
return nil, fmt.Errorf("values is not a slice or array")
}
if len(keys) != valuesRefl.Len() {
return nil, fmt.Errorf("key length (%d) != value length (%d)", len(keys), valuesRefl.Len())
}
m := make(map[string]any)
for i, k := range keys {
m[k] = valuesRefl.Index(i).Interface()
}
return m, nil
}
// KeysToMapWithDefault returns a json map by mapping the keys array to the static defaultValue
// Example keysToMapWithDefault(0, ['a', 'b']) => {"a": 0, "b": 0}
func KeysToMapWithDefault(defaultValue any, keys []string) (map[string]any, error) {
m := make(map[string]any)
for _, k := range keys {
m[k] = defaultValue
}
return m, nil
}
// Add returns the sum of all the arguments.
func Add(args ...int) int {
total := 0
for _, a := range args {
total += a
}
return total
}
// Sub returns the difference of all the arguments.
func Sub(args ...int) int {
if len(args) == 0 {
return 0
}
total := args[0]
for _, a := range args[1:] {
total -= a
}
return total
}
// Last returns the last element of a list.
func Last(list any) (any, error) {
v := reflect.ValueOf(list)
if v.Kind() != reflect.Slice && v.Kind() != reflect.Array {
return nil, fmt.Errorf("list is not a slice or array, is %s", v.Kind())
}
if v.Len() == 0 {
return nil, fmt.Errorf("list is empty")
}
return v.Index(v.Len() - 1).Interface(), nil
}
// ReplaceAllRegex replaces all occurrences of the regex pattern with the replace string in value.
func ReplaceAllRegex(pattern, replace, value string) (string, error) {
re, err := regexp.Compile(pattern)
if err != nil {
return "", err
}
s := re.ReplaceAllString(value, replace)
return s, nil
}
// MakeSlice creates and returns a new slice of any type with the given values.
func MakeSlice(args ...any) []any {
return args
}
// AppendSlice appends any number of values to a slice and returns the new slice.
func AppendSlice(slice []any, value ...any) []any {
return append(slice, value...)
}
// SliceContains checks if a slice contains a specific value.
func SliceContains(slice []any, value any) bool {
for _, v := range slice {
if v == value {
return true
}
}
return false
}
package testutil
func NewSet[T comparable](ss ...T) map[T]struct{} {
set := make(map[T]struct{}, len(ss))
for _, s := range ss {
set[s] = struct{}{}
}
return set
}
package testutil
import (
sitter "github.com/smacker/go-tree-sitter"
)
// FindNodeByContent returns a node whose content is the given content string. This could be a descendent of the given
// node. If there are multiple such nodes, this will return one of them, but it's unspecified which.
func FindNodeByContent(node *sitter.Node, content string) *sitter.Node {
if node.Content() == content {
return node
}
for childIdx := 0; childIdx < int(node.ChildCount()); childIdx += 1 {
child := node.Child(childIdx)
if result := FindNodeByContent(child, content); result != nil {
return result
}
}
return nil
}
package testutil
import (
"fmt"
"strings"
"github.com/vmware-labs/yaml-jsonpath/pkg/yamlpath"
"gopkg.in/yaml.v3"
)
// UnIndent removes a level of indentation from the given string. The rules are very simple:
// - first, drop any leading newlines from the string
// - then, find the indentation in the first line, which is defined by the leading tabs-or-spaces in that line
// - then, trim that prefix from all lines
//
// The prefix must match exactly, and this method removes it just by invoke [strings.CutPrefix] -- nothing fancier.
// In particular, this means that if you're mixing tabs and spaces, you may find yourself in for a bad time: if the
// first string uses "<space><tab>" and the second uses "<tab><space>", the second does not count as an indentation,
// and won't be affected.
//
// You can use this to embed yaml within test code as literal strings:
//
// ...
// SomeField: MyStruct {
// foo: unIndent(`
// hello: world
// counts:
// - 1
// - 2
// hello: world`
// ),
// ...
//
// The resulting string will be:
//
// ┌────────────┐ ◁─ no newline
// │hello: world│ ◁─╮
// │counts: │ ◁─┤
// │ - 1 │ ◁─┼─ no extra indentation
// │ - 2 │ ◁─┤
// │hello: world│ ◁─╯
// └────────────┘
func UnIndent(y string) string {
y = strings.TrimLeft(y, "\n")
tabsCount := 0
for ; tabsCount < len(y) && y[tabsCount] == '\t' || y[tabsCount] == ' '; tabsCount += 1 {
// nothing; the tabsCount += 1 is the side effect we want
}
prefixTabs := y[:tabsCount]
sb := strings.Builder{}
sb.Grow(len(y))
for _, line := range strings.Split(y, "\n") {
line, _ = strings.CutPrefix(line, prefixTabs)
sb.WriteString(line)
sb.WriteRune('\n')
}
return sb.String()
}
// YamlPath returns a subset of the given yaml file, as specified by its path. It uses [yamlpath] under the hood.
// See the [yamlpath's github page] for details, though the package's godocs are easier to read.
//
// tldr: `$.path.to.your[0].subdocument` (the `$` is literally a dollar sign you should use to anchor the path).
//
// This function expects there to be a single node result. If you want a list, select the list's parent instead.
//
// If there are any errors along the way, this will return `// ERROR: ${msg}`.
//
// [yamlpath's github page]: https://github.com/vmware-labs/yaml-jsonpath
func SafeYamlPath(yamlStr string, path string) string {
path_obj, err := yamlpath.NewPath(path)
if err != nil {
return fmt.Sprintf("// ERROR: %s", err)
}
var parsed_node yaml.Node
err = yaml.Unmarshal([]byte(yamlStr), &parsed_node)
if err != nil {
return fmt.Sprintf("// ERROR: %s", err)
}
found_nodes, err := path_obj.Find(&parsed_node)
if err != nil {
return fmt.Sprintf("// ERROR: %s", err)
}
if len(found_nodes) != 1 {
return fmt.Sprintf("// ERROR: expected exactly one match, but found %d", len(found_nodes))
}
result_bytes, err := yaml.Marshal(found_nodes[0])
if err != nil {
return fmt.Sprintf("// ERROR: %s", err)
}
return string(result_bytes)
}
package tui
import (
"os"
"strings"
tea "github.com/charmbracelet/bubbletea"
"github.com/klothoplatform/klotho/pkg/logging"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
type LogCore struct {
zapcore.Core
verbosity Verbosity
program *tea.Program
enc zapcore.Encoder
construct string
}
func NewLogCore(opts logging.LogOpts, verbosity Verbosity, program *tea.Program) zapcore.Core {
enc := opts.Encoder()
leveller := zap.NewAtomicLevel()
leveller.SetLevel(verbosity.LogLevel())
core := zapcore.NewCore(enc, os.Stderr, leveller)
core = &LogCore{
Core: core,
verbosity: verbosity,
program: program,
enc: enc,
}
core = opts.EntryLeveller(core)
core = opts.CategoryCore(core)
return core
}
func (c *LogCore) With(f []zapcore.Field) zapcore.Core {
nc := *c
nc.Core = c.Core.With(f)
nc.enc = c.enc.Clone()
for _, field := range f {
if field.Key == "construct" {
nc.construct = field.String
if c.verbosity.CombineLogs() {
field.AddTo(nc.enc)
}
// else (if the field is the construct, and we're not combining logs) don't add it to the encoder
// because the log lines will already be in its own construct section of the output.
} else {
field.AddTo(nc.enc)
}
}
return &nc
}
func (c *LogCore) Check(e zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
if c.Enabled(e.Level) {
return ce.AddCore(e, c)
}
return ce
}
func (c *LogCore) Write(ent zapcore.Entry, fields []zapcore.Field) error {
if c.verbosity.CombineLogs() {
buf, err := c.enc.EncodeEntry(ent, fields)
if err != nil {
return err
}
s := buf.String()
s = strings.TrimSuffix(s, "\n")
c.program.Println(s)
buf.Free()
return nil
}
construct := c.construct
nonConstructFields := make([]zapcore.Field, 0, len(fields))
for _, f := range fields {
if f.Key == "construct" {
construct = f.String
} else {
nonConstructFields = append(nonConstructFields, f)
}
}
buf, err := c.enc.EncodeEntry(ent, nonConstructFields)
if err != nil {
return err
}
s := buf.String()
s = strings.TrimSuffix(s, "\n")
if c.construct == "" && zapcore.ErrorLevel.Enabled(ent.Level) {
c.program.Send(ErrorMessage{
Message: s,
})
buf.Free()
return nil
}
c.program.Send(LogMessage{
Construct: construct,
Message: s,
})
buf.Free()
return nil
}
package tui
import (
"fmt"
"io"
"strings"
"sync"
"github.com/charmbracelet/bubbles/progress"
"github.com/charmbracelet/bubbles/spinner"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/mitchellh/go-wordwrap"
)
type (
model struct {
mu sync.Mutex
verbosity Verbosity
constructs map[string]*constructModel
constructOrder []string
consoleWidth int
constructWidth int
statusWidth int
errors []string
}
constructModel struct {
logs *RingBuffer[string]
outputs map[string]any
outputOrder []string
status string
hasProgress bool
complete bool
progress progress.Model
spinner spinner.Model
}
)
func NewModel(verbosity Verbosity) *model {
return &model{
verbosity: verbosity,
constructs: make(map[string]*constructModel),
}
}
func (m *model) Init() tea.Cmd {
return nil
}
func (m *model) constructModel(construct string) *constructModel {
cm, ok := m.constructs[construct]
if !ok {
cm = &constructModel{
outputs: make(map[string]any),
progress: progress.New(),
spinner: spinner.New(spinner.WithSpinner(spinner.Dot)),
}
cm.logs = NewRingBuffer[string](10)
m.constructs[construct] = cm
m.constructOrder = append(m.constructOrder, construct)
}
return cm
}
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.mu.Lock()
defer m.mu.Unlock()
var cmd tea.Cmd
switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.consoleWidth = msg.Width
case LogMessage:
if m.verbosity > 0 && msg.Message != "" {
cm := m.constructModel(msg.Construct)
cm.logs.Push(msg.Message)
}
case ErrorMessage:
m.errors = append(m.errors, msg.Message)
case UpdateMessage:
cm := m.constructModel(msg.Construct)
m.constructWidth = 0
m.statusWidth = 0
for c, cm := range m.constructs {
m.constructWidth = max(m.constructWidth, len(c))
m.statusWidth = max(m.statusWidth, len(cm.status))
}
cm.hasProgress = !msg.Indeterminate
cm.complete = msg.Complete
if cm.hasProgress {
cmd = cm.progress.SetPercent(msg.Percent)
} else {
cmd = tea.Batch(
cm.spinner.Tick,
// Reset the progress to 0. This isn't guaranteed, but if we switched from a progress
// to a spinner, most likely it's changing what's measured. By setting this to 0 now,
// we try to prevent the progress from "bouncing back" from the last value it was at (likely at 100%)
// due to the animation of the progress bar.
cm.progress.SetPercent(0.0),
)
}
cm.status = msg.Status
case OutputMessage:
cm := m.constructModel(msg.Construct)
if _, ok := cm.outputs[msg.Name]; !ok {
cm.outputOrder = append(cm.outputOrder, msg.Name)
}
cm.outputs[msg.Name] = msg.Value
case progress.FrameMsg:
cmds := make([]tea.Cmd, 0, len(m.constructs))
for _, cm := range m.constructs {
pm, cmd := cm.progress.Update(msg)
cm.progress = pm.(progress.Model)
cmds = append(cmds, cmd)
}
cmd = tea.Batch(cmds...)
case spinner.TickMsg:
cmds := make([]tea.Cmd, 0, len(m.constructs))
for _, cm := range m.constructs {
sm, cmd := cm.spinner.Update(msg)
cm.spinner = sm
cmds = append(cmds, cmd)
}
cmd = tea.Batch(cmds...)
}
return m, cmd
}
func (m *model) View() string {
m.mu.Lock()
defer m.mu.Unlock()
if m.verbosity == VerbosityConcise {
return m.viewCompact()
} else if m.verbosity.CombineLogs() {
return m.viewDebug()
} else {
return m.viewVerbose()
}
}
func (m *model) viewCompact() string {
sb := new(strings.Builder)
sb.WriteString(RenderLogo())
sb.WriteString("\n")
for _, c := range m.constructOrder {
cm := m.constructs[c]
if pad := m.constructWidth - len(c); pad > 0 {
sb.WriteString(strings.Repeat(" ", pad))
}
sb.WriteString(boxTitleStyle.Render(c))
sb.WriteString(" ")
if pad := m.statusWidth - len(cm.status); pad > 0 {
sb.WriteString(strings.Repeat(" ", pad))
}
sb.WriteString(cm.status)
sb.WriteString(" ")
switch {
case cm.complete:
// Do nothing
case cm.hasProgress:
sb.WriteString(cm.progress.View())
default:
sb.WriteString(cm.spinner.View())
}
sb.WriteRune('\n')
outputPad := strings.Repeat(" ", m.constructWidth-1)
for i, name := range cm.outputOrder {
out := cm.outputs[name]
sb.WriteString(outputPad)
if i == len(cm.outputs)-1 {
fmt.Fprintf(sb, "└ %s: %v\n", name, out)
} else {
fmt.Fprintf(sb, "├ %s: %v\n", name, out)
}
}
}
if len(m.constructs) > 0 && len(m.errors) > 0 {
sb.WriteRune('\n')
}
for _, log := range m.errors {
sb.WriteString(log)
sb.WriteRune('\n')
}
return sb.String()
}
var (
boxTitleStyle = lipgloss.NewStyle().Bold(true)
boxSectionHeadingStyle = lipgloss.NewStyle().Underline(true)
)
type boxLine struct {
Content string
NoWrap bool
}
// renderConstructBox renders a construct box to the given writer. There are 2 kinds of boxes:
// The general logs box:
//
// ┌ General
// ├ Logs
// ├ 12.1s DBG > populated default value ...
// ├ ...
// └ 40.6s DBG > Shutting down TUI
//
// And a construct box:
//
// ┌ my-api Success (dry run)
// ├ Logs
// ├ 29.3s DBG > AddEdge ...
// ├ ...
// ├ Outputs
// └ └ Endpoint: <aws:api_stage:my-api-api:my-api-stage#InvokeUrl>
//
// Long lines are wrapped to fit the console width:
//
// ├ 44.6s INF pulumi.preview >
// │ error: Preview failed: resource 'preview(id=aws:subnet:default-network-vpc:default-network-public-subnet-1)' does
// │ not exist
func (m *model) renderConstructBox(lines []boxLine, w io.Writer) {
write := func(s string) {
_, _ = w.Write([]byte(s))
}
for i, elem := range lines {
msg := elem.Content
if !elem.NoWrap {
msg = wordwrap.WrapString(elem.Content, uint(m.consoleWidth-4))
}
elemLines := strings.Split(msg, "\n")
for j, line := range elemLines {
switch {
case len(lines) == 1 && len(elemLines) == 1:
// Single line special case
write("─ ")
case i == 0 && j == 0:
// First line in the box
write("┌ ")
case i == len(lines)-1 && j == len(elemLines)-1:
// Last line in the box
write("└ ")
case j == 0:
// A list element
write("├ ")
default:
// A continuation line
write("│ ")
}
write(line + "\n")
}
}
}
func (m *model) viewVerbose() string {
sb := new(strings.Builder)
sb.WriteString(RenderLogo())
sb.WriteString("\n")
for _, c := range m.constructOrder {
cm := m.constructs[c]
var lines []boxLine
addLine := func(s string) { // convenience function because most lines are regular content
lines = append(lines, boxLine{Content: s})
}
if c == "" {
addLine(boxTitleStyle.Render("General"))
} else {
addLine(boxTitleStyle.Render(c) + " " + cm.status)
if !cm.complete {
// Don't use addLine in the following to disable wrapping
switch {
case cm.hasProgress:
lines = append(lines, boxLine{Content: cm.progress.View(), NoWrap: true})
default:
lines = append(lines, boxLine{Content: cm.spinner.View(), NoWrap: true})
}
}
}
cm.logs.ForEach(func(idx int, msg string) {
if idx == 0 {
// Only render the heading if there are actually logs to show
// Check inside the ForEach instead of using Len to prevent a race condition
addLine(boxSectionHeadingStyle.Render("Logs"))
}
addLine(msg)
})
for i, name := range cm.outputOrder {
if i == 0 {
addLine(boxSectionHeadingStyle.Render("Outputs"))
}
out := cm.outputs[name]
if i == len(cm.outputs)-1 {
addLine(fmt.Sprintf("└ %s: %v", name, out))
} else {
addLine(fmt.Sprintf("├ %s: %v", name, out))
}
}
m.renderConstructBox(lines, sb)
sb.WriteRune('\n') // extra newline to separate constructs
}
for _, log := range m.errors {
sb.WriteString(log)
sb.WriteRune('\n')
}
s := sb.String()
return s
}
func (m *model) viewDebug() string {
// for now, only difference is that the logs show in the top before the TUI,
// which is handled outside of the model. Don't show logs twice, so use the
// compact view.
return m.viewCompact()
}
package tui
import (
"context"
"github.com/klothoplatform/klotho/pkg/logging"
)
type contextKey string
var progressKey contextKey = "progress"
func GetProgress(ctx context.Context) Progress {
p := ctx.Value(progressKey)
if p == nil {
return LogProgress{Logger: logging.GetLogger(ctx).Named("progress")}
}
return p.(Progress)
}
func WithProgress(ctx context.Context, progress Progress) context.Context {
return context.WithValue(ctx, progressKey, progress)
}
package tui
import (
"context"
tea "github.com/charmbracelet/bubbletea"
)
type (
UpdateMessage struct {
Construct string
Status string
Percent float64
Indeterminate bool
Complete bool
}
LogMessage struct {
Construct string
Message string
}
OutputMessage struct {
Construct string
Name string
Value any
}
ErrorMessage struct {
Message string
}
TuiProgress struct {
Prog *tea.Program
Construct string
}
)
var programKey contextKey = "tui-prog"
func WithProgram(ctx context.Context, p *tea.Program) context.Context {
return context.WithValue(ctx, programKey, p)
}
func GetProgram(ctx context.Context) *tea.Program {
if prog := ctx.Value(programKey); prog != nil {
return prog.(*tea.Program)
}
return nil
}
func (p *TuiProgress) Update(status string, current, total int) {
p.Prog.Send(UpdateMessage{
Construct: p.Construct,
Status: status,
Percent: float64(current) / float64(total),
})
}
func (p *TuiProgress) UpdateIndeterminate(status string) {
p.Prog.Send(UpdateMessage{
Construct: p.Construct,
Status: status,
Indeterminate: true,
})
}
func (p *TuiProgress) Complete(status string) {
p.Prog.Send(UpdateMessage{
Construct: p.Construct,
Status: status,
Complete: true,
})
}
package tui
import (
"go.uber.org/zap"
)
type Progress interface {
// Update updates the progress status with the current and total count.
Update(status string, current, total int)
UpdateIndeterminate(status string)
Complete(status string)
}
type LogProgress struct {
Logger *zap.Logger
}
func (p LogProgress) Update(status string, current, total int) {
p.Logger.Sugar().Infof("%s %d/%d (%.1f%%)", status, current, total, float64(current)/float64(total)*100)
}
func (p LogProgress) UpdateIndeterminate(status string) {
p.Logger.Info(status)
}
func (p LogProgress) Complete(status string) {
p.Logger.Sugar().Debugf("Complete: %s", status)
}
package prompt
import (
"fmt"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/klothoplatform/klotho/pkg/tui"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"strings"
)
type MultiFlagPromptModel struct {
Prompts []FlagPromptModel
CurrentIndex int
Quit bool
Width int
Height int
FlagNames []string
Cmd *cobra.Command
Helpers map[string]Helper
PromptCreator func(string) FlagPromptModel
}
type FlagPromptModel struct {
TextInput textinput.Model
Flag *pflag.Flag
InitialValue string
Description string
IsRequired bool
FlagHelpers Helper
Completed bool
}
func (m MultiFlagPromptModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if m.CurrentIndex >= len(m.FlagNames) {
return m, tea.Quit
}
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.Type {
case tea.KeyUp:
// Go back to the previous prompt
if m.CurrentIndex == 0 {
return m, nil
}
m.Prompts[m.CurrentIndex].TextInput.Blur()
m.CurrentIndex--
m.Prompts[m.CurrentIndex].TextInput.Focus()
return m, nil
case tea.KeyDown, tea.KeyEnter:
if msg.Type == tea.KeyDown && m.CurrentIndex == len(m.Prompts)-1 {
// Down arrow on the last prompt should do nothing (down only submits if there's a next prompt already rendered)
return m, nil
}
currentPrompt := &m.Prompts[m.CurrentIndex]
value := currentPrompt.TextInput.Value()
if currentPrompt.TextInput.Err != nil {
return m, nil
}
if value == "" {
if currentPrompt.InitialValue != "" {
value = currentPrompt.InitialValue
} else if currentPrompt.IsRequired {
return m, nil
}
}
if currentPrompt.FlagHelpers.ValidateFunc != nil {
err := currentPrompt.FlagHelpers.ValidateFunc(value)
if err != nil {
currentPrompt.TextInput.Err = err
return m, nil
}
}
err := currentPrompt.Flag.Value.Set(value)
if err != nil {
return m, nil
}
currentPrompt.Flag.Changed = true
currentPrompt.Completed = true
currentPrompt.TextInput.Blur()
m.CurrentIndex++
if m.CurrentIndex >= len(m.FlagNames) {
// If we've completed all prompts, quit immediately
return m, tea.Quit
}
if m.CurrentIndex == len(m.Prompts) {
newPrompt := m.PromptCreator(m.FlagNames[m.CurrentIndex])
m.Prompts = append(m.Prompts, newPrompt)
}
m.Prompts[m.CurrentIndex].TextInput.Focus()
return m, nil
case tea.KeyCtrlC, tea.KeyEsc:
m.Quit = true
return m, tea.Quit
}
case tea.WindowSizeMsg:
m.Width = msg.Width
return m, tea.ClearScreen
}
var cmd tea.Cmd
m.Prompts[m.CurrentIndex].TextInput, cmd = m.Prompts[m.CurrentIndex].TextInput.Update(msg)
return m, cmd
}
func (m MultiFlagPromptModel) View() string {
var b strings.Builder
b.WriteString(tui.RenderLogo())
b.WriteString("\n\n")
b.WriteString("Please provide the following information to initialize your Klotho application:\n\n")
for i, prompt := range m.Prompts {
style := lipgloss.NewStyle()
if i == m.CurrentIndex {
style = style.Foreground(lipgloss.Color(tui.LogoColor))
} else {
style = style.Foreground(lipgloss.Color("230"))
}
initialValue := ""
if prompt.InitialValue != "" {
initialValueStyle := lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("36"))
styledInitialValue := initialValueStyle.Render(prompt.InitialValue)
initialValue = fmt.Sprintf(" [%s]", styledInitialValue)
}
styledError := ""
if prompt.TextInput.Err != nil {
styledError = lipgloss.NewStyle().Foreground(lipgloss.Color("9")).Render(fmt.Sprintf(" (%s)", prompt.TextInput.Err))
}
var promptView string
if prompt.Completed {
promptView = fmt.Sprintf("%s: %s", style.Render(prompt.Flag.Name), prompt.TextInput.View())
} else {
promptView = fmt.Sprintf("%s%s %s", style.Render(prompt.Flag.Name), initialValue, prompt.TextInput.View())
}
if styledError != "" {
promptView += styledError
}
b.WriteString(promptView + "\n")
}
for i := 0; i < m.Height-len(m.Prompts)-3; i++ {
b.WriteString("\n")
}
// Add footer
b.WriteString("\nPress Esc or Ctrl+C to quit")
return b.String()
}
type Helper struct {
SuggestionResolverFunc func(string) []string
ValidateFunc func(string) error
}
func CreatePromptModel(flag *pflag.Flag, flagHelpers Helper, isRequired bool) FlagPromptModel {
ti := textinput.New()
description := flag.Usage
if isRequired {
style := lipgloss.NewStyle().Bold(true)
requiredSuffix := style.Render(" (required)")
description += requiredSuffix
}
ti.Placeholder = description
ti.Validate = flagHelpers.ValidateFunc
ti.CharLimit = 156
ti.SetValue(flag.Value.String())
if flagHelpers.SuggestionResolverFunc != nil {
ti.SetSuggestions(flagHelpers.SuggestionResolverFunc(flag.Value.String()))
}
ti.ShowSuggestions = true
ti.Focus()
return FlagPromptModel{
TextInput: ti,
Flag: flag,
InitialValue: flag.Value.String(),
Description: description,
IsRequired: isRequired,
FlagHelpers: flagHelpers,
Completed: false,
}
}
func (m MultiFlagPromptModel) Init() tea.Cmd {
return textinput.Blink
}
package tui
type RingBuffer[T any] struct {
buf []T
head int
tail int
}
func NewRingBuffer[T any](size int) *RingBuffer[T] {
return &RingBuffer[T]{buf: make([]T, size)}
}
func (r *RingBuffer[T]) Push(v T) {
r.buf[r.head] = v
r.head = (r.head + 1) % len(r.buf)
if r.head == r.tail {
r.tail = (r.tail + 1) % len(r.buf)
}
}
func (r *RingBuffer[T]) Pop() T {
if r.head == r.tail {
return r.buf[r.head]
}
v := r.buf[r.tail]
r.tail = (r.tail + 1) % len(r.buf)
return v
}
func (r *RingBuffer[T]) Len() int {
if r.head >= r.tail {
return r.head - r.tail
}
return len(r.buf) - r.tail + r.head
}
func (r *RingBuffer[T]) Cap() int {
return len(r.buf)
}
func (r *RingBuffer[T]) Get(i int) (T, bool) {
if i >= r.Len() {
var zero T
return zero, false
}
if i < 0 {
i += r.Len()
}
return r.buf[(r.tail+i)%len(r.buf)], true
}
func (r *RingBuffer[T]) Set(i int, v T) bool {
if i >= r.Len() {
return false
}
if i < 0 {
i += r.Len()
}
r.buf[(r.tail+i)%len(r.buf)] = v
return true
}
func (r *RingBuffer[T]) Clear() {
r.head = 0
r.tail = 0
}
func (r *RingBuffer[T]) ForEach(f func(int, T)) {
for i := 0; i < r.Len(); i++ {
v, _ := r.Get(i)
f(i, v)
}
}
package tui
import (
_ "embed"
"github.com/charmbracelet/lipgloss"
)
const LogoColor = "#816FA6"
var (
//go:embed logo.txt
logo string
logoStyle = lipgloss.NewStyle().Foreground(lipgloss.Color(LogoColor))
)
func RenderLogo() string {
return logoStyle.Render(logo)
}
package tui
import "go.uber.org/zap/zapcore"
type Verbosity int
var (
VerbosityConcise Verbosity = 0
VerbosityVerbose Verbosity = 1
VerbosityDebug Verbosity = 2
VerbosityDebugMore Verbosity = 3
)
func (v Verbosity) LogLevel() zapcore.Level {
switch v {
case VerbosityConcise:
return zapcore.ErrorLevel
case VerbosityVerbose:
return zapcore.InfoLevel
case VerbosityDebug:
return zapcore.DebugLevel
default:
return zapcore.DebugLevel
}
}
// CombineLogs controls whether to show all logs commingled in the TUI.
// In other words, sorted by timestamp, not grouped by construct.
func (v Verbosity) CombineLogs() bool {
return VerbosityDebugMore == v
}
package updater
import (
"encoding/json"
"fmt"
"io"
"net/http"
"runtime"
"strings"
"github.com/coreos/go-semver/semver"
"github.com/fatih/color"
"github.com/gojek/heimdall/v7/httpclient"
"github.com/inconshreveable/go-update"
"github.com/pkg/errors"
"github.com/schollz/progressbar/v3"
"go.uber.org/zap"
)
var (
OS string = runtime.GOOS
Arch string = runtime.GOARCH
)
const (
DefaultServer string = "http://srv.klo.dev"
)
type Updater struct {
ServerURL string
// Stream is the update stream to check
Stream string
// CurrentStream is the stream this binary came from
CurrentStream string
Client *httpclient.Client
}
func selfUpdate(data io.Reader) error {
//TODO add signature verification if we want
return update.Apply(data, update.Options{})
}
// CheckUpdate compares the version of the klotho binary
// against the latest github release, returns true
// if the latest release is newer
func (u *Updater) CheckUpdate(currentVersion string) (bool, error) {
endpoint := fmt.Sprintf("%s/update/check-latest-version?stream=%s", u.ServerURL, u.Stream)
res, err := u.Client.Get(endpoint, nil)
if err != nil {
return false, fmt.Errorf("failed to query for latest version: %v", err)
}
if res.StatusCode != http.StatusOK {
return false, fmt.Errorf("failed to query for latest version, bad response from server: %d", res.StatusCode)
}
defer res.Body.Close()
result := make(map[string]string)
dec := json.NewDecoder(res.Body)
if err := dec.Decode(&result); err != nil {
return false, fmt.Errorf("failed to decode body: %v", err)
}
ver, ok := result["latest_version"]
if !ok {
return false, errors.New("no version found in result")
}
latestVersion, err := semver.NewVersion(ver)
if err != nil {
return false, fmt.Errorf("strange version received: %s", latestVersion)
}
currVersion, err := semver.NewVersion(strings.TrimPrefix(currentVersion, "v"))
if err != nil {
return false, fmt.Errorf("invalid version %s: %v", currentVersion, err)
}
// Given a stream "xxx:yyyy", the qualifier is the "xxx" and the tag is the "yyyy".
//
// (1) If the qualifiers are different, always update (this is to handle open <--> pro)
// Otherwise, check the cli's version against latest. This is a bit trickier:
//
// (2a) If the tags are the same, then either it's a specific version or it's a monotonic tag like "latest".
// • If it's a monotonic tag, we only want to perform upgrades. A downgrade would be a situation like if we gave
// someone a pre-release, in which case we don't want to downgrade them.
// • If it's a specific version, we can assume that the version will never change.
// • So in either case, we want to only perform upgrades.
// (2b) If the tags are different, then someone is either pinning to a specific version, or going from a pinned
// version to a monotonic version. In either case, we should allow downgrades. (Going from pinned to monotonic
// *may* be an incorrect downgrade, with a similar pre-release reason. But if someone has a pre-release, they
// shouldn't be worrying about any upgrade stuff, including not changing their update stream from pinned to
// monotonic.)
// case (1): different qualifiers always update
if strings.Split(u.CurrentStream, ":")[0] != strings.Split(u.Stream, ":")[0] {
return true, nil
}
// the qualifiers are the same, so the tags are the same iff the full stream strings are the same
if u.CurrentStream == u.Stream {
return currVersion.LessThan(*latestVersion), nil // case (2a): only upgrades
} else {
return !currVersion.Equal(*latestVersion), nil // case (2b): upgrades or downgrades
}
}
// Update performs an update if a newer version is
// available
func (u *Updater) Update(currentVersion string) error {
doUpdate, err := u.CheckUpdate(currentVersion)
if err != nil {
zap.S().Errorf(`error checking for updates on stream "%s": %v`, u.Stream, err)
return err
}
if !doUpdate {
zap.S().Infof(`already up to date on stream "%s".`, u.Stream)
return nil
}
resp, err := u.getLatest()
if err != nil {
return errors.Wrapf(err, "failed to get latest")
}
if resp.Body == nil {
return errors.New("No response body from download")
}
defer resp.Body.Close()
var body io.Reader = resp.Body
if !color.NoColor {
// Use NoColor as an indicator of whether the output
// is a terminal or not. It's not perfect (the env var "NO_COLOR")
// but it's close enough.
bar := progressbar.DefaultBytes(
resp.ContentLength,
"downloading",
)
teeToBar := progressbar.NewReader(body, bar)
body = &teeToBar
}
if err := selfUpdate(body); err != nil {
return errors.Wrapf(err, "failed to update klotho")
}
zap.S().Infof(`updated to the latest version on stream "%s"`, u.Stream)
return nil
}
// getLatest Grabs latest release from klotho server
func (u *Updater) getLatest() (*http.Response, error) {
endpoint := fmt.Sprintf("%s/update/latest/%s/%s?stream=%s", u.ServerURL, OS, Arch, u.Stream)
res, err := u.Client.Get(endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to query for latest version: %v", err)
}
if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to query for latest version, bad response from server: %d", res.StatusCode)
}
return res, nil
}
package visualizer
import (
"fmt"
"io"
"sort"
"strings"
construct "github.com/klothoplatform/klotho/pkg/construct"
klotho_io "github.com/klothoplatform/klotho/pkg/io"
"github.com/klothoplatform/klotho/pkg/ioutil"
"gopkg.in/yaml.v3"
)
const indent = " "
type (
File struct {
FilenamePrefix string
AppName string
Provider string
Graph VisGraph
}
)
func (f *File) Path() string {
return fmt.Sprintf("%stopology.yaml", f.FilenamePrefix)
}
func (f *File) Clone() klotho_io.File {
return f
}
func (f *File) WriteTo(w io.Writer) (n int64, err error) {
wh := ioutil.NewWriteToHelper(w, &n, &err)
wh.Writef("provider: %s\n", f.Provider)
wh.Write("resources:\n")
resourceIds, err := construct.TopologicalSort(f.Graph)
if err != nil {
return
}
adj, err := f.Graph.AdjacencyMap()
if err != nil {
return
}
for _, id := range resourceIds {
res, err := f.Graph.Vertex(id)
if err != nil {
return n, err
}
src := f.KeyFor(id)
if src == "" {
continue
}
wh.Writef(indent+"%s:\n", src)
props := make(map[string]any)
if res.Tag != "" {
props["tag"] = res.Tag
}
if !res.Parent.IsZero() {
props["parent"] = f.KeyFor(res.Parent)
}
if len(res.Children) > 0 {
children := res.Children.ToSlice()
sort.Sort(construct.SortedIds(children))
props["children"] = children
}
if len(props) > 0 {
writeYaml(props, 2, wh)
} else {
wh.Write("\n")
}
deps := adj[id]
downstream := make([]construct.ResourceId, 0, len(deps))
for dep := range deps {
downstream = append(downstream, dep)
}
sort.Sort(construct.SortedIds(downstream))
for _, dep := range downstream {
dst := f.KeyFor(dep)
if src != "" && dst != "" {
wh.Writef(indent+"%s -> %s:\n", src, dst)
}
dep, err := f.Graph.Edge(id, dep)
if err != nil {
return n, err
}
if dep.Properties.Data != nil {
writeYaml(dep.Properties.Data, 2, wh)
}
}
}
return
}
func (f *File) KeyFor(res construct.ResourceId) string {
resId := res
var providerInfo string
var namespaceInfo string
if resId.Provider != f.Provider || resId.Namespace != "" {
providerInfo = resId.Provider + `:`
}
if resId.Namespace != "" {
namespaceInfo = ":" + resId.Namespace
}
return strings.ToLower(fmt.Sprintf("%s%s%s/%s", providerInfo, res.Type, namespaceInfo, resId.Name))
}
func writeYaml(e any, indentCount int, out ioutil.WriteToHelper) {
bs, err := yaml.Marshal(e)
if err != nil {
out.AddErr(err)
return
}
for _, line := range strings.Split(string(bs), "\n") {
if strings.TrimSpace(line) != "" {
for i := 0; i < indentCount; i++ {
out.Write(indent)
}
}
out.Write(line)
out.Write("\n")
}
}
package visualizer
import (
"bytes"
"fmt"
"io"
"net/http"
"github.com/klothoplatform/klotho/pkg/cli_config"
construct "github.com/klothoplatform/klotho/pkg/construct"
klotho_io "github.com/klothoplatform/klotho/pkg/io"
)
type Plugin struct {
AppName string
Provider string
Client *http.Client
}
type (
visApi struct {
client *http.Client
buf bytes.Buffer
}
httpStatusBad int
)
// Name implements compiler.Plugin
func (p Plugin) Name() string {
return "visualizer"
}
var visualizerBaseUrlEnv = cli_config.EnvVar("KLOTHO_VIZ_URL_BASE")
var visualizerBaseUrl = visualizerBaseUrlEnv.GetOr("https://viz.klo.dev")
func (a *visApi) request(method string, path string, contentType string, accept string, f io.WriterTo) ([]byte, error) {
a.buf.Reset()
_, err := f.WriteTo(&a.buf)
if err != nil {
return nil, err
}
req, err := http.NewRequest(method, visualizerBaseUrl+`/api/v1/`+path, &a.buf)
if err != nil {
return nil, err
}
if contentType != "" {
req.Header.Set("Content-Type", contentType)
}
if accept != "" {
req.Header.Set("Accept", accept)
}
resp, err := a.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
a.buf.Reset()
_, err = a.buf.ReadFrom(resp.Body)
if err != nil && resp.StatusCode < 200 || resp.StatusCode > 299 {
err = httpStatusBad(resp.StatusCode)
}
return a.buf.Bytes(), err
}
// Translate implements compiler.IaCPlugin - although it's not strictly an IaC plugin, it uses the same API
func (p Plugin) Translate(dag construct.Graph) ([]klotho_io.File, error) {
api := visApi{client: p.Client}
var err error
spec := &File{
AppName: p.AppName,
Provider: p.Provider,
}
spec.Graph, err = ConstructToVis(dag)
if err != nil {
return nil, err
}
resp, err := api.request(http.MethodPost, `generate-infra-diagram`, "application/yaml", "image/png", spec)
if err != nil {
return nil, err
}
diagram := &klotho_io.RawFile{
FPath: "diagram.png",
Content: resp,
}
return []klotho_io.File{
spec,
diagram,
}, nil
}
// Translate implements compiler.IaCPlugin - although it's not strictly an IaC plugin, it uses the same API
func (p Plugin) Generate(dag construct.Graph, filenamePrefix string) ([]klotho_io.File, error) {
api := visApi{client: p.Client}
var err error
spec := &File{
FilenamePrefix: fmt.Sprintf("%s-", filenamePrefix),
AppName: p.AppName,
Provider: p.Provider,
}
spec.Graph, err = ConstructToVis(dag)
if err != nil {
return nil, err
}
resp, err := api.request(http.MethodPost, `generate-infra-diagram`, "application/yaml", "image/png", spec)
if err != nil {
return nil, err
}
diagram := &klotho_io.RawFile{
FPath: fmt.Sprintf("%s-diagram.png", filenamePrefix),
Content: resp,
}
return []klotho_io.File{
spec,
diagram,
}, nil
}
func (h httpStatusBad) Error() string {
return fmt.Sprintf("visualizer returned status code %d", h)
}
package visualizer
import (
"errors"
"sort"
"github.com/dominikbraun/graph"
construct "github.com/klothoplatform/klotho/pkg/construct"
"github.com/klothoplatform/klotho/pkg/graph_addons"
"github.com/klothoplatform/klotho/pkg/set"
)
type (
VisResource struct {
ID construct.ResourceId
Tag string
Parent construct.ResourceId
Children set.Set[construct.ResourceId]
}
VisEdgeData struct {
PathResources set.Set[construct.ResourceId]
}
VisGraph graph.Graph[construct.ResourceId, *VisResource]
)
func NewVisGraph(options ...func(*graph.Traits)) VisGraph {
return VisGraph(graph.NewWithStore(
func(r *VisResource) construct.ResourceId { return r.ID },
graph_addons.NewMemoryStore[construct.ResourceId, *VisResource](),
append(options,
graph.Directed(),
)...,
))
}
func ConstructToVis(g construct.Graph) (VisGraph, error) {
adj, err := g.AdjacencyMap()
if err != nil {
return nil, err
}
vis := NewVisGraph()
var errs error
for id := range adj {
errs = errors.Join(errs, vis.AddVertex(&VisResource{ID: id}))
}
if errs != nil {
return nil, errs
}
for source, targets := range adj {
for target := range targets {
errs = errors.Join(errs, vis.AddEdge(source, target))
}
}
return vis, errs
}
func (d VisEdgeData) MarshalYAML() (interface{}, error) {
res := d.PathResources.ToSlice()
sort.Sort(construct.SortedIds(res))
return map[string]any{
// TODO infacopilot frontend currently just uses 'path' as the colledction of
// additional resources to show in the graph for that edge. We have more information
// we could give, but for compatibility until more is added to the frontend, just flatten
// everything and call it 'path'.
"path": res,
}, nil
}
func VertexAncestors(g VisGraph, id construct.ResourceId) (set.Set[construct.ResourceId], error) {
ancestors := make(set.Set[construct.ResourceId])
var err error
for ancestor := id; !ancestor.IsZero() && err == nil; {
ancestors.Add(ancestor)
var ancestorVert *VisResource
ancestorVert, err = g.Vertex(ancestor)
if ancestorVert != nil {
ancestor = ancestorVert.Parent
}
}
return ancestors, err
}
package yaml_util
import (
"errors"
"fmt"
"sort"
"gopkg.in/yaml.v3"
)
var nullNode = &yaml.Node{
Kind: yaml.ScalarNode,
Tag: "!!null",
Value: "",
}
func MarshalMap[K comparable, V any](m map[K]V, less func(K, K) bool) (*yaml.Node, error) {
if len(m) == 0 {
return nullNode, nil
}
keys := make([]K, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Slice(
keys,
func(i, j int) bool { return less(keys[i], keys[j]) },
)
node := &yaml.Node{Kind: yaml.MappingNode}
var errs error
for _, k := range keys {
var v any = m[k]
var keyNode yaml.Node
if err := keyNode.Encode(k); err != nil {
errs = errors.Join(errs, fmt.Errorf("failed to encode key %v: %w", k, err))
continue
}
var valueNode yaml.Node
switch v := v.(type) {
case *yaml.Node:
valueNode = *v
default:
if err := valueNode.Encode(m[k]); err != nil {
errs = errors.Join(errs, fmt.Errorf("failed to encode value for %v: %w", k, err))
continue
}
}
node.Content = append(
node.Content,
&keyNode,
&valueNode,
)
}
return node, nil
}
package yaml_util
import "gopkg.in/yaml.v3"
type RawNode struct{ *yaml.Node }
func (n *RawNode) UnmarshalYAML(value *yaml.Node) error {
n.Node = value
return nil
}
package yaml_util
import (
"bytes"
"github.com/pkg/errors"
"gopkg.in/yaml.v3"
"strings"
)
type CheckMode bool
const (
Lenient = CheckMode(false)
Strict = CheckMode(true)
)
// SetValue upserts the value the content yaml to a given value, specified by a dotted path. For example, setting
// `foo.bar.baz` to `hello, world` is equivalent to upserting the following yaml:
//
// foo:
// bar:
// baz: hello, world
//
// This method will make a best effort to preserve comments, as per the `yaml` package's abilities. You may overwrite
// scalars, but you may not overwrite a non-scalar. You may also specify a path that doesn't exist in the source yaml,
// as long as none of the paths correspond to existing elements other than yaml mappings.
func SetValue(content []byte, optionPath string, optionValue string) ([]byte, error) {
// General approach:
// 1) convert the yaml into a map, using yaml
// 2a) find the node tree's value at the specified path
// 2b) set that node's value, assuming it's a scalar (or empty)
// 2) write the node tree back into bytes. this will preserve comments and such
// step 1
var tree yaml.Node
if err := yaml.Unmarshal(content, &tree); err != nil {
return nil, err
}
// step 2a
segments := strings.Split(optionPath, ".")
var topNode *yaml.Node
if len(tree.Content) == 0 {
topNode = &yaml.Node{Kind: yaml.MappingNode}
} else {
topNode = tree.Content[0] // the tree's root is a DocumentNode; we assume one document
}
setOptionAtNode := topNode
for _, segment := range segments[:len(segments)-1] {
if setOptionAtNode.Kind != yaml.MappingNode {
return nil, errors.Errorf(`can't set the path "%s"'`, optionPath)
}
if child := findChild(setOptionAtNode.Content, segment); child != nil {
setOptionAtNode = child
} else {
newSubMap := &yaml.Node{Kind: yaml.MappingNode}
setOptionAtNode.Content = append(setOptionAtNode.Content, &yaml.Node{
Kind: yaml.ScalarNode,
Value: segment,
})
setOptionAtNode.Content = append(setOptionAtNode.Content, newSubMap)
setOptionAtNode = newSubMap
}
}
// step 2b
if setOptionAtNode.Kind != yaml.MappingNode {
return nil, errors.Errorf(`can't set the path "%s"'`, optionPath)
}
lastSegment := segments[len(segments)-1]
if currValue := findChild(setOptionAtNode.Content, lastSegment); currValue != nil {
if currValue.Kind != yaml.ScalarNode {
return nil, errors.Errorf(`"%s" cannot be a scalar`, optionPath)
}
currValue.Tag = "" // if the existing type isn't a string, we want to reset it
currValue.Value = optionValue
} else {
setOptionAtNode.Content = append(setOptionAtNode.Content, &yaml.Node{
Kind: yaml.ScalarNode,
Value: lastSegment,
})
setOptionAtNode.Content = append(setOptionAtNode.Content, &yaml.Node{
Kind: yaml.ScalarNode,
Value: optionValue,
})
}
// step 3
return yaml.Marshal(topNode)
}
// CheckValid validates that the given yaml actually represents the type provided, and returns a non-nil error
// describing the problem if it doesn't. You need to explicitly provide the type to be checked:
//
// CheckValid[MyCoolType](contents)
//
// The strict flag governs whether the check will allow unknown fields.
func CheckValid[T any](content []byte, mode CheckMode) error {
if strings.TrimSpace(string(content)) == "" {
// the decoder will fail on this (EOF), but we want to consider it valid yaml
return nil
}
var ignored T
decoder := yaml.NewDecoder(bytes.NewReader(content))
decoder.KnownFields(bool(mode))
return decoder.Decode(&ignored)
}
// YamlErrors returns the yaml.TypeError errors if the given err is a TypeError; otherwise, it just returns a
// single-element array of the given error's string (disregarding any wrapped errors).
func YamlErrors(err error) []string {
switch err := err.(type) {
case *yaml.TypeError:
return err.Errors
default:
return []string{err.Error()}
}
}
func findChild(within []*yaml.Node, named string) *yaml.Node {
for i := 0; i < len(within); i += 2 {
node := within[i]
if node.Kind == yaml.ScalarNode && node.Value == named {
return within[i+1]
}
}
return nil
}