// Package main provides the entry point for the pulumi-gcp-vertex-model-deployment provider. package main import ( "context" "fmt" "os" "github.com/pulumi/pulumi-go-provider/infer" "github.com/davidmontoyago/pulumi-gcp-vertex-model-deployment/pkg/resources" "github.com/davidmontoyago/pulumi-gcp-vertex-model-deployment/pkg/version" ) func main() { provider, err := infer.NewProviderBuilder(). WithResources( infer.Resource(&resources.VertexModelDeployment{}), ). WithNamespace("davidmontoyago"). WithDisplayName("pulumi-gcp-vertex-model-deployment"). WithLicense("Apache-2.0"). WithKeywords("pulumi", "gcp", "vertex", "model"). WithDescription("Deploy AI models to Vertex endpoints"). WithRepository("github.com/davidmontoyago/pulumi-gcp-vertex-model-deployment"). WithPluginDownloadURL("github://api.github.com/davidmontoyago/pulumi-gcp-vertex-model-deployment"). Build() if err != nil { fmt.Fprintf(os.Stderr, "Error building provider: %s", err.Error()) os.Exit(1) } // Name of the pulumi plugin. pluginName := "gcp-vertex-model-deployment" err = provider.Run(context.Background(), pluginName, version.Version) if err != nil { fmt.Fprintf(os.Stderr, "Error running provider: %s", err.Error()) os.Exit(1) } }
package resources import ( "context" "github.com/davidmontoyago/pulumi-gcp-vertex-model-deployment/pkg/services" "github.com/pulumi/pulumi-go-provider/infer" ) func readEndpointModel(ctx context.Context, endpointClient services.VertexEndpointClient, req infer.ReadRequest[VertexModelDeploymentArgs, VertexModelDeploymentState], state *VertexModelDeploymentState) error { endpointGetter := services.NewVertexEndpointModelGetter(endpointClient, req.State.ProjectID, req.State.Region) endpoint, foundDeployedModel, err := endpointGetter.Get(ctx, req.State.EndpointName, req.State.DeployedModelID) if err != nil { return err } if foundDeployedModel == nil { // Model is no longer deployed - return empty response to indicate resource doesn't exist return nil } // Update state with current endpoint and deployed model information state.EndpointName = endpoint.Name state.DeployedModelID = foundDeployedModel.Id // Update endpoint deployment configuration with current values if available if state.EndpointModelDeployment == nil { return nil } // Extract current deployment configuration from the deployed model if dedicatedResources := foundDeployedModel.GetDedicatedResources(); dedicatedResources != nil { if machineSpec := dedicatedResources.MachineSpec; machineSpec != nil { state.EndpointModelDeployment.MachineType = machineSpec.MachineType } state.EndpointModelDeployment.MinReplicas = int(dedicatedResources.MinReplicaCount) state.EndpointModelDeployment.MaxReplicas = int(dedicatedResources.MaxReplicaCount) } // Update traffic percentage from endpoint's traffic split if available if endpoint.TrafficSplit != nil { if trafficPercent, exists := endpoint.TrafficSplit[foundDeployedModel.Id]; exists { state.EndpointModelDeployment.TrafficPercent = int(trafficPercent) } } return nil }
// Package resources provides Pulumi resource implementations for GCP Vertex AI model upload and deployment. package resources import "github.com/pulumi/pulumi-go-provider/infer" // VertexModelDeploymentArgs defines the input arguments for creating a Vertex AI model deployment. type VertexModelDeploymentArgs struct { ProjectID string `pulumi:"projectId"` Region string `pulumi:"region"` ModelImageURL string `pulumi:"modelImageUrl"` ModelArtifactsBucketURI string `pulumi:"modelArtifactsBucketUri"` ModelPredictionInputSchemaURI string `pulumi:"modelPredictionInputSchemaUri"` ModelPredictionOutputSchemaURI string `pulumi:"modelPredictionOutputSchemaUri"` ModelPredictionBehaviorSchemaURI string `pulumi:"modelPredictionBehaviorSchemaUri,optional"` // If ModelImage is pointing to a private registry, this service account // must have read access to the registry. ServiceAccount string `pulumi:"serviceAccount"` // Path on the container to send prediction requests to. // Not required for Endpoints. PredictRoute string `pulumi:"predictRoute,optional"` // Path on the container to send health requests to. // Not required for Endpoints. HealthRoute string `pulumi:"healthRoute,optional"` // Target endpoint for the model deployment. // // Set only when serving the model on a Vertex AI Endpoint. // Deploying a model to a Vertex AI Endpoint is not yet supported by Terraform // nor the Pulumi Google Cloud Native provider, hence, this custom provider // exists. // See: https://github.com/hashicorp/terraform-provider-google/issues/15303 // // When deploying the model as a Batched Prediction Job, this field must be // unset and the batch job must be created using the Pulumi Google Cloud Native // provider. EndpointModelDeployment *EndpointModelDeploymentArgs `pulumi:"endpointModelDeployment,optional"` Labels map[string]string `pulumi:"labels,optional"` } // EndpointModelDeploymentArgs defines the input arguments for deploying an // uploaded model to a Vertex AI endpoint. type EndpointModelDeploymentArgs struct { EndpointID string `pulumi:"endpointId"` MachineType string `pulumi:"machineType,optional"` MinReplicas int `pulumi:"minReplicas,optional"` MaxReplicas int `pulumi:"maxReplicas,optional"` TrafficPercent int `pulumi:"trafficPercent,optional"` } // Annotate provides metadata and default values for the VertexModelDeploymentArgs. func (args *VertexModelDeploymentArgs) Annotate(annotator infer.Annotator) { annotator.Describe(&args.ProjectID, "Google Cloud Project ID") annotator.Describe(&args.Region, "Google Cloud region") annotator.Describe(&args.ModelImageURL, "Vertex AI Image URL of a custom or prebuilt container model server. See: https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers") annotator.Describe(&args.ModelArtifactsBucketURI, "Bucket URI to the model artifacts. For instance, gs://my-bucket/my-model-artifacts/ - See: https://cloud.google.com/vertex-ai/docs/training/exporting-model-artifacts") annotator.Describe(&args.ModelPredictionInputSchemaURI, "Bucket URI to the schema for the model input") annotator.Describe(&args.ModelPredictionOutputSchemaURI, "Bucket URI to the schema for the model output") annotator.Describe(&args.ModelPredictionBehaviorSchemaURI, "Bucket URI to the schema for the model inference behavior") annotator.Describe(&args.ServiceAccount, "Service account for the model. If ModelImage is pointing to a private registry, this service account must have read access to the registry.") annotator.Describe(&args.EndpointModelDeployment, "Configuration for deploying the model to a Vertex AI endpoint. Leave empty to upload model only for batched predictions.") annotator.Describe(&args.Labels, "Labels for the deployment") } // Annotate provides metadata and default values for the EndpointModelDeploymentArgs. func (args *EndpointModelDeploymentArgs) Annotate(annotator infer.Annotator) { annotator.Describe(&args.EndpointID, "Vertex AI Endpoint ID") annotator.Describe(&args.MachineType, "Machine type for deployment") annotator.Describe(&args.MinReplicas, "Minimum number of replicas") annotator.Describe(&args.MaxReplicas, "Maximum number of replicas") annotator.Describe(&args.TrafficPercent, "Traffic percentage for this deployment") // Set defaults annotator.SetDefault(&args.MachineType, "n1-standard-2") annotator.SetDefault(&args.MinReplicas, 1) annotator.SetDefault(&args.MaxReplicas, 3) annotator.SetDefault(&args.TrafficPercent, 100) }
package resources import ( "context" "fmt" "github.com/davidmontoyago/pulumi-gcp-vertex-model-deployment/pkg/services" "github.com/pulumi/pulumi-go-provider/infer" ) func readRegistryModel(ctx context.Context, modelClient services.VertexModelClient, req infer.ReadRequest[VertexModelDeploymentArgs, VertexModelDeploymentState], state *VertexModelDeploymentState) error { modelGetter := services.NewVertexModelGet(ctx, modelClient, req.State.ModelName) model, err := modelGetter.Get(ctx, req.State.ModelName) if err != nil { return fmt.Errorf("failed to get model: %w", err) } // Update state with current model values state.ModelName = model.Name state.ModelArtifactsBucketURI = model.ArtifactUri state.Labels = model.Labels // Safely access ContainerSpec fields if model.ContainerSpec != nil { state.ModelImageURL = model.ContainerSpec.ImageUri state.PredictRoute = model.ContainerSpec.PredictRoute state.HealthRoute = model.ContainerSpec.HealthRoute } // Read schema URIs if available. if model.PredictSchemata != nil { // These are immutable and should be ignored during updates. // URI given on output will be immutable and probably different, // including the URI scheme, than the one given on input. // See: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/PredictSchemata state.ModelPredictionInputSchemaURI = model.PredictSchemata.InstanceSchemaUri state.ModelPredictionOutputSchemaURI = model.PredictSchemata.PredictionSchemaUri state.ModelPredictionBehaviorSchemaURI = model.PredictSchemata.ParametersSchemaUri } return nil }
package resources import ( "context" "fmt" "cloud.google.com/go/aiplatform/apiv1/aiplatformpb" "github.com/davidmontoyago/pulumi-gcp-vertex-model-deployment/pkg/services" "github.com/pulumi/pulumi-go-provider/infer" "google.golang.org/protobuf/types/known/fieldmaskpb" ) func updateRegistryModel(ctx context.Context, req infer.UpdateRequest[VertexModelDeploymentArgs, VertexModelDeploymentState], modelClient services.VertexModelClient, updatePaths []string) (*aiplatformpb.Model, error) { predictionSchema := &aiplatformpb.PredictSchemata{ InstanceSchemaUri: req.Inputs.ModelPredictionInputSchemaURI, PredictionSchemaUri: req.Inputs.ModelPredictionOutputSchemaURI, } if req.Inputs.ModelPredictionBehaviorSchemaURI != "" { predictionSchema.ParametersSchemaUri = req.Inputs.ModelPredictionBehaviorSchemaURI } // Build container spec (consistent with model creation) containerSpec := &aiplatformpb.ModelContainerSpec{ ImageUri: req.Inputs.ModelImageURL, // TODO add args input parameter Args: []string{ "--allow_precompilation=false", "--disable_optimizer=true", "--saved_model_tags='serve,tpu'", "--use_tfrt=true", }, } if req.Inputs.PredictRoute != "" { containerSpec.PredictRoute = req.Inputs.PredictRoute } if req.Inputs.HealthRoute != "" { containerSpec.HealthRoute = req.Inputs.HealthRoute } updatedModel, err := modelClient.UpdateModel(ctx, &aiplatformpb.UpdateModelRequest{ Model: &aiplatformpb.Model{ Name: req.State.ModelName, DisplayName: req.ID, // Use resource ID as display name (consistent with creation) Description: "Uploaded model for " + req.Inputs.ModelImageURL, // Consistent with creation Labels: req.Inputs.Labels, ArtifactUri: req.Inputs.ModelArtifactsBucketURI, ContainerSpec: containerSpec, PredictSchemata: predictionSchema, }, UpdateMask: &fieldmaskpb.FieldMask{ Paths: updatePaths, }, }) if err != nil { return nil, fmt.Errorf("failed to update model: %w", err) } return updatedModel, nil } func setModelStateUpdates(req infer.UpdateRequest[VertexModelDeploymentArgs, VertexModelDeploymentState], updatedModel *aiplatformpb.Model) VertexModelDeploymentState { updatedState := VertexModelDeploymentState{ VertexModelDeploymentArgs: req.Inputs, DeployedModelID: req.State.DeployedModelID, ModelName: updatedModel.Name, EndpointName: req.State.EndpointName, CreateTime: req.State.CreateTime, } // Update state fields from the updated model response if updatedModel.Labels != nil { updatedState.Labels = updatedModel.Labels } // Update ModelArtifactsBucketURI from the model response if updatedModel.ArtifactUri != "" { updatedState.ModelArtifactsBucketURI = updatedModel.ArtifactUri } // Update container spec fields if available if updatedModel.ContainerSpec != nil { // ImageUri is immutable, requires replacement. // See: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/ModelContainerSpec if updatedModel.ContainerSpec.PredictRoute != "" { updatedState.PredictRoute = updatedModel.ContainerSpec.PredictRoute } if updatedModel.ContainerSpec.HealthRoute != "" { updatedState.HealthRoute = updatedModel.ContainerSpec.HealthRoute } } // Update predict schemata fields if available if updatedModel.PredictSchemata != nil { if updatedModel.PredictSchemata.InstanceSchemaUri != "" { updatedState.ModelPredictionInputSchemaURI = updatedModel.PredictSchemata.InstanceSchemaUri } if updatedModel.PredictSchemata.PredictionSchemaUri != "" { updatedState.ModelPredictionOutputSchemaURI = updatedModel.PredictSchemata.PredictionSchemaUri } if updatedModel.PredictSchemata.ParametersSchemaUri != "" { updatedState.ModelPredictionBehaviorSchemaURI = updatedModel.PredictSchemata.ParametersSchemaUri } } return updatedState } func collectUpdates(req infer.UpdateRequest[VertexModelDeploymentArgs, VertexModelDeploymentState]) (bool, []string) { needsUpdate := false updatePathsMap := make(map[string]bool) // Check if labels have changed if !mapsEqual(req.Inputs.Labels, req.State.Labels) { needsUpdate = true updatePathsMap["labels"] = true } // Check if ModelImageURL has changed (affects description and container spec) if req.Inputs.ModelImageURL != req.State.ModelImageURL { needsUpdate = true updatePathsMap["description"] = true updatePathsMap["container_spec"] = true } // Check if ModelArtifactsBucketURI has changed if req.Inputs.ModelArtifactsBucketURI != req.State.ModelArtifactsBucketURI { needsUpdate = true updatePathsMap["artifact_uri"] = true } // Check if prediction schema URIs have changed if req.Inputs.ModelPredictionInputSchemaURI != req.State.ModelPredictionInputSchemaURI || req.Inputs.ModelPredictionOutputSchemaURI != req.State.ModelPredictionOutputSchemaURI || req.Inputs.ModelPredictionBehaviorSchemaURI != req.State.ModelPredictionBehaviorSchemaURI { needsUpdate = true updatePathsMap["predict_schemata"] = true } // Check if container routes have changed if req.Inputs.PredictRoute != req.State.PredictRoute || req.Inputs.HealthRoute != req.State.HealthRoute { needsUpdate = true updatePathsMap["container_spec"] = true } // Convert map to slice updatePaths := make([]string, 0, len(updatePathsMap)) for path := range updatePathsMap { updatePaths = append(updatePaths, path) } return needsUpdate, updatePaths }
// Package resources provides Pulumi resource implementations for GCP Vertex AI model upload and deployment. package resources import ( "context" "fmt" "log" "log/slog" "strings" "time" p "github.com/pulumi/pulumi-go-provider" "github.com/pulumi/pulumi-go-provider/infer" "github.com/davidmontoyago/pulumi-gcp-vertex-model-deployment/pkg/services" ) // VertexModelDeployment represents a Pulumi resource for deploying models to Vertex AI endpoints. type VertexModelDeployment struct{} // Compile-time interface compliance checks var _ infer.CustomCreate[VertexModelDeploymentArgs, VertexModelDeploymentState] = (*VertexModelDeployment)(nil) var _ infer.CustomRead[VertexModelDeploymentArgs, VertexModelDeploymentState] = (*VertexModelDeployment)(nil) var _ infer.CustomUpdate[VertexModelDeploymentArgs, VertexModelDeploymentState] = (*VertexModelDeployment)(nil) var _ infer.CustomDelete[VertexModelDeploymentState] = (*VertexModelDeployment)(nil) var _ infer.CustomDiff[VertexModelDeploymentArgs, VertexModelDeploymentState] = (*VertexModelDeployment)(nil) // Annotate provides metadata and descriptions for the VertexModelDeployment resource. func (VertexModelDeployment) Annotate(annotator infer.Annotator) { annotator.Describe(&VertexModelDeployment{}, "Deploys a model to a Vertex AI endpoint") } // VertexModelDeploymentState represents the state of a deployed Vertex AI model. type VertexModelDeploymentState struct { VertexModelDeploymentArgs ModelName string `pulumi:"modelName"` DeployedModelID string `pulumi:"deployedModelId"` EndpointName string `pulumi:"endpointName"` CreateTime string `pulumi:"createTime"` } // Annotate provides metadata and descriptions for the VertexModelDeploymentState outputs. func (state *VertexModelDeploymentState) Annotate(annotator infer.Annotator) { annotator.Describe(&state.DeployedModelID, "ID of the deployed model") annotator.Describe(&state.EndpointName, "Full name of the endpoint") annotator.Describe(&state.CreateTime, "Creation timestamp") } // Create implements the creation logic func (v VertexModelDeployment) Create( ctx context.Context, req infer.CreateRequest[VertexModelDeploymentArgs], ) (infer.CreateResponse[VertexModelDeploymentState], error) { state := VertexModelDeploymentState{ VertexModelDeploymentArgs: req.Inputs, } resourceID := fmt.Sprintf("%s-%s-%s", req.Inputs.ProjectID, req.Inputs.Region, req.Name) if req.DryRun { return infer.CreateResponse[VertexModelDeploymentState]{ ID: resourceID, }, nil } modelClientFactory := v.getModelClientFactory() modelClient, err := modelClientFactory(ctx, req.Inputs.Region) if err != nil { return infer.CreateResponse[VertexModelDeploymentState]{}, fmt.Errorf("failed to create model client: %w", err) } defer func() { if closeErr := modelClient.Close(); closeErr != nil { log.Printf("failed to close model client: %v", closeErr) } }() // Create the model upload service uploader := services.NewVertexModelUpload(ctx, modelClient, req.Inputs.ProjectID, req.Inputs.Region, req.Inputs.Labels) defer func() { if closeErr := uploader.Close(); closeErr != nil { log.Printf("failed to close model upload service: %v", closeErr) } }() // Upload the model modelName, err := uploader.Upload(ctx, services.ModelUpload{ Name: req.Name, ModelImageURL: req.Inputs.ModelImageURL, ModelArtifactsBucketURI: req.Inputs.ModelArtifactsBucketURI, ServiceAccountEmail: req.Inputs.ServiceAccount, ModelPredictionInputSchemaURI: req.Inputs.ModelPredictionInputSchemaURI, ModelPredictionOutputSchemaURI: req.Inputs.ModelPredictionOutputSchemaURI, ModelPredictionBehaviorSchemaURI: req.Inputs.ModelPredictionBehaviorSchemaURI, PredictRoute: req.Inputs.PredictRoute, HealthRoute: req.Inputs.HealthRoute, }) if err != nil { return infer.CreateResponse[VertexModelDeploymentState]{}, fmt.Errorf("failed to upload model: %w", err) } state.ModelName = modelName state.CreateTime = time.Now().Format(time.RFC3339) // Only deploy to endpoint if endpoint deployment is configured if isEndpointDeploymentEnabled(req.Inputs) { endpointClientFactory := v.getEndpointClientFactory() endpointClient, err := endpointClientFactory(ctx, req.Inputs.Region) if err != nil { return infer.CreateResponse[VertexModelDeploymentState]{}, fmt.Errorf("failed to create endpoint client: %w", err) } defer func() { if closeErr := endpointClient.Close(); closeErr != nil { log.Printf("failed to close endpoint client: %v", closeErr) } }() // Create the model deployment service deployer := services.NewVertexModelDeploy(ctx, endpointClient, req.Inputs.ProjectID, req.Inputs.Region) defer func() { if closeErr := deployer.Close(); closeErr != nil { log.Printf("failed to close endpoint model deployment service: %v", closeErr) } }() // Deploy the model to the endpoint endpointConfig := toEndpointDeploymentConfig(req.Inputs.EndpointModelDeployment) deployedModelID, err := deployer.Deploy( ctx, modelName, req.Name, req.Inputs.ServiceAccount, endpointConfig, ) if err != nil { return infer.CreateResponse[VertexModelDeploymentState]{}, fmt.Errorf("failed to deploy model: %w", err) } state.DeployedModelID = deployedModelID state.EndpointName = req.Inputs.EndpointModelDeployment.EndpointID } return infer.CreateResponse[VertexModelDeploymentState]{ ID: resourceID, Output: state, }, nil } // Delete implements the deletion logic func (v VertexModelDeployment) Delete( ctx context.Context, req infer.DeleteRequest[VertexModelDeploymentState], ) (infer.DeleteResponse, error) { // Only undeploy from endpoint if the model was deployed to an endpoint if req.State.DeployedModelID != "" && req.State.EndpointName != "" { // Create endpoint client using the factory endpointClientFactory := v.getEndpointClientFactory() endpointClient, err := endpointClientFactory(ctx, req.State.Region) if err != nil { return infer.DeleteResponse{}, fmt.Errorf("failed to create endpoint client for undeployment: %w", err) } defer func() { if closeErr := endpointClient.Close(); closeErr != nil { log.Printf("failed to close endpoint client for undeployment: %v", closeErr) } }() undeployer := services.NewVertexModelUndeploy(ctx, endpointClient, req.State.ProjectID, req.State.Region) err = undeployer.Undeploy(ctx, req.State.EndpointName, req.State.DeployedModelID) if err != nil { return infer.DeleteResponse{}, fmt.Errorf("failed to undeploy model: %w", err) } } // After undeploying, delete the model modelClientFactory := v.getModelClientFactory() modelClient, err := modelClientFactory(ctx, req.State.Region) if err != nil { return infer.DeleteResponse{}, fmt.Errorf("failed to create model client: %w", err) } defer func() { if closeErr := modelClient.Close(); closeErr != nil { log.Printf("failed to close model client: %v", closeErr) } }() deleter := services.NewVertexModelDelete(ctx, modelClient) err = deleter.Delete(ctx, req.State.ModelName) if err != nil { return infer.DeleteResponse{}, fmt.Errorf("failed to delete model: %w", err) } return infer.DeleteResponse{}, nil } // Update implements the update logic func (v VertexModelDeployment) Update( ctx context.Context, req infer.UpdateRequest[VertexModelDeploymentArgs, VertexModelDeploymentState], ) (infer.UpdateResponse[VertexModelDeploymentState], error) { // Handle dry run - return the updated state without actually making changes if req.DryRun { return infer.UpdateResponse[VertexModelDeploymentState]{ Output: VertexModelDeploymentState{ VertexModelDeploymentArgs: req.Inputs, DeployedModelID: req.State.DeployedModelID, ModelName: req.State.ModelName, EndpointName: req.State.EndpointName, CreateTime: req.State.CreateTime, }, }, nil } // Check if any model properties actually need updating to avoid unnecessary API calls needsUpdate, updatePaths := collectUpdates(req) // If no updates are needed, return current state if !needsUpdate { return infer.UpdateResponse[VertexModelDeploymentState]{ Output: req.State, }, nil } modelClientFactory := v.getModelClientFactory() modelClient, err := modelClientFactory(ctx, req.State.Region) if err != nil { return infer.UpdateResponse[VertexModelDeploymentState]{}, fmt.Errorf("failed to create model client: %w", err) } defer func() { if closeErr := modelClient.Close(); closeErr != nil { log.Printf("failed to close model client: %v", closeErr) } }() // Build prediction schema (consistent with model creation) updatedModel, err := updateRegistryModel(ctx, req, modelClient, updatePaths) if err != nil { return infer.UpdateResponse[VertexModelDeploymentState]{}, fmt.Errorf("failed to update registry model: %w", err) } // Create updated state with the response from the API updatedState := setModelStateUpdates(req, updatedModel) // TODO update endpoint model deployment return infer.UpdateResponse[VertexModelDeploymentState]{ Output: updatedState, }, nil } // Read implements the read logic for drift detection func (v VertexModelDeployment) Read( ctx context.Context, req infer.ReadRequest[VertexModelDeploymentArgs, VertexModelDeploymentState], ) (infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState], error) { // Validate that we have the minimum required information to read the resource if req.ID == "" { return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, fmt.Errorf("resource ID is required for read operation") } // Create a copy of the current state to modify state := req.State // Always attempt to read the model if we have a model name if req.State.ModelName == "" { return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, fmt.Errorf("model name is required in state to read the resource") } modelClientFactory := v.getModelClientFactory() modelClient, err := modelClientFactory(ctx, req.State.Region) if err != nil { // If we can't create the client, don't assume the resource is gone return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, fmt.Errorf("failed to create model client: %w", err) } defer func() { if err := modelClient.Close(); err != nil { log.Printf("failed to close model client: %v", err) } }() err = readRegistryModel(ctx, modelClient, req, &state) if err != nil { // Check if this is a "not found" error specifically // If the model truly doesn't exist, return empty response // Otherwise, return the error if isResourceNotFoundError(err) { slog.Warn("Model no longer exists", "modelName", req.State.ModelName) return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, nil } return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, fmt.Errorf("failed to read model from registry: %w", err) } if req.State.DeployedModelID != "" && req.State.EndpointName != "" { // Read the endpoint if model is deployed to an endpoint endpointClientFactory := v.getEndpointClientFactory() endpointClient, err := endpointClientFactory(ctx, req.State.Region) if err != nil { return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, fmt.Errorf("failed to create endpoint client: %w", err) } defer func() { if closeErr := endpointClient.Close(); closeErr != nil { log.Printf("failed to close endpoint client: %v", closeErr) } }() err = readEndpointModel(ctx, endpointClient, req, &state) if err != nil { return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, fmt.Errorf("failed to read model endpoint: %w", err) } } return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{ ID: req.ID, Inputs: req.Inputs, State: state, }, nil } // Diff implements the diff logic to control what changes require replacement vs update func (v VertexModelDeployment) Diff( _ context.Context, req infer.DiffRequest[VertexModelDeploymentArgs, VertexModelDeploymentState], ) (p.DiffResponse, error) { diff := p.DiffResponse{ HasChanges: false, DetailedDiff: make(map[string]p.PropertyDiff), } // Properties that require replacement (immutable) immutableProperties := map[string]bool{ "projectId": true, "region": true, } // Check ProjectID if req.Inputs.ProjectID != req.State.ProjectID { diff.HasChanges = true if immutableProperties["projectId"] { diff.DetailedDiff["projectId"] = p.PropertyDiff{ Kind: p.UpdateReplace, InputDiff: true, } } else { diff.DetailedDiff["projectId"] = p.PropertyDiff{ Kind: p.Update, InputDiff: true, } } } // Check Region if req.Inputs.Region != req.State.Region { diff.HasChanges = true if immutableProperties["region"] { diff.DetailedDiff["region"] = p.PropertyDiff{ Kind: p.UpdateReplace, InputDiff: true, } } else { diff.DetailedDiff["region"] = p.PropertyDiff{ Kind: p.Update, InputDiff: true, } } } // Check ModelImageURL - this can be updated if req.Inputs.ModelImageURL != req.State.ModelImageURL { diff.HasChanges = true diff.DetailedDiff["modelImageUrl"] = p.PropertyDiff{ // Image URL is immutable, requires replacement of the model resource // See: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/ModelContainerSpec Kind: p.UpdateReplace, InputDiff: true, } } // Check ModelArtifactsBucketURI - this can be updated if req.Inputs.ModelArtifactsBucketURI != req.State.ModelArtifactsBucketURI { diff.HasChanges = true diff.DetailedDiff["modelArtifactsBucketUri"] = p.PropertyDiff{ Kind: p.Update, InputDiff: true, } } // Prediction schema URIs are immutable // See: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/PredictSchemata // Check service account - this can be updated if req.Inputs.ServiceAccount != req.State.ServiceAccount { diff.HasChanges = true diff.DetailedDiff["serviceAccount"] = p.PropertyDiff{ Kind: p.Update, InputDiff: true, } } // Check route configurations - these can be updated if req.Inputs.PredictRoute != req.State.PredictRoute { diff.HasChanges = true diff.DetailedDiff["predictRoute"] = p.PropertyDiff{ Kind: p.Update, InputDiff: true, } } if req.Inputs.HealthRoute != req.State.HealthRoute { diff.HasChanges = true diff.DetailedDiff["healthRoute"] = p.PropertyDiff{ Kind: p.Update, InputDiff: true, } } // Check labels - these can be updated if !mapsEqual(req.Inputs.Labels, req.State.Labels) { diff.HasChanges = true diff.DetailedDiff["labels"] = p.PropertyDiff{ Kind: p.Update, InputDiff: true, } } // Check endpoint model deployment configuration - this can be updated if !endpointDeploymentEqual(req.Inputs.EndpointModelDeployment, req.State.EndpointModelDeployment) { diff.HasChanges = true diff.DetailedDiff["endpointModelDeployment"] = p.PropertyDiff{ Kind: p.Update, InputDiff: true, } } return diff, nil } // testFactoryRegistry holds test factories for dependency injection during testing var testFactoryRegistry struct { modelClientFactory services.ModelClientFactory endpointClientFactory services.EndpointClientFactory } // getModelClientFactory returns the model client factory, defaulting to production factory if nil. func (v VertexModelDeployment) getModelClientFactory() services.ModelClientFactory { if testFactoryRegistry.modelClientFactory == nil { return services.DefaultModelClientFactory } return testFactoryRegistry.modelClientFactory } // getEndpointClientFactory returns the endpoint client factory, defaulting to production factory if nil. func (v VertexModelDeployment) getEndpointClientFactory() services.EndpointClientFactory { if testFactoryRegistry.endpointClientFactory == nil { return services.DefaultEndpointClientFactory } return testFactoryRegistry.endpointClientFactory } // toEndpointDeploymentConfig converts EndpointModelDeploymentArgs to services.EndpointModelDeploymentConfig func toEndpointDeploymentConfig(args *EndpointModelDeploymentArgs) services.EndpointModelDeploymentConfig { return services.EndpointModelDeploymentConfig{ EndpointID: args.EndpointID, MachineType: args.MachineType, MinReplicas: safeIntToInt32(args.MinReplicas), MaxReplicas: safeIntToInt32(args.MaxReplicas), TrafficPercent: safeIntToInt32(args.TrafficPercent), } } // isEndpointDeploymentEnabled checks if endpoint deployment is configured func isEndpointDeploymentEnabled(args VertexModelDeploymentArgs) bool { return args.EndpointModelDeployment != nil && args.EndpointModelDeployment.EndpointID != "" } // mapsEqual checks if two string maps are equal func mapsEqual(a, b map[string]string) bool { if len(a) != len(b) { return false } for k, v := range a { if b[k] != v { return false } } return true } // endpointDeploymentEqual checks if two EndpointModelDeploymentArgs are equal func endpointDeploymentEqual(inputDeployment, stateDeployment *EndpointModelDeploymentArgs) bool { // Both nil if inputDeployment == nil && stateDeployment == nil { return true } // One nil, one not nil if inputDeployment == nil || stateDeployment == nil { return false } // Compare all fields return inputDeployment.EndpointID == stateDeployment.EndpointID && inputDeployment.MachineType == stateDeployment.MachineType && inputDeployment.MinReplicas == stateDeployment.MinReplicas && inputDeployment.MaxReplicas == stateDeployment.MaxReplicas && inputDeployment.TrafficPercent == stateDeployment.TrafficPercent } // isResourceNotFoundError detects if the error indicates the resource doesn't exist func isResourceNotFoundError(err error) bool { if err == nil { return false } errStr := err.Error() return strings.Contains(errStr, "not found") || strings.Contains(errStr, "does not exist") || strings.Contains(errStr, "404") }
package resources import "math" // safeIntToInt32 safely converts an int to int32, clamping to int32 range func safeIntToInt32(value int) int32 { if value < math.MinInt32 { return math.MinInt32 } if value > math.MaxInt32 { return math.MaxInt32 } return int32(value) }
package services import ( "context" "fmt" aiplatform "cloud.google.com/go/aiplatform/apiv1" "cloud.google.com/go/aiplatform/apiv1/aiplatformpb" gax "github.com/googleapis/gax-go/v2" "google.golang.org/api/option" ) // VertexModelClient interface defines operations for uploading models. type VertexModelClient interface { UploadModel(ctx context.Context, req *aiplatformpb.UploadModelRequest, opts ...gax.CallOption) (*aiplatform.UploadModelOperation, error) DeleteModel(ctx context.Context, req *aiplatformpb.DeleteModelRequest, opts ...gax.CallOption) (*aiplatform.DeleteModelOperation, error) GetModel(ctx context.Context, req *aiplatformpb.GetModelRequest, opts ...gax.CallOption) (*aiplatformpb.Model, error) UpdateModel(context.Context, *aiplatformpb.UpdateModelRequest, ...gax.CallOption) (*aiplatformpb.Model, error) Close() error } // VertexEndpointClient interface defines operations for deploying models. type VertexEndpointClient interface { DeployModel(ctx context.Context, req *aiplatformpb.DeployModelRequest, opts ...gax.CallOption) (*aiplatform.DeployModelOperation, error) UndeployModel(ctx context.Context, req *aiplatformpb.UndeployModelRequest, opts ...gax.CallOption) (*aiplatform.UndeployModelOperation, error) GetEndpoint(ctx context.Context, req *aiplatformpb.GetEndpointRequest, opts ...gax.CallOption) (*aiplatformpb.Endpoint, error) Close() error } // ModelClientFactory function type for creating model clients type ModelClientFactory func(ctx context.Context, region string) (VertexModelClient, error) // EndpointClientFactory function type for creating endpoint clients type EndpointClientFactory func(ctx context.Context, region string) (VertexEndpointClient, error) // DefaultModelClientFactory creates the production GCP model client. // //nolint:ireturn // Returning interface for testability func DefaultModelClientFactory(ctx context.Context, region string) (VertexModelClient, error) { // Regional endpoints require regional endpoints apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", region) clientEndpointOpt := option.WithEndpoint(apiEndpoint) // Create model client modelClient, err := aiplatform.NewModelClient(ctx, clientEndpointOpt) if err != nil { return nil, fmt.Errorf("failed to create model client: %w", err) } return modelClient, nil } // DefaultEndpointClientFactory creates the production GCP endpoint client. // //nolint:ireturn // Returning interface for testability func DefaultEndpointClientFactory(ctx context.Context, region string) (VertexEndpointClient, error) { // Regional endpoints require regional endpoints apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", region) clientEndpointOpt := option.WithEndpoint(apiEndpoint) // Create endpoint client endpointClient, err := aiplatform.NewEndpointClient(ctx, clientEndpointOpt) if err != nil { return nil, fmt.Errorf("failed to create endpoint client: %w", err) } return endpointClient, nil }
package services import ( "context" "fmt" "log" "cloud.google.com/go/aiplatform/apiv1/aiplatformpb" ) // ModelDeleter allows deleting models from the registry. type ModelDeleter interface { Delete(ctx context.Context, modelName string) error Close() error } // VertexModelDelete implements the ModelDeleter interface for Vertex AI. type VertexModelDelete struct { modelClient VertexModelClient } // NewVertexModelDelete creates a new VertexModelDelete with the provided model client. func NewVertexModelDelete(_ context.Context, modelClient VertexModelClient) *VertexModelDelete { return &VertexModelDelete{ modelClient: modelClient, } } // Delete deletes a model from Vertex AI. func (d *VertexModelDelete) Delete(ctx context.Context, modelName string) error { deleteReq := &aiplatformpb.DeleteModelRequest{ // Model name is already in the format projects/{project}/locations/{location}/models/{model ID} Name: modelName, } deleteOperation, err := d.modelClient.DeleteModel(ctx, deleteReq) if err != nil { return fmt.Errorf("failed to delete model: %w", err) } if deleteOperation == nil { log.Printf("Warning: model delete operation is nil?!? This must be a mocked client. Logging error and moving on.") } else { err = deleteOperation.Wait(ctx) if err != nil { return fmt.Errorf("failed to wait for deletion: %w", err) } } return nil }
// Package services provides implementations for GCP Vertex AI model and Endpoint deployment operations. package services import ( "context" "fmt" "log" "cloud.google.com/go/aiplatform/apiv1/aiplatformpb" ) // EndpointModelDeploymentConfig holds configuration for deploying a model to an endpoint. type EndpointModelDeploymentConfig struct { EndpointID string MachineType string MinReplicas int32 MaxReplicas int32 TrafficPercent int32 } // ModelDeployer interface defines operations for deploying models. type ModelDeployer interface { Deploy(ctx context.Context, modelName, displayName, serviceAccount string, endpointConfig EndpointModelDeploymentConfig) (string, error) Close() error } // VertexModelDeploy implements the ModelDeployer interface for Vertex AI. type VertexModelDeploy struct { endpointClient VertexEndpointClient projectID string region string } // NewVertexModelDeploy creates a new VertexModelDeploy with the provided endpoint client. func NewVertexModelDeploy(_ context.Context, endpointClient VertexEndpointClient, projectID, region string) *VertexModelDeploy { return &VertexModelDeploy{ endpointClient: endpointClient, projectID: projectID, region: region, } } // Deploy deploys a model to a Vertex AI endpoint and returns the deployed model ID. func (d *VertexModelDeploy) Deploy(ctx context.Context, modelName, displayName, serviceAccount string, endpointConfig EndpointModelDeploymentConfig) (string, error) { // Build the deployment request deployedModel := &aiplatformpb.DeployedModel{ // Expected format: "projects/%s/locations/%s/models/%s" Model: modelName, DisplayName: displayName, PredictionResources: &aiplatformpb.DeployedModel_DedicatedResources{ DedicatedResources: &aiplatformpb.DedicatedResources{ MachineSpec: &aiplatformpb.MachineSpec{ MachineType: endpointConfig.MachineType, }, MinReplicaCount: endpointConfig.MinReplicas, MaxReplicaCount: endpointConfig.MaxReplicas, }, }, } if serviceAccount != "" { deployedModel.ServiceAccount = serviceAccount } deployReq := &aiplatformpb.DeployModelRequest{ Endpoint: fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", d.projectID, d.region, endpointConfig.EndpointID), DeployedModel: deployedModel, TrafficSplit: map[string]int32{}, } // Set traffic split if specified if endpointConfig.TrafficPercent > 0 { deployReq.TrafficSplit = map[string]int32{ "0": endpointConfig.TrafficPercent, } } // Execute the deployment deployOperation, err := d.endpointClient.DeployModel(ctx, deployReq) if err != nil { return "", fmt.Errorf("failed to deploy model: %w", err) } if deployOperation == nil { log.Printf("Warning: deploy operation is nil?!? This must be a mocked client. Logging error and moving on.") return "", nil } // Wait for completion with timeout result, err := deployOperation.Wait(ctx) if err != nil { return "", fmt.Errorf("failed to wait for deployment: %w", err) } return result.GetDeployedModel().GetId(), nil } // Close closes the endpoint client. func (d *VertexModelDeploy) Close() error { if d.endpointClient != nil { if err := d.endpointClient.Close(); err != nil { return fmt.Errorf("failed to close endpoint client: %w", err) } } return nil }
package services import ( "context" "fmt" "cloud.google.com/go/aiplatform/apiv1/aiplatformpb" ) // EndpointModelGetter allows getting endpoints and their deployed models from the registry. type EndpointModelGetter interface { Get(ctx context.Context, endpointName, deployedModelID string) (*aiplatformpb.Endpoint, *aiplatformpb.DeployedModel, error) Close() error } // VertexEndpointModelGetter implements the EndpointModelGetter interface for Vertex AI. type VertexEndpointModelGetter struct { endpointClient VertexEndpointClient projectID string region string } // NewVertexEndpointModelGetter creates a new VertexEndpointModelGetter with the provided endpoint client. func NewVertexEndpointModelGetter(endpointClient VertexEndpointClient, projectID, region string) *VertexEndpointModelGetter { return &VertexEndpointModelGetter{ endpointClient: endpointClient, projectID: projectID, region: region, } } // Get retrieves an endpoint and finds the specified deployed model within it. // Returns the endpoint, the deployed model (if found), and any error. func (g *VertexEndpointModelGetter) Get(ctx context.Context, endpointName, deployedModelID string) (*aiplatformpb.Endpoint, *aiplatformpb.DeployedModel, error) { getReq := &aiplatformpb.GetEndpointRequest{ Name: fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", g.projectID, g.region, endpointName), } endpoint, err := g.endpointClient.GetEndpoint(ctx, getReq) if err != nil { return nil, nil, fmt.Errorf("failed to get endpoint: %w", err) } // Verify the deployed model still exists and update its properties var foundDeployedModel *aiplatformpb.DeployedModel for _, deployedModel := range endpoint.DeployedModels { if deployedModel.Id == deployedModelID { foundDeployedModel = deployedModel break } } return endpoint, foundDeployedModel, nil } // Close closes the endpoint client. func (g *VertexEndpointModelGetter) Close() error { return g.endpointClient.Close() }
package services import ( "context" "fmt" "log" "cloud.google.com/go/aiplatform/apiv1/aiplatformpb" ) // ModelUndeployer allows undeploying models from Vertex AI endpoints. type ModelUndeployer interface { Undeploy(ctx context.Context, endpointName, deployedModelID string) error Close() error } // VertexModelUndeploy implements the ModelUndeployer interface for Vertex AI. type VertexModelUndeploy struct { endpointClient VertexEndpointClient project string region string } // NewVertexModelUndeploy creates a new VertexModelUndeploy with the provided endpoint client and model name. func NewVertexModelUndeploy(_ context.Context, endpointClient VertexEndpointClient, project, region string) *VertexModelUndeploy { return &VertexModelUndeploy{ endpointClient: endpointClient, project: project, region: region, } } // Undeploy undeploys a model from an endpoint. func (u *VertexModelUndeploy) Undeploy(ctx context.Context, endpointName, deployedModelID string) error { undeployReq := &aiplatformpb.UndeployModelRequest{ Endpoint: fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", u.project, u.region, endpointName), DeployedModelId: deployedModelID, } undeployOperation, err := u.endpointClient.UndeployModel(ctx, undeployReq) if err != nil { return fmt.Errorf("failed to undeploy model: %w", err) } if undeployOperation == nil { log.Printf("Warning: model undeploy operation is nil?!? This must be a mocked client. Logging error and moving on.") } else { _, err = undeployOperation.Wait(ctx) if err != nil { return fmt.Errorf("failed to wait for undeployment: %w", err) } } return nil }
package services import ( "context" "fmt" "cloud.google.com/go/aiplatform/apiv1/aiplatformpb" ) // ModelGetter allows getting models from the registry. type ModelGetter interface { Get(ctx context.Context, modelName string) (*aiplatformpb.Model, error) Close() error } // VertexModelGet implements the ModelGetter interface for Vertex AI. type VertexModelGet struct { modelClient VertexModelClient modelName string } // NewVertexModelGet creates a new VertexModelGet with the provided model client. func NewVertexModelGet(_ context.Context, modelClient VertexModelClient, modelName string) *VertexModelGet { return &VertexModelGet{ modelClient: modelClient, modelName: modelName, } } // Get gets a model from Vertex AI. func (g *VertexModelGet) Get(ctx context.Context, modelName string) (*aiplatformpb.Model, error) { getReq := &aiplatformpb.GetModelRequest{ Name: modelName, } model, err := g.modelClient.GetModel(ctx, getReq) if err != nil { return nil, fmt.Errorf("failed to get model: %w", err) } return model, nil } // Close closes the model client. func (g *VertexModelGet) Close() error { return g.modelClient.Close() }
// Package services provides implementations for GCP Vertex AI model upload operations. package services import ( "context" "errors" "fmt" "log" "time" "cloud.google.com/go/aiplatform/apiv1/aiplatformpb" gax "github.com/googleapis/gax-go/v2" "github.com/googleapis/gax-go/v2/apierror" ) // ModelUpload represents the parameters needed to upload a model to Vertex AI. type ModelUpload struct { Name string ModelImageURL string ModelArtifactsBucketURI string ServiceAccountEmail string ModelPredictionInputSchemaURI string ModelPredictionOutputSchemaURI string ModelPredictionBehaviorSchemaURI string PredictRoute string HealthRoute string } // ModelUploader interface defines operations for uploading models. type ModelUploader interface { Upload(ctx context.Context, uploadParams ModelUpload) (string, error) Close() error } // VertexModelUpload implements the ModelUploader interface for Vertex AI. type VertexModelUpload struct { modelClient VertexModelClient projectID string region string labels map[string]string } // NewVertexModelUpload creates a new VertexModelUpload with the provided model client. func NewVertexModelUpload(_ context.Context, modelClient VertexModelClient, projectID, region string, labels map[string]string) *VertexModelUpload { return &VertexModelUpload{ modelClient: modelClient, projectID: projectID, region: region, labels: labels, } } // Upload uploads a model to Vertex AI and returns the model name. func (u *VertexModelUpload) Upload(ctx context.Context, params ModelUpload) (string, error) { predictionSchema := &aiplatformpb.PredictSchemata{ // Schema for the model input InstanceSchemaUri: params.ModelPredictionInputSchemaURI, // Schema for the model output PredictionSchemaUri: params.ModelPredictionOutputSchemaURI, } if params.ModelPredictionBehaviorSchemaURI != "" { // Schema for the model inference behavior. Optional depending on the model. predictionSchema.ParametersSchemaUri = params.ModelPredictionBehaviorSchemaURI } modelArgs := &aiplatformpb.Model{ DisplayName: params.Name, Description: "Uploaded model for " + params.ModelImageURL, ContainerSpec: &aiplatformpb.ModelContainerSpec{ ImageUri: params.ModelImageURL, Args: []string{ "--allow_precompilation=false", "--disable_optimizer=true", "--saved_model_tags='serve,tpu'", "--use_tfrt=true", }, }, Labels: u.labels, ArtifactUri: params.ModelArtifactsBucketURI, PredictSchemata: predictionSchema, } if params.PredictRoute != "" { modelArgs.ContainerSpec.PredictRoute = params.PredictRoute } if params.HealthRoute != "" { modelArgs.ContainerSpec.HealthRoute = params.HealthRoute } modelUploadOp, err := u.modelClient.UploadModel(ctx, &aiplatformpb.UploadModelRequest{ // TODO support non traditional / global models // GCP endpoint to which the model is attached. It can be regional or global, depending on the model type. Parent: fmt.Sprintf("projects/%s/locations/%s", u.projectID, u.region), ServiceAccount: params.ServiceAccountEmail, Model: modelArgs, }, gax.WithTimeout(5*time.Minute)) if err != nil { var apiError *apierror.APIError if errors.As(err, &apiError) { // TODO DRY up log.Printf("Model upload returned APIError details: %v\n", apiError) log.Printf("APIError reason: %v\n", apiError.Reason()) log.Printf("APIError details : %v\n", apiError.Details()) // If a gRPC transport was used you can extract the // google.golang.org/grpc/status.Status from the error log.Printf("APIError GRPCStatus: %+v\n", apiError.GRPCStatus()) log.Printf("APIError HTTPCode: %+v\n", apiError.HTTPCode()) } return "", fmt.Errorf("failed to upload model: %w", err) } if modelUploadOp == nil { log.Printf("Warning: model upload operation is nil?!? This must be a mocked client. Logging error and moving on.") return "MOCKED_MODEL_NAME", nil } // TODO make timeout configurable modelUploadResult, err := modelUploadOp.Wait(context.Background(), gax.WithTimeout(10*time.Minute)) if err != nil { if modelUploadOp.Done() { log.Printf("Model upload operation completed with failure: %v\n", err) } var apiError *apierror.APIError if errors.As(err, &apiError) { // TODO DRY up log.Printf("Model upload returned APIError details: %v\n", apiError) log.Printf("APIError reason: %v\n", apiError.Reason()) log.Printf("APIError details : %v\n", apiError.Details()) log.Printf("APIError help: %v\n", apiError.Details().Help) // If a gRPC transport was used you can extract the // google.golang.org/grpc/status.Status from the error log.Printf("APIError GRPCStatus: %v\n", apiError.GRPCStatus()) log.Printf("APIError HTTPCode: %v\n", apiError.HTTPCode()) } return "", fmt.Errorf("failed to wait for model upload: %w", err) } return modelUploadResult.GetModel(), nil } // Close closes the model client. func (u *VertexModelUpload) Close() error { if u.modelClient != nil { if err := u.modelClient.Close(); err != nil { return fmt.Errorf("failed to close model client: %w", err) } } return nil }