// Package main provides the entry point for the pulumi-gcp-vertex-model-deployment provider.
package main
import (
"context"
"fmt"
"os"
"github.com/pulumi/pulumi-go-provider/infer"
"github.com/davidmontoyago/pulumi-gcp-vertex-model-deployment/pkg/resources"
"github.com/davidmontoyago/pulumi-gcp-vertex-model-deployment/pkg/version"
)
func main() {
provider, err := infer.NewProviderBuilder().
WithResources(
infer.Resource(&resources.VertexModelDeployment{}),
).
WithNamespace("davidmontoyago").
WithDisplayName("pulumi-gcp-vertex-model-deployment").
WithLicense("Apache-2.0").
WithKeywords("pulumi", "gcp", "vertex", "model").
WithDescription("Deploy AI models to Vertex endpoints").
WithRepository("github.com/davidmontoyago/pulumi-gcp-vertex-model-deployment").
WithPluginDownloadURL("github://api.github.com/davidmontoyago/pulumi-gcp-vertex-model-deployment").
Build()
if err != nil {
fmt.Fprintf(os.Stderr, "Error building provider: %s", err.Error())
os.Exit(1)
}
// Name of the pulumi plugin.
pluginName := "gcp-vertex-model-deployment"
err = provider.Run(context.Background(), pluginName, version.Version)
if err != nil {
fmt.Fprintf(os.Stderr, "Error running provider: %s", err.Error())
os.Exit(1)
}
}
package resources
import (
"context"
"github.com/davidmontoyago/pulumi-gcp-vertex-model-deployment/pkg/services"
"github.com/pulumi/pulumi-go-provider/infer"
)
func readEndpointModel(ctx context.Context,
endpointClient services.VertexEndpointClient,
req infer.ReadRequest[VertexModelDeploymentArgs, VertexModelDeploymentState],
state *VertexModelDeploymentState) error {
endpointGetter := services.NewVertexEndpointModelGetter(endpointClient, req.State.ProjectID, req.State.Region)
endpoint, foundDeployedModel, err := endpointGetter.Get(ctx, req.State.EndpointName, req.State.DeployedModelID)
if err != nil {
return err
}
if foundDeployedModel == nil {
// Model is no longer deployed - return empty response to indicate resource doesn't exist
return nil
}
// Update state with current endpoint and deployed model information
state.EndpointName = endpoint.Name
state.DeployedModelID = foundDeployedModel.Id
// Update endpoint deployment configuration with current values if available
if state.EndpointModelDeployment == nil {
return nil
}
// Extract current deployment configuration from the deployed model
if dedicatedResources := foundDeployedModel.GetDedicatedResources(); dedicatedResources != nil {
if machineSpec := dedicatedResources.MachineSpec; machineSpec != nil {
state.EndpointModelDeployment.MachineType = machineSpec.MachineType
}
state.EndpointModelDeployment.MinReplicas = int(dedicatedResources.MinReplicaCount)
state.EndpointModelDeployment.MaxReplicas = int(dedicatedResources.MaxReplicaCount)
}
// Update traffic percentage from endpoint's traffic split if available
if endpoint.TrafficSplit != nil {
if trafficPercent, exists := endpoint.TrafficSplit[foundDeployedModel.Id]; exists {
state.EndpointModelDeployment.TrafficPercent = int(trafficPercent)
}
}
return nil
}
// Package resources provides Pulumi resource implementations for GCP Vertex model upload to the registry and endpoint deployment.
package resources
import "github.com/pulumi/pulumi-go-provider/infer"
// VertexModelDeploymentArgs defines the input arguments for creating a Vertex AI model deployment.
type VertexModelDeploymentArgs struct {
ProjectID string `pulumi:"projectId"`
Region string `pulumi:"region"`
ModelImageURL string `pulumi:"modelImageUrl"`
ModelArtifactsBucketURI string `pulumi:"modelArtifactsBucketUri,optional"`
ModelPredictionInputSchemaURI string `pulumi:"modelPredictionInputSchemaUri,optional"`
ModelPredictionOutputSchemaURI string `pulumi:"modelPredictionOutputSchemaUri,optional"`
ModelPredictionBehaviorSchemaURI string `pulumi:"modelPredictionBehaviorSchemaUri,optional"`
// If ModelImage is pointing to a private registry, this service account
// must have read access to the registry.
ServiceAccount string `pulumi:"serviceAccount"`
// Path on the container to send prediction requests to.
// Not required for Endpoints.
PredictRoute string `pulumi:"predictRoute,optional"`
// Path on the container to send health requests to.
// Not required for Endpoints.
HealthRoute string `pulumi:"healthRoute,optional"`
Args []string `pulumi:"args,optional"`
EnvVars map[string]string `pulumi:"env,optional"`
Port int32 `pulumi:"port,optional"`
// Target endpoint for the model deployment.
//
// Set only when serving the model on a Vertex AI Endpoint.
// Deploying a custom or dockerized model to a Vertex AI Endpoint is not yet supported
// by Terraform nor the Pulumi Google Cloud Native provider, hence, this custom provider
// exists.
// See: https://github.com/hashicorp/terraform-provider-google/issues/15303
//
// When deploying the model as a Batched Prediction Job, this field must be
// unset and the batch job must be created using the Pulumi Google Cloud Native
// provider.
EndpointModelDeployment *EndpointModelDeploymentArgs `pulumi:"endpointModelDeployment,optional"`
Labels map[string]string `pulumi:"labels,optional"`
}
// EndpointModelDeploymentArgs defines the input arguments for deploying an
// uploaded model to a Vertex AI endpoint.
type EndpointModelDeploymentArgs struct {
EndpointID string `pulumi:"endpointId"`
MachineType string `pulumi:"machineType,optional"`
AcceleratorType string `pulumi:"acceleratorType,optional"`
AcceleratorCount int32 `pulumi:"acceleratorCount,optional"`
MinReplicas int `pulumi:"minReplicas,optional"`
MaxReplicas int `pulumi:"maxReplicas,optional"`
TrafficPercent int `pulumi:"trafficPercent,optional"`
DisableContainerLogging bool `pulumi:"disableContainerLogging,optional"`
EnableAccessLogging bool `pulumi:"enableAccessLogging,optional"`
EnableSpotVMs bool `pulumi:"enableSpotVMs,optional"`
}
// Annotate provides metadata and default values for the VertexModelDeploymentArgs.
func (args *VertexModelDeploymentArgs) Annotate(annotator infer.Annotator) {
annotator.Describe(&args.ProjectID, "Google Cloud Project ID")
annotator.Describe(&args.Region, "Google Cloud region")
annotator.Describe(&args.ModelImageURL, "Vertex AI Image URL of a custom or prebuilt container model server. See: https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers")
annotator.Describe(&args.ModelArtifactsBucketURI, "Bucket URI to the model artifacts. For instance, gs://my-bucket/my-model-artifacts/ - See: https://cloud.google.com/vertex-ai/docs/training/exporting-model-artifacts")
annotator.Describe(&args.ModelPredictionInputSchemaURI, "Bucket URI to the schema for the model input")
annotator.Describe(&args.ModelPredictionOutputSchemaURI, "Bucket URI to the schema for the model output")
annotator.Describe(&args.ModelPredictionBehaviorSchemaURI, "Bucket URI to the schema for the model inference behavior")
annotator.Describe(&args.ServiceAccount, "Service account for the model. If ModelImage is pointing to a private registry, this service account must have read access to the registry.")
annotator.Describe(&args.Args, "Dockerized model server command line arguments")
annotator.Describe(&args.EnvVars, "Environment variables")
annotator.Describe(&args.Port, "Port for the model server. Defaults to 8080.")
annotator.Describe(&args.EndpointModelDeployment, "Configuration for deploying the model to a Vertex AI endpoint. Leave empty to upload model only for batched predictions.")
annotator.Describe(&args.Labels, "Labels for the deployment")
}
// Annotate provides metadata and default values for the EndpointModelDeploymentArgs.
func (args *EndpointModelDeploymentArgs) Annotate(annotator infer.Annotator) {
annotator.Describe(&args.EndpointID, "Vertex AI Endpoint ID")
annotator.Describe(&args.MachineType, "Machine type for deployment")
annotator.Describe(&args.AcceleratorType, "Accelerator type for endpoint deployment. Defaults to ACCELERATOR_TYPE_UNSPECIFIED. E.g.: NVIDIA_TESLA_P4, NVIDIA_TESLA_T4")
annotator.Describe(&args.AcceleratorCount, "Accelerator count for deployment")
annotator.Describe(&args.MinReplicas, "Minimum number of replicas")
annotator.Describe(&args.MaxReplicas, "Maximum number of replicas")
annotator.Describe(&args.TrafficPercent, "Traffic percentage for this deployment")
annotator.Describe(&args.DisableContainerLogging, "Disable container logging")
annotator.Describe(&args.EnableAccessLogging, "Enable access logging")
annotator.Describe(&args.EnableSpotVMs, "Enable spot VMs")
// Set defaults
annotator.SetDefault(&args.MachineType, "n1-standard-8")
annotator.SetDefault(&args.MinReplicas, 1)
annotator.SetDefault(&args.MaxReplicas, 3)
annotator.SetDefault(&args.TrafficPercent, 100)
}
package resources
import (
"context"
"fmt"
"github.com/davidmontoyago/pulumi-gcp-vertex-model-deployment/pkg/services"
"github.com/pulumi/pulumi-go-provider/infer"
)
func readRegistryModel(ctx context.Context,
modelClient services.VertexModelClient,
req infer.ReadRequest[VertexModelDeploymentArgs,
VertexModelDeploymentState], state *VertexModelDeploymentState) error {
modelGetter := services.NewVertexModelGet(ctx, modelClient, req.State.ModelName)
model, err := modelGetter.Get(ctx, req.State.ModelName)
if err != nil {
return fmt.Errorf("failed to get model: %w", err)
}
// Update state with current model values
state.ModelName = model.Name
state.ModelArtifactsBucketURI = model.ArtifactUri
state.Labels = model.Labels
// Safely access ContainerSpec fields
if model.ContainerSpec != nil {
state.ModelImageURL = model.ContainerSpec.ImageUri
state.PredictRoute = model.ContainerSpec.PredictRoute
state.HealthRoute = model.ContainerSpec.HealthRoute
}
// Read schema URIs if available.
if model.PredictSchemata != nil {
// These are immutable and should be ignored during updates.
// URI given on output will be immutable and probably different,
// including the URI scheme, than the one given on input.
// See: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/PredictSchemata
state.ModelPredictionInputSchemaURI = model.PredictSchemata.InstanceSchemaUri
state.ModelPredictionOutputSchemaURI = model.PredictSchemata.PredictionSchemaUri
state.ModelPredictionBehaviorSchemaURI = model.PredictSchemata.ParametersSchemaUri
}
return nil
}
package resources
import (
"context"
"fmt"
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
"github.com/davidmontoyago/pulumi-gcp-vertex-model-deployment/pkg/services"
"github.com/pulumi/pulumi-go-provider/infer"
"google.golang.org/protobuf/types/known/fieldmaskpb"
)
func updateRegistryModel(ctx context.Context, req infer.UpdateRequest[VertexModelDeploymentArgs, VertexModelDeploymentState], modelClient services.VertexModelClient, updatePaths []string) (*aiplatformpb.Model, error) {
predictionSchema := &aiplatformpb.PredictSchemata{}
if req.Inputs.ModelPredictionInputSchemaURI != "" {
predictionSchema.InstanceSchemaUri = req.Inputs.ModelPredictionInputSchemaURI
}
if req.Inputs.ModelPredictionOutputSchemaURI != "" {
predictionSchema.PredictionSchemaUri = req.Inputs.ModelPredictionOutputSchemaURI
}
if req.Inputs.ModelPredictionBehaviorSchemaURI != "" {
predictionSchema.ParametersSchemaUri = req.Inputs.ModelPredictionBehaviorSchemaURI
}
// Build container spec (consistent with model creation)
containerSpec := &aiplatformpb.ModelContainerSpec{
ImageUri: req.Inputs.ModelImageURL,
Args: req.Inputs.Args,
}
// Add environment variables
envVars := []*aiplatformpb.EnvVar{}
for name, value := range req.Inputs.EnvVars {
envVars = append(envVars, &aiplatformpb.EnvVar{
Name: name,
Value: value,
})
}
containerSpec.Env = envVars
// Add port configuration
modelServerPort := req.Inputs.Port
if modelServerPort == 0 {
modelServerPort = 8080
}
containerSpec.Ports = []*aiplatformpb.Port{
{
ContainerPort: modelServerPort,
},
}
if req.Inputs.PredictRoute != "" {
containerSpec.PredictRoute = req.Inputs.PredictRoute
}
if req.Inputs.HealthRoute != "" {
containerSpec.HealthRoute = req.Inputs.HealthRoute
}
updatedModel, err := modelClient.UpdateModel(ctx, &aiplatformpb.UpdateModelRequest{
Model: &aiplatformpb.Model{
Name: req.State.ModelName,
DisplayName: req.ID, // Use resource ID as display name (consistent with creation)
Description: "Uploaded model for " + req.Inputs.ModelImageURL, // Consistent with creation
Labels: req.Inputs.Labels,
ArtifactUri: req.Inputs.ModelArtifactsBucketURI,
ContainerSpec: containerSpec,
PredictSchemata: predictionSchema,
},
UpdateMask: &fieldmaskpb.FieldMask{
Paths: updatePaths,
},
})
if err != nil {
return nil, fmt.Errorf("failed to update model: %w", err)
}
return updatedModel, nil
}
func setModelStateUpdates(req infer.UpdateRequest[VertexModelDeploymentArgs, VertexModelDeploymentState], updatedModel *aiplatformpb.Model) VertexModelDeploymentState {
updatedState := VertexModelDeploymentState{
VertexModelDeploymentArgs: req.Inputs,
DeployedModelID: req.State.DeployedModelID,
ModelName: updatedModel.Name,
EndpointName: req.State.EndpointName,
CreateTime: req.State.CreateTime,
}
// Update state fields from the updated model response
if updatedModel.Labels != nil {
updatedState.Labels = updatedModel.Labels
}
// Update ModelArtifactsBucketURI from the model response
if updatedModel.ArtifactUri != "" {
updatedState.ModelArtifactsBucketURI = updatedModel.ArtifactUri
}
// Update container spec fields if available
if updatedModel.ContainerSpec != nil {
// ImageUri is immutable, requires replacement.
// See: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/ModelContainerSpec
setContainerSpecStateUpdates(updatedModel, updatedState)
}
// Update predict schemata fields if available
if updatedModel.PredictSchemata != nil {
if updatedModel.PredictSchemata.InstanceSchemaUri != "" {
updatedState.ModelPredictionInputSchemaURI = updatedModel.PredictSchemata.InstanceSchemaUri
}
if updatedModel.PredictSchemata.PredictionSchemaUri != "" {
updatedState.ModelPredictionOutputSchemaURI = updatedModel.PredictSchemata.PredictionSchemaUri
}
if updatedModel.PredictSchemata.ParametersSchemaUri != "" {
updatedState.ModelPredictionBehaviorSchemaURI = updatedModel.PredictSchemata.ParametersSchemaUri
}
}
return updatedState
}
func setContainerSpecStateUpdates(updatedModel *aiplatformpb.Model, updatedState VertexModelDeploymentState) {
if updatedModel.ContainerSpec.PredictRoute != "" {
updatedState.PredictRoute = updatedModel.ContainerSpec.PredictRoute
}
if updatedModel.ContainerSpec.HealthRoute != "" {
updatedState.HealthRoute = updatedModel.ContainerSpec.HealthRoute
}
// Update container args
if len(updatedModel.ContainerSpec.Args) > 0 {
updatedState.Args = updatedModel.ContainerSpec.Args
}
// Update environment variables
if len(updatedModel.ContainerSpec.Env) > 0 {
updatedState.EnvVars = make(map[string]string)
for _, env := range updatedModel.ContainerSpec.Env {
updatedState.EnvVars[env.Name] = env.Value
}
}
// Update port
if len(updatedModel.ContainerSpec.Ports) > 0 {
updatedState.Port = updatedModel.ContainerSpec.Ports[0].ContainerPort
}
}
func collectUpdates(req infer.UpdateRequest[VertexModelDeploymentArgs, VertexModelDeploymentState]) (bool, []string) {
needsUpdate := false
updatePathsMap := make(map[string]bool)
// Check if labels have changed
if !mapsEqual(req.Inputs.Labels, req.State.Labels) {
needsUpdate = true
updatePathsMap["labels"] = true
}
// Check if ModelImageURL has changed (affects description and container spec)
if req.Inputs.ModelImageURL != req.State.ModelImageURL {
needsUpdate = true
updatePathsMap["description"] = true
updatePathsMap["container_spec"] = true
}
// Check if ModelArtifactsBucketURI has changed
if req.Inputs.ModelArtifactsBucketURI != req.State.ModelArtifactsBucketURI {
needsUpdate = true
updatePathsMap["artifact_uri"] = true
}
// Check if prediction schema URIs have changed
if req.Inputs.ModelPredictionInputSchemaURI != req.State.ModelPredictionInputSchemaURI ||
req.Inputs.ModelPredictionOutputSchemaURI != req.State.ModelPredictionOutputSchemaURI ||
req.Inputs.ModelPredictionBehaviorSchemaURI != req.State.ModelPredictionBehaviorSchemaURI {
needsUpdate = true
updatePathsMap["predict_schemata"] = true
}
// Check if container routes have changed
if req.Inputs.PredictRoute != req.State.PredictRoute ||
req.Inputs.HealthRoute != req.State.HealthRoute {
needsUpdate = true
updatePathsMap["container_spec"] = true
}
// Check if container args have changed
if !slicesEqual(req.Inputs.Args, req.State.Args) {
needsUpdate = true
updatePathsMap["container_spec"] = true
}
// Check if environment variables have changed
if !mapsEqual(req.Inputs.EnvVars, req.State.EnvVars) {
needsUpdate = true
updatePathsMap["container_spec"] = true
}
// Check if port has changed
if req.Inputs.Port != req.State.Port {
needsUpdate = true
updatePathsMap["container_spec"] = true
}
// Convert map to slice
updatePaths := make([]string, 0, len(updatePathsMap))
for path := range updatePathsMap {
updatePaths = append(updatePaths, path)
}
return needsUpdate, updatePaths
}
// slicesEqual compares two string slices for equality
func slicesEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
// Package resources provides Pulumi resource implementations for GCP Vertex AI model upload and deployment.
package resources
import (
"context"
"fmt"
"log"
"log/slog"
"strings"
"time"
p "github.com/pulumi/pulumi-go-provider"
"github.com/pulumi/pulumi-go-provider/infer"
"github.com/davidmontoyago/pulumi-gcp-vertex-model-deployment/pkg/services"
)
// VertexModelDeployment represents a Pulumi resource for deploying models to Vertex AI endpoints.
type VertexModelDeployment struct{}
// Compile-time interface compliance checks
var _ infer.CustomCreate[VertexModelDeploymentArgs, VertexModelDeploymentState] = (*VertexModelDeployment)(nil)
var _ infer.CustomRead[VertexModelDeploymentArgs, VertexModelDeploymentState] = (*VertexModelDeployment)(nil)
var _ infer.CustomUpdate[VertexModelDeploymentArgs, VertexModelDeploymentState] = (*VertexModelDeployment)(nil)
var _ infer.CustomDelete[VertexModelDeploymentState] = (*VertexModelDeployment)(nil)
var _ infer.CustomDiff[VertexModelDeploymentArgs, VertexModelDeploymentState] = (*VertexModelDeployment)(nil)
// Annotate provides metadata and descriptions for the VertexModelDeployment resource.
func (VertexModelDeployment) Annotate(annotator infer.Annotator) {
annotator.Describe(&VertexModelDeployment{}, "Deploys a model to a Vertex AI endpoint")
}
// VertexModelDeploymentState represents the state of a deployed Vertex AI model.
type VertexModelDeploymentState struct {
VertexModelDeploymentArgs
ModelName string `pulumi:"modelName"`
DeployedModelID string `pulumi:"deployedModelId"`
EndpointName string `pulumi:"endpointName"`
CreateTime string `pulumi:"createTime"`
}
// Annotate provides metadata and descriptions for the VertexModelDeploymentState outputs.
func (state *VertexModelDeploymentState) Annotate(annotator infer.Annotator) {
annotator.Describe(&state.DeployedModelID, "ID of the deployed model")
annotator.Describe(&state.EndpointName, "Full name of the endpoint")
annotator.Describe(&state.CreateTime, "Creation timestamp")
}
// Create implements the creation logic
func (v VertexModelDeployment) Create(
ctx context.Context,
req infer.CreateRequest[VertexModelDeploymentArgs],
) (infer.CreateResponse[VertexModelDeploymentState], error) {
state := VertexModelDeploymentState{
VertexModelDeploymentArgs: req.Inputs,
}
resourceID := fmt.Sprintf("%s-%s-%s", req.Inputs.ProjectID, req.Inputs.Region, req.Name)
if req.DryRun {
return infer.CreateResponse[VertexModelDeploymentState]{
ID: resourceID,
}, nil
}
modelClientFactory := v.getModelClientFactory()
modelClient, err := modelClientFactory(ctx, req.Inputs.Region)
if err != nil {
return infer.CreateResponse[VertexModelDeploymentState]{},
fmt.Errorf("failed to create model client: %w", err)
}
defer func() {
if closeErr := modelClient.Close(); closeErr != nil {
log.Printf("failed to close model client: %v", closeErr)
}
}()
// Create the model upload service
uploader := services.NewVertexModelUpload(ctx, modelClient, req.Inputs.ProjectID, req.Inputs.Region, req.Inputs.Labels)
defer func() {
if closeErr := uploader.Close(); closeErr != nil {
log.Printf("failed to close model upload service: %v", closeErr)
}
}()
// Upload the model
modelName, err := uploader.Upload(ctx, services.ModelUpload{
Name: req.Name,
ModelImageURL: req.Inputs.ModelImageURL,
ModelArtifactsBucketURI: req.Inputs.ModelArtifactsBucketURI,
ModelPredictionInputSchemaURI: req.Inputs.ModelPredictionInputSchemaURI,
ModelPredictionOutputSchemaURI: req.Inputs.ModelPredictionOutputSchemaURI,
ModelPredictionBehaviorSchemaURI: req.Inputs.ModelPredictionBehaviorSchemaURI,
ServiceAccountEmail: req.Inputs.ServiceAccount,
PredictRoute: req.Inputs.PredictRoute,
HealthRoute: req.Inputs.HealthRoute,
Args: req.Inputs.Args,
EnvVars: req.Inputs.EnvVars,
Port: req.Inputs.Port,
})
if err != nil {
return infer.CreateResponse[VertexModelDeploymentState]{},
fmt.Errorf("failed to upload model: %w", err)
}
state.ModelName = modelName
state.CreateTime = time.Now().Format(time.RFC3339)
// Only deploy to endpoint if endpoint deployment is configured
if isEndpointDeploymentEnabled(req.Inputs) {
endpointClientFactory := v.getEndpointClientFactory()
endpointClient, err := endpointClientFactory(ctx, req.Inputs.Region)
if err != nil {
return infer.CreateResponse[VertexModelDeploymentState]{},
fmt.Errorf("failed to create endpoint client: %w", err)
}
defer func() {
if closeErr := endpointClient.Close(); closeErr != nil {
log.Printf("failed to close endpoint client: %v", closeErr)
}
}()
// Create the model deployment service
deployer := services.NewVertexModelDeploy(ctx, endpointClient, req.Inputs.ProjectID, req.Inputs.Region)
defer func() {
if closeErr := deployer.Close(); closeErr != nil {
log.Printf("failed to close endpoint model deployment service: %v", closeErr)
}
}()
// Deploy the model to the endpoint
endpointConfig := toEndpointDeploymentConfig(req.Inputs.EndpointModelDeployment)
deployedModelID, err := deployer.Deploy(
ctx,
modelName,
req.Name,
req.Inputs.ServiceAccount,
endpointConfig,
)
if err != nil {
return infer.CreateResponse[VertexModelDeploymentState]{},
fmt.Errorf("failed to deploy model: %w", err)
}
state.DeployedModelID = deployedModelID
state.EndpointName = req.Inputs.EndpointModelDeployment.EndpointID
}
return infer.CreateResponse[VertexModelDeploymentState]{
ID: resourceID,
Output: state,
}, nil
}
// Delete implements the deletion logic
func (v VertexModelDeployment) Delete(
ctx context.Context,
req infer.DeleteRequest[VertexModelDeploymentState],
) (infer.DeleteResponse, error) {
// Only undeploy from endpoint if the model was deployed to an endpoint
if req.State.DeployedModelID != "" && req.State.EndpointName != "" {
// Create endpoint client using the factory
endpointClientFactory := v.getEndpointClientFactory()
endpointClient, err := endpointClientFactory(ctx, req.State.Region)
if err != nil {
return infer.DeleteResponse{}, fmt.Errorf("failed to create endpoint client for undeployment: %w", err)
}
defer func() {
if closeErr := endpointClient.Close(); closeErr != nil {
log.Printf("failed to close endpoint client for undeployment: %v", closeErr)
}
}()
undeployer := services.NewVertexModelUndeploy(ctx, endpointClient, req.State.ProjectID, req.State.Region)
err = undeployer.Undeploy(ctx, req.State.EndpointName, req.State.DeployedModelID)
if err != nil {
return infer.DeleteResponse{}, fmt.Errorf("failed to undeploy model: %w", err)
}
}
// After undeploying, delete the model
modelClientFactory := v.getModelClientFactory()
modelClient, err := modelClientFactory(ctx, req.State.Region)
if err != nil {
return infer.DeleteResponse{}, fmt.Errorf("failed to create model client: %w", err)
}
defer func() {
if closeErr := modelClient.Close(); closeErr != nil {
log.Printf("failed to close model client: %v", closeErr)
}
}()
deleter := services.NewVertexModelDelete(ctx, modelClient, req.State.ProjectID, req.State.Region)
err = deleter.Delete(ctx, req.State.ModelName)
if err != nil {
return infer.DeleteResponse{}, fmt.Errorf("failed to delete model: %w", err)
}
return infer.DeleteResponse{}, nil
}
// Update implements the update logic
func (v VertexModelDeployment) Update(
ctx context.Context,
req infer.UpdateRequest[VertexModelDeploymentArgs, VertexModelDeploymentState],
) (infer.UpdateResponse[VertexModelDeploymentState], error) {
// Handle dry run - return the updated state without actually making changes
if req.DryRun {
return infer.UpdateResponse[VertexModelDeploymentState]{
Output: VertexModelDeploymentState{
VertexModelDeploymentArgs: req.Inputs,
DeployedModelID: req.State.DeployedModelID,
ModelName: req.State.ModelName,
EndpointName: req.State.EndpointName,
CreateTime: req.State.CreateTime,
},
}, nil
}
// Check if any model properties actually need updating to avoid unnecessary API calls
needsUpdate, updatePaths := collectUpdates(req)
// If no updates are needed, return current state
if !needsUpdate {
return infer.UpdateResponse[VertexModelDeploymentState]{
Output: req.State,
}, nil
}
modelClientFactory := v.getModelClientFactory()
modelClient, err := modelClientFactory(ctx, req.State.Region)
if err != nil {
return infer.UpdateResponse[VertexModelDeploymentState]{}, fmt.Errorf("failed to create model client: %w", err)
}
defer func() {
if closeErr := modelClient.Close(); closeErr != nil {
log.Printf("failed to close model client: %v", closeErr)
}
}()
// Build prediction schema (consistent with model creation)
updatedModel, err := updateRegistryModel(ctx, req, modelClient, updatePaths)
if err != nil {
return infer.UpdateResponse[VertexModelDeploymentState]{}, fmt.Errorf("failed to update registry model: %w", err)
}
// Create updated state with the response from the API
updatedState := setModelStateUpdates(req, updatedModel)
// TODO update endpoint model deployment
return infer.UpdateResponse[VertexModelDeploymentState]{
Output: updatedState,
}, nil
}
// Read implements the read logic for drift detection
func (v VertexModelDeployment) Read(
ctx context.Context,
req infer.ReadRequest[VertexModelDeploymentArgs, VertexModelDeploymentState],
) (infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState], error) {
// Validate that we have the minimum required information to read the resource
if req.ID == "" {
return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{},
fmt.Errorf("resource ID is required for read operation")
}
// Create a copy of the current state to modify
state := req.State
// Always attempt to read the model if we have a model name
if req.State.ModelName == "" {
return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{},
fmt.Errorf("model name is required in state to read the resource")
}
modelClientFactory := v.getModelClientFactory()
modelClient, err := modelClientFactory(ctx, req.State.Region)
if err != nil {
// If we can't create the client, don't assume the resource is gone
return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{},
fmt.Errorf("failed to create model client: %w", err)
}
defer func() {
if err := modelClient.Close(); err != nil {
log.Printf("failed to close model client: %v", err)
}
}()
err = readRegistryModel(ctx, modelClient, req, &state)
if err != nil {
// Check if this is a "not found" error specifically
// If the model truly doesn't exist, return empty response
// Otherwise, return the error
if isResourceNotFoundError(err) {
slog.Warn("Model no longer exists", "modelName", req.State.ModelName)
return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, nil
}
return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{},
fmt.Errorf("failed to read model from registry: %w", err)
}
if req.State.DeployedModelID != "" && req.State.EndpointName != "" {
// Read the endpoint if model is deployed to an endpoint
endpointClientFactory := v.getEndpointClientFactory()
endpointClient, err := endpointClientFactory(ctx, req.State.Region)
if err != nil {
return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, fmt.Errorf("failed to create endpoint client: %w", err)
}
defer func() {
if closeErr := endpointClient.Close(); closeErr != nil {
log.Printf("failed to close endpoint client: %v", closeErr)
}
}()
err = readEndpointModel(ctx, endpointClient, req, &state)
if err != nil {
return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{}, fmt.Errorf("failed to read model endpoint: %w", err)
}
}
return infer.ReadResponse[VertexModelDeploymentArgs, VertexModelDeploymentState]{
ID: req.ID,
Inputs: req.Inputs,
State: state,
}, nil
}
// Diff implements the diff logic to control what changes require replacement vs update
func (v VertexModelDeployment) Diff(
_ context.Context,
req infer.DiffRequest[VertexModelDeploymentArgs, VertexModelDeploymentState],
) (p.DiffResponse, error) {
diff := p.DiffResponse{
HasChanges: false,
DetailedDiff: make(map[string]p.PropertyDiff),
}
// Properties that require replacement (immutable)
immutableProperties := map[string]bool{
"projectId": true,
"region": true,
}
// Check ProjectID
if req.Inputs.ProjectID != req.State.ProjectID {
diff.HasChanges = true
if immutableProperties["projectId"] {
diff.DetailedDiff["projectId"] = p.PropertyDiff{
Kind: p.UpdateReplace,
InputDiff: true,
}
} else {
diff.DetailedDiff["projectId"] = p.PropertyDiff{
Kind: p.Update,
InputDiff: true,
}
}
}
// Check Region
if req.Inputs.Region != req.State.Region {
diff.HasChanges = true
if immutableProperties["region"] {
diff.DetailedDiff["region"] = p.PropertyDiff{
Kind: p.UpdateReplace,
InputDiff: true,
}
} else {
diff.DetailedDiff["region"] = p.PropertyDiff{
Kind: p.Update,
InputDiff: true,
}
}
}
// Check ModelImageURL - this can be updated
if req.Inputs.ModelImageURL != req.State.ModelImageURL {
diff.HasChanges = true
diff.DetailedDiff["modelImageUrl"] = p.PropertyDiff{
// Image URL is immutable, requires replacement of the model resource
// See: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/ModelContainerSpec
Kind: p.UpdateReplace,
InputDiff: true,
}
}
// Check ModelArtifactsBucketURI - this can be updated
if req.Inputs.ModelArtifactsBucketURI != req.State.ModelArtifactsBucketURI {
diff.HasChanges = true
diff.DetailedDiff["modelArtifactsBucketUri"] = p.PropertyDiff{
Kind: p.Update,
InputDiff: true,
}
}
// Prediction schema URIs are immutable
// See: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/PredictSchemata
// Check service account - this can be updated
if req.Inputs.ServiceAccount != req.State.ServiceAccount {
diff.HasChanges = true
diff.DetailedDiff["serviceAccount"] = p.PropertyDiff{
Kind: p.Update,
InputDiff: true,
}
}
// Check route configurations - these can be updated
if req.Inputs.PredictRoute != req.State.PredictRoute {
diff.HasChanges = true
diff.DetailedDiff["predictRoute"] = p.PropertyDiff{
Kind: p.Update,
InputDiff: true,
}
}
if req.Inputs.HealthRoute != req.State.HealthRoute {
diff.HasChanges = true
diff.DetailedDiff["healthRoute"] = p.PropertyDiff{
Kind: p.Update,
InputDiff: true,
}
}
// Check labels - these can be updated
if !mapsEqual(req.Inputs.Labels, req.State.Labels) {
diff.HasChanges = true
diff.DetailedDiff["labels"] = p.PropertyDiff{
Kind: p.Update,
InputDiff: true,
}
}
// Check endpoint model deployment configuration - this can be updated
if !endpointDeploymentEqual(req.Inputs.EndpointModelDeployment, req.State.EndpointModelDeployment) {
diff.HasChanges = true
diff.DetailedDiff["endpointModelDeployment"] = p.PropertyDiff{
Kind: p.Update,
InputDiff: true,
}
}
return diff, nil
}
// testFactoryRegistry holds test factories for dependency injection during testing
var testFactoryRegistry struct {
modelClientFactory services.ModelClientFactory
endpointClientFactory services.EndpointClientFactory
}
// getModelClientFactory returns the model client factory, defaulting to production factory if nil.
func (v VertexModelDeployment) getModelClientFactory() services.ModelClientFactory {
if testFactoryRegistry.modelClientFactory == nil {
return services.DefaultModelClientFactory
}
return testFactoryRegistry.modelClientFactory
}
// getEndpointClientFactory returns the endpoint client factory, defaulting to production factory if nil.
func (v VertexModelDeployment) getEndpointClientFactory() services.EndpointClientFactory {
if testFactoryRegistry.endpointClientFactory == nil {
return services.DefaultEndpointClientFactory
}
return testFactoryRegistry.endpointClientFactory
}
// toEndpointDeploymentConfig converts EndpointModelDeploymentArgs to services.EndpointModelDeploymentConfig
func toEndpointDeploymentConfig(args *EndpointModelDeploymentArgs) services.EndpointModelDeploymentConfig {
return services.EndpointModelDeploymentConfig{
EndpointID: args.EndpointID,
MachineType: args.MachineType,
AcceleratorType: args.AcceleratorType,
AcceleratorCount: args.AcceleratorCount,
MinReplicas: safeIntToInt32(args.MinReplicas),
MaxReplicas: safeIntToInt32(args.MaxReplicas),
TrafficPercent: safeIntToInt32(args.TrafficPercent),
DisableContainerLogging: args.DisableContainerLogging,
EnableAccessLogging: args.EnableAccessLogging,
EnableSpotVMs: args.EnableSpotVMs,
}
}
// isEndpointDeploymentEnabled checks if endpoint deployment is configured
func isEndpointDeploymentEnabled(args VertexModelDeploymentArgs) bool {
return args.EndpointModelDeployment != nil && args.EndpointModelDeployment.EndpointID != ""
}
// mapsEqual checks if two string maps are equal
func mapsEqual(a, b map[string]string) bool {
if len(a) != len(b) {
return false
}
for k, v := range a {
if b[k] != v {
return false
}
}
return true
}
// endpointDeploymentEqual checks if two EndpointModelDeploymentArgs are equal
func endpointDeploymentEqual(inputDeployment, stateDeployment *EndpointModelDeploymentArgs) bool {
// Both nil
if inputDeployment == nil && stateDeployment == nil {
return true
}
// One nil, one not nil
if inputDeployment == nil || stateDeployment == nil {
return false
}
// Compare all fields
return inputDeployment.EndpointID == stateDeployment.EndpointID &&
inputDeployment.MachineType == stateDeployment.MachineType &&
inputDeployment.AcceleratorType == stateDeployment.AcceleratorType &&
inputDeployment.AcceleratorCount == stateDeployment.AcceleratorCount &&
inputDeployment.MinReplicas == stateDeployment.MinReplicas &&
inputDeployment.MaxReplicas == stateDeployment.MaxReplicas &&
inputDeployment.TrafficPercent == stateDeployment.TrafficPercent &&
inputDeployment.DisableContainerLogging == stateDeployment.DisableContainerLogging &&
inputDeployment.EnableAccessLogging == stateDeployment.EnableAccessLogging &&
inputDeployment.EnableSpotVMs == stateDeployment.EnableSpotVMs
}
// isResourceNotFoundError detects if the error indicates the resource doesn't exist
func isResourceNotFoundError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
return strings.Contains(errStr, "not found") ||
strings.Contains(errStr, "does not exist") ||
strings.Contains(errStr, "404")
}
package resources
import "math"
// safeIntToInt32 safely converts an int to int32, clamping to int32 range
func safeIntToInt32(value int) int32 {
if value < math.MinInt32 {
return math.MinInt32
}
if value > math.MaxInt32 {
return math.MaxInt32
}
return int32(value)
}
package services
import (
"context"
"fmt"
aiplatform "cloud.google.com/go/aiplatform/apiv1"
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
gax "github.com/googleapis/gax-go/v2"
"google.golang.org/api/option"
)
// VertexModelClient interface defines operations for uploading models.
type VertexModelClient interface {
UploadModel(ctx context.Context, req *aiplatformpb.UploadModelRequest, opts ...gax.CallOption) (*aiplatform.UploadModelOperation, error)
DeleteModel(ctx context.Context, req *aiplatformpb.DeleteModelRequest, opts ...gax.CallOption) (*aiplatform.DeleteModelOperation, error)
GetModel(ctx context.Context, req *aiplatformpb.GetModelRequest, opts ...gax.CallOption) (*aiplatformpb.Model, error)
UpdateModel(context.Context, *aiplatformpb.UpdateModelRequest, ...gax.CallOption) (*aiplatformpb.Model, error)
Close() error
}
// VertexEndpointClient interface defines operations for deploying models.
type VertexEndpointClient interface {
DeployModel(ctx context.Context, req *aiplatformpb.DeployModelRequest, opts ...gax.CallOption) (*aiplatform.DeployModelOperation, error)
UndeployModel(ctx context.Context, req *aiplatformpb.UndeployModelRequest, opts ...gax.CallOption) (*aiplatform.UndeployModelOperation, error)
GetEndpoint(ctx context.Context, req *aiplatformpb.GetEndpointRequest, opts ...gax.CallOption) (*aiplatformpb.Endpoint, error)
Close() error
}
// ModelClientFactory function type for creating model clients
type ModelClientFactory func(ctx context.Context, region string) (VertexModelClient, error)
// EndpointClientFactory function type for creating endpoint clients
type EndpointClientFactory func(ctx context.Context, region string) (VertexEndpointClient, error)
// DefaultModelClientFactory creates the production GCP model client.
//
//nolint:ireturn // Returning interface for testability
func DefaultModelClientFactory(ctx context.Context, region string) (VertexModelClient, error) {
// Regional endpoints require regional endpoints
apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", region)
clientEndpointOpt := option.WithEndpoint(apiEndpoint)
// Create model client
modelClient, err := aiplatform.NewModelClient(ctx, clientEndpointOpt)
if err != nil {
return nil, fmt.Errorf("failed to create model client: %w", err)
}
return modelClient, nil
}
// DefaultEndpointClientFactory creates the production GCP endpoint client.
//
//nolint:ireturn // Returning interface for testability
func DefaultEndpointClientFactory(ctx context.Context, region string) (VertexEndpointClient, error) {
// Regional endpoints require regional endpoints
apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", region)
clientEndpointOpt := option.WithEndpoint(apiEndpoint)
// Create endpoint client
endpointClient, err := aiplatform.NewEndpointClient(ctx, clientEndpointOpt)
if err != nil {
return nil, fmt.Errorf("failed to create endpoint client: %w", err)
}
return endpointClient, nil
}
package services
import (
"context"
"fmt"
"log"
"strings"
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
)
// ModelDeleter allows deleting models from the registry.
type ModelDeleter interface {
Delete(ctx context.Context, modelName string) error
Close() error
}
// VertexModelDelete implements the ModelDeleter interface for Vertex AI.
type VertexModelDelete struct {
modelClient VertexModelClient
projectID string
region string
}
// NewVertexModelDelete creates a new VertexModelDelete with the provided model client.
func NewVertexModelDelete(_ context.Context, modelClient VertexModelClient, projectID, region string) *VertexModelDelete {
return &VertexModelDelete{
modelClient: modelClient,
projectID: projectID,
region: region,
}
}
// Delete deletes a model from Vertex AI.
func (d *VertexModelDelete) Delete(ctx context.Context, modelName string) error {
modelFullName := modelName
if !strings.HasPrefix(modelName, "projects/") {
modelFullName = fmt.Sprintf("projects/%s/locations/%s/models/%s",
d.projectID, d.region, modelName)
}
deleteReq := &aiplatformpb.DeleteModelRequest{
Name: modelFullName,
}
deleteOperation, err := d.modelClient.DeleteModel(ctx, deleteReq)
if err != nil {
return fmt.Errorf("failed to delete model: %w", err)
}
if deleteOperation == nil {
log.Printf("Warning: model delete operation is nil?!? This must be a mocked client. Logging error and moving on.")
} else {
err = deleteOperation.Wait(ctx)
if err != nil {
return fmt.Errorf("failed to wait for deletion: %w", err)
}
}
return nil
}
// Package services provides implementations for GCP Vertex AI model and Endpoint deployment operations.
package services
import (
"context"
"fmt"
"log"
"time"
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
gax "github.com/googleapis/gax-go/v2"
)
// EndpointModelDeploymentConfig holds configuration for deploying a model to an endpoint.
type EndpointModelDeploymentConfig struct {
EndpointID string
MachineType string
AcceleratorType string
AcceleratorCount int32
MinReplicas int32
MaxReplicas int32
TrafficPercent int32
DisableContainerLogging bool
EnableAccessLogging bool
EnableSpotVMs bool
}
// ModelDeployer interface defines operations for deploying models.
type ModelDeployer interface {
Deploy(ctx context.Context, modelName, displayName, serviceAccount string, endpointConfig EndpointModelDeploymentConfig) (string, error)
Close() error
}
// VertexModelDeploy implements the ModelDeployer interface for Vertex AI.
type VertexModelDeploy struct {
endpointClient VertexEndpointClient
projectID string
region string
}
// NewVertexModelDeploy creates a new VertexModelDeploy with the provided endpoint client.
func NewVertexModelDeploy(_ context.Context, endpointClient VertexEndpointClient, projectID, region string) *VertexModelDeploy {
return &VertexModelDeploy{
endpointClient: endpointClient,
projectID: projectID,
region: region,
}
}
// Deploy deploys a model to a Vertex AI endpoint and returns the deployed model ID.
func (d *VertexModelDeploy) Deploy(ctx context.Context, modelName, displayName, serviceAccount string, endpointConfig EndpointModelDeploymentConfig) (string, error) {
// Build the deployment request
deployedModel := &aiplatformpb.DeployedModel{
// Expected format: "projects/%s/locations/%s/models/%s"
Model: modelName,
DisplayName: displayName,
PredictionResources: &aiplatformpb.DeployedModel_DedicatedResources{
DedicatedResources: &aiplatformpb.DedicatedResources{
MachineSpec: &aiplatformpb.MachineSpec{
MachineType: endpointConfig.MachineType,
AcceleratorType: aiplatformpb.AcceleratorType(aiplatformpb.AcceleratorType_value[endpointConfig.AcceleratorType]),
AcceleratorCount: endpointConfig.AcceleratorCount,
},
MinReplicaCount: endpointConfig.MinReplicas,
MaxReplicaCount: endpointConfig.MaxReplicas,
Spot: endpointConfig.EnableSpotVMs,
},
},
DisableContainerLogging: endpointConfig.DisableContainerLogging,
EnableAccessLogging: endpointConfig.EnableAccessLogging,
}
if serviceAccount != "" {
deployedModel.ServiceAccount = serviceAccount
}
deployReq := &aiplatformpb.DeployModelRequest{
Endpoint: fmt.Sprintf("projects/%s/locations/%s/endpoints/%s",
d.projectID, d.region, endpointConfig.EndpointID),
DeployedModel: deployedModel,
TrafficSplit: map[string]int32{},
}
// Set traffic split if specified
if endpointConfig.TrafficPercent > 0 {
deployReq.TrafficSplit = map[string]int32{
"0": endpointConfig.TrafficPercent,
}
}
// Execute the deployment
deployOperation, err := d.endpointClient.DeployModel(ctx, deployReq)
if err != nil {
return "", fmt.Errorf("failed to deploy model: %w", err)
}
if deployOperation == nil {
log.Printf("Warning: deploy operation is nil?!? This must be a mocked client. Logging error and moving on.")
return "", nil
}
// Wait for completion with timeout
result, err := deployOperation.Wait(ctx, gax.WithTimeout(10*time.Minute))
if err != nil {
return "", fmt.Errorf("failed to wait for deployment: %w", err)
}
return result.GetDeployedModel().GetId(), nil
}
// Close closes the endpoint client.
func (d *VertexModelDeploy) Close() error {
if d.endpointClient != nil {
if err := d.endpointClient.Close(); err != nil {
return fmt.Errorf("failed to close endpoint client: %w", err)
}
}
return nil
}
package services
import (
"context"
"fmt"
"strings"
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
)
// EndpointModelGetter allows getting endpoints and their deployed models from the registry.
type EndpointModelGetter interface {
Get(ctx context.Context, endpointName, deployedModelID string) (*aiplatformpb.Endpoint, *aiplatformpb.DeployedModel, error)
Close() error
}
// VertexEndpointModelGetter implements the EndpointModelGetter interface for Vertex AI.
type VertexEndpointModelGetter struct {
endpointClient VertexEndpointClient
projectID string
region string
}
// NewVertexEndpointModelGetter creates a new VertexEndpointModelGetter with the provided endpoint client.
func NewVertexEndpointModelGetter(endpointClient VertexEndpointClient, projectID, region string) *VertexEndpointModelGetter {
return &VertexEndpointModelGetter{
endpointClient: endpointClient,
projectID: projectID,
region: region,
}
}
// Get retrieves an endpoint and finds the specified deployed model within it.
// Returns the endpoint, the deployed model (if found), and any error.
func (g *VertexEndpointModelGetter) Get(ctx context.Context, endpointName, deployedModelID string) (*aiplatformpb.Endpoint, *aiplatformpb.DeployedModel, error) {
endpointFullName := endpointName
if !strings.HasPrefix(endpointName, "projects/") {
endpointFullName = fmt.Sprintf("projects/%s/locations/%s/endpoints/%s",
g.projectID, g.region, endpointName)
}
getReq := &aiplatformpb.GetEndpointRequest{
Name: endpointFullName,
}
endpoint, err := g.endpointClient.GetEndpoint(ctx, getReq)
if err != nil {
return nil, nil, fmt.Errorf("failed to get endpoint: %w", err)
}
// Verify the deployed model still exists and update its properties
var foundDeployedModel *aiplatformpb.DeployedModel
for _, deployedModel := range endpoint.DeployedModels {
if deployedModel.Id == deployedModelID {
foundDeployedModel = deployedModel
break
}
}
return endpoint, foundDeployedModel, nil
}
// Close closes the endpoint client.
func (g *VertexEndpointModelGetter) Close() error {
return g.endpointClient.Close()
}
package services
import (
"context"
"fmt"
"log"
"strings"
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
)
// ModelUndeployer allows undeploying models from Vertex AI endpoints.
type ModelUndeployer interface {
Undeploy(ctx context.Context, endpointName, deployedModelID string) error
Close() error
}
// VertexModelUndeploy implements the ModelUndeployer interface for Vertex AI.
type VertexModelUndeploy struct {
endpointClient VertexEndpointClient
projectID string
region string
}
// NewVertexModelUndeploy creates a new VertexModelUndeploy with the provided endpoint client.
func NewVertexModelUndeploy(_ context.Context, endpointClient VertexEndpointClient, projectID, region string) *VertexModelUndeploy {
return &VertexModelUndeploy{
endpointClient: endpointClient,
projectID: projectID,
region: region,
}
}
// Undeploy undeploys a model from an endpoint.
func (u *VertexModelUndeploy) Undeploy(ctx context.Context, endpointName, deployedModelID string) error {
endpointFullName := endpointName
if !strings.HasPrefix(endpointName, "projects/") {
endpointFullName = fmt.Sprintf("projects/%s/locations/%s/endpoints/%s",
u.projectID, u.region, endpointName)
}
undeployReq := &aiplatformpb.UndeployModelRequest{
Endpoint: endpointFullName,
DeployedModelId: deployedModelID,
}
undeployOperation, err := u.endpointClient.UndeployModel(ctx, undeployReq)
if err != nil {
return fmt.Errorf("failed to undeploy model: %w", err)
}
if undeployOperation == nil {
log.Printf("Warning: model undeploy operation is nil?!? This must be a mocked client. Logging error and moving on.")
} else {
_, err = undeployOperation.Wait(ctx)
if err != nil {
return fmt.Errorf("failed to wait for undeployment: %w", err)
}
}
return nil
}
package services
import (
"context"
"fmt"
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
)
// ModelGetter allows getting models from the registry.
type ModelGetter interface {
Get(ctx context.Context, modelName string) (*aiplatformpb.Model, error)
Close() error
}
// VertexModelGet implements the ModelGetter interface for Vertex AI.
type VertexModelGet struct {
modelClient VertexModelClient
modelName string
}
// NewVertexModelGet creates a new VertexModelGet with the provided model client.
func NewVertexModelGet(_ context.Context, modelClient VertexModelClient, modelName string) *VertexModelGet {
return &VertexModelGet{
modelClient: modelClient,
modelName: modelName,
}
}
// Get gets a model from Vertex AI.
func (g *VertexModelGet) Get(ctx context.Context, modelName string) (*aiplatformpb.Model, error) {
getReq := &aiplatformpb.GetModelRequest{
Name: modelName,
}
model, err := g.modelClient.GetModel(ctx, getReq)
if err != nil {
return nil, fmt.Errorf("failed to get model: %w", err)
}
return model, nil
}
// Close closes the model client.
func (g *VertexModelGet) Close() error {
return g.modelClient.Close()
}
// Package services provides implementations for GCP Vertex AI model upload operations.
package services
import (
"context"
"errors"
"fmt"
"log"
"time"
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
gax "github.com/googleapis/gax-go/v2"
"github.com/googleapis/gax-go/v2/apierror"
)
// ModelUpload represents the parameters needed to upload a model to Vertex AI.
type ModelUpload struct {
Name string
ModelImageURL string
ModelArtifactsBucketURI string
ServiceAccountEmail string
ModelPredictionInputSchemaURI string
ModelPredictionOutputSchemaURI string
ModelPredictionBehaviorSchemaURI string
PredictRoute string
HealthRoute string
Args []string
EnvVars map[string]string
Port int32
}
// ModelUploader interface defines operations for uploading models.
type ModelUploader interface {
Upload(ctx context.Context, uploadParams ModelUpload) (string, error)
Close() error
}
// VertexModelUpload implements the ModelUploader interface for Vertex AI.
type VertexModelUpload struct {
modelClient VertexModelClient
projectID string
region string
labels map[string]string
}
// NewVertexModelUpload creates a new VertexModelUpload with the provided model client.
func NewVertexModelUpload(_ context.Context, modelClient VertexModelClient, projectID, region string, labels map[string]string) *VertexModelUpload {
return &VertexModelUpload{
modelClient: modelClient,
projectID: projectID,
region: region,
labels: labels,
}
}
// Upload uploads a model to Vertex AI and returns the model name.
func (u *VertexModelUpload) Upload(ctx context.Context, params ModelUpload) (string, error) {
envVars := []*aiplatformpb.EnvVar{}
for name, value := range params.EnvVars {
envVars = append(envVars, &aiplatformpb.EnvVar{
Name: name,
Value: value,
})
}
modelServerPort := params.Port
if modelServerPort == 0 {
modelServerPort = 8080
}
modelArgs := &aiplatformpb.Model{
DisplayName: params.Name,
Description: "Uploaded model for " + params.ModelImageURL,
ContainerSpec: &aiplatformpb.ModelContainerSpec{
ImageUri: params.ModelImageURL,
Args: params.Args,
Env: envVars,
Ports: []*aiplatformpb.Port{
{
ContainerPort: modelServerPort,
},
},
},
Labels: u.labels,
ArtifactUri: params.ModelArtifactsBucketURI,
}
if params.ModelPredictionInputSchemaURI != "" {
modelArgs.PredictSchemata = &aiplatformpb.PredictSchemata{
// Schema for the model input
InstanceSchemaUri: params.ModelPredictionInputSchemaURI,
// Schema for the model output
PredictionSchemaUri: params.ModelPredictionOutputSchemaURI,
}
if params.ModelPredictionBehaviorSchemaURI != "" {
// Schema for the model inference behavior. Optional depending on the model.
modelArgs.PredictSchemata.ParametersSchemaUri = params.ModelPredictionBehaviorSchemaURI
}
}
if params.PredictRoute != "" {
modelArgs.ContainerSpec.PredictRoute = params.PredictRoute
}
if params.HealthRoute != "" {
modelArgs.ContainerSpec.HealthRoute = params.HealthRoute
}
modelUploadOp, err := u.modelClient.UploadModel(ctx, &aiplatformpb.UploadModelRequest{
// TODO support non traditional / global models
// GCP endpoint to which the model is attached. It can be regional or global, depending on the model type.
Parent: fmt.Sprintf("projects/%s/locations/%s", u.projectID, u.region),
ServiceAccount: params.ServiceAccountEmail,
Model: modelArgs,
}, gax.WithTimeout(5*time.Minute))
if err != nil {
var apiError *apierror.APIError
if errors.As(err, &apiError) {
// TODO DRY up
log.Printf("Model upload returned APIError details: %v\n", apiError)
log.Printf("APIError reason: %v\n", apiError.Reason())
log.Printf("APIError details : %v\n", apiError.Details())
// If a gRPC transport was used you can extract the
// google.golang.org/grpc/status.Status from the error
log.Printf("APIError GRPCStatus: %+v\n", apiError.GRPCStatus())
log.Printf("APIError HTTPCode: %+v\n", apiError.HTTPCode())
}
return "", fmt.Errorf("failed to upload model: %w", err)
}
if modelUploadOp == nil {
log.Printf("Warning: model upload operation is nil?!? This must be a mocked client. Logging error and moving on.")
return "MOCKED_MODEL_NAME", nil
}
modelUploadResult, err := modelUploadOp.Wait(ctx, gax.WithTimeout(10*time.Minute))
if err != nil {
if modelUploadOp.Done() {
log.Printf("Model upload operation completed with failure: %v\n", err)
}
var apiError *apierror.APIError
if errors.As(err, &apiError) {
// TODO DRY up
log.Printf("Model upload returned APIError details: %v\n", apiError)
log.Printf("APIError reason: %v\n", apiError.Reason())
log.Printf("APIError details : %v\n", apiError.Details())
log.Printf("APIError help: %v\n", apiError.Details().Help)
// If a gRPC transport was used you can extract the
// google.golang.org/grpc/status.Status from the error
log.Printf("APIError GRPCStatus: %v\n", apiError.GRPCStatus())
log.Printf("APIError HTTPCode: %v\n", apiError.HTTPCode())
}
return "", fmt.Errorf("failed to wait for model upload: %w", err)
}
return modelUploadResult.GetModel(), nil
}
// Close closes the model client.
func (u *VertexModelUpload) Close() error {
if u.modelClient != nil {
if err := u.modelClient.Close(); err != nil {
return fmt.Errorf("failed to close model client: %w", err)
}
}
return nil
}