// Package gcp provides Google Cloud Platform infrastructure components for Vertex AI Batch Prediction Jobs.
package gcp
import (
"fmt"
namer "github.com/davidmontoyago/commodity-namer"
vertexmodeldeployment "github.com/davidmontoyago/pulumi-gcp-vertex-model-deployment/sdk/go/pulumi-gcp-vertex-model-deployment/resources"
"github.com/pulumi/pulumi-gcp/sdk/v8/go/gcp/artifactregistry"
"github.com/pulumi/pulumi-gcp/sdk/v8/go/gcp/projects"
"github.com/pulumi/pulumi-gcp/sdk/v8/go/gcp/storage"
v1 "github.com/pulumi/pulumi-google-native/sdk/go/google/aiplatform/v1"
"github.com/pulumi/pulumi/sdk/v3/go/pulumi"
)
// AIBatch represents a GCP Vertex AI Model Deployment and Batch Prediction Job.
type AIBatch struct {
pulumi.ResourceState
namer.Namer
Project string
Region string
ModelImageURL pulumi.StringOutput
ModelDir string
ModelName string
ModelPredictionInputSchemaPath string
ModelPredictionOutputSchemaPath string
ModelPredictionBehaviorSchemaPath string
ModelBucketBasePath string
MachineType pulumi.StringOutput
JobDisplayName pulumi.StringOutput
ModelDisplayName pulumi.StringOutput
// Batch prediction job specific fields
InputDataPath pulumi.StringOutput
InputFormat pulumi.StringOutput
InputFileName pulumi.StringOutput
OutputDataPath pulumi.StringOutput
OutputFormat pulumi.StringOutput
StartingReplicaCount pulumi.IntOutput
MaxReplicaCount pulumi.IntOutput
BatchSize pulumi.IntOutput
AcceleratorType pulumi.StringOutput
AcceleratorCount pulumi.IntOutput
Labels map[string]string
inputDataLocalDir string
inputDataTargetDir string
retainJobOnDelete bool
// Core resources
modelServiceAccountEmail pulumi.StringOutput
batchPredictionJob *v1.BatchPredictionJob
artifactsBucket *storage.Bucket
modelDeployment *vertexmodeldeployment.VertexModelDeployment
uploadedModelFiles pulumi.StringArrayOutput
jobState pulumi.StringOutput
// IAM bindings for the model service account
iamMembers []*projects.IAMMember
repoIamMember *artifactregistry.RepositoryIamMember
}
// NewAIBatch creates a new AIBatch instance with the provided configuration.
func NewAIBatch(ctx *pulumi.Context, name string, args *AIBatchArgs, opts ...pulumi.ResourceOption) (*AIBatch, error) {
if args.Project == "" {
return nil, fmt.Errorf("project is required")
}
if args.Region == "" {
return nil, fmt.Errorf("region is required")
}
if args.ModelDir == "" && args.ModelName == "" {
return nil, fmt.Errorf("one of model directory or model name is required")
}
if args.ModelDir != "" {
if args.ModelPredictionInputSchemaPath == "" {
return nil, fmt.Errorf("model prediction input schema path is required")
}
if args.ModelPredictionOutputSchemaPath == "" {
return nil, fmt.Errorf("model prediction output schema path is required")
}
}
if args.ModelBucketBasePath == "" {
args.ModelBucketBasePath = "model"
}
// Model input data defaults
if args.InputDataPath == "" {
args.InputDataPath = "inputs"
}
if args.InputFileName == "" {
args.InputFileName = "*.jsonl"
}
if args.InputFormat == "" {
args.InputFormat = "jsonl"
}
AIBatch := &AIBatch{
Namer: namer.New(name, namer.WithReplace()),
Project: args.Project,
Region: args.Region,
ModelDir: args.ModelDir,
ModelName: args.ModelName,
ModelPredictionInputSchemaPath: args.ModelPredictionInputSchemaPath,
ModelPredictionOutputSchemaPath: args.ModelPredictionOutputSchemaPath,
ModelPredictionBehaviorSchemaPath: args.ModelPredictionBehaviorSchemaPath,
ModelBucketBasePath: args.ModelBucketBasePath,
// Default to the latest TensorFlow 2.15 CPU prediction container
ModelImageURL: setDefaultString(args.ModelImageURL, "us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-15:latest"),
MachineType: setDefaultString(args.MachineType, "n1-highmem-4"),
JobDisplayName: setDefaultString(args.JobDisplayName, name),
ModelDisplayName: setDefaultString(args.ModelDisplayName, name+"-model"),
// Model input data
InputDataPath: pulumi.String(args.InputDataPath).ToStringOutput(),
InputFormat: pulumi.String(args.InputFormat).ToStringOutput(),
InputFileName: pulumi.String(args.InputFileName).ToStringOutput(),
// Batch prediction job specific defaults
OutputDataPath: setDefaultString(args.OutputDataPath, "predictions/"),
OutputFormat: setDefaultString(args.OutputFormat, "jsonl"),
StartingReplicaCount: setDefaultInt(args.StartingReplicaCount, 1),
MaxReplicaCount: setDefaultInt(args.MaxReplicaCount, 3),
BatchSize: setDefaultInt(args.BatchSize, 0), // 0 means auto-configure
AcceleratorType: setDefaultString(args.AcceleratorType, "ACCELERATOR_TYPE_UNSPECIFIED"),
AcceleratorCount: setDefaultInt(args.AcceleratorCount, 1),
Labels: args.Labels,
// Initial job state until we create the job
jobState: pulumi.String("").ToStringOutput(),
inputDataLocalDir: args.InputDataPath,
inputDataTargetDir: "inputs", // Upload input data to a separate "inputs" directory in bucket
retainJobOnDelete: args.RetainJobOnDelete,
}
err := ctx.RegisterComponentResource("pulumi-ai-batch:gcp:AIBatch", name, AIBatch, opts...)
if err != nil {
return nil, fmt.Errorf("failed to register component resource: %w", err)
}
// Deploy the infrastructure
err = AIBatch.deploy(ctx, args)
if err != nil {
return nil, fmt.Errorf("failed to deploy AI batch: %w", err)
}
// Prepare resource outputs
outputs := pulumi.Map{
"vertex_ai_batch_model_service_account_email": AIBatch.modelServiceAccountEmail,
"vertex_ai_batch_job_id": AIBatch.batchPredictionJob.ID(),
"vertex_ai_batch_job_name": AIBatch.batchPredictionJob.Name,
"vertex_ai_batch_job_display_name": AIBatch.batchPredictionJob.DisplayName,
"vertex_ai_batch_job_state": AIBatch.batchPredictionJob.State,
"vertex_ai_batch_artifacts_bucket_name": AIBatch.artifactsBucket.Name,
"vertex_ai_batch_uploaded_model_files": AIBatch.uploadedModelFiles,
"vertex_ai_batch_input_data_uri": AIBatch.InputDataPath,
"vertex_ai_batch_output_data_uri_prefix": AIBatch.OutputDataPath,
}
// Add model deployment specific outputs only if model deployment exists
if AIBatch.modelDeployment != nil {
outputs["vertex_ai_batch_model_image_url"] = AIBatch.modelDeployment.ModelImageUrl
outputs["vertex_ai_batch_model_artifacts_bucket_uri"] = AIBatch.modelDeployment.ModelArtifactsBucketUri
outputs["vertex_ai_batch_model_deployment_id"] = AIBatch.modelDeployment.ID()
outputs["vertex_ai_batch_deployed_model_id"] = AIBatch.modelDeployment.DeployedModelId
outputs["vertex_ai_batch_model_prediction_input_schema_uri"] = AIBatch.modelDeployment.ModelPredictionInputSchemaUri
outputs["vertex_ai_batch_model_prediction_output_schema_uri"] = AIBatch.modelDeployment.ModelPredictionOutputSchemaUri
outputs["vertex_ai_batch_model_prediction_behavior_schema_uri"] = AIBatch.modelDeployment.ModelPredictionBehaviorSchemaUri
}
err = ctx.RegisterResourceOutputs(AIBatch, outputs)
if err != nil {
return nil, fmt.Errorf("failed to register resource outputs: %w", err)
}
return AIBatch, nil
}
// deploy provisions all the resources for the Vertex AI Batch Prediction Job.
func (v *AIBatch) deploy(ctx *pulumi.Context, args *AIBatchArgs) error {
isCustomModel := args.ModelDir != ""
var modelServiceAccountEmail pulumi.StringOutput
if isCustomModel {
// Custom model. Run it with custom GSA.
// Create service account for the model deployment
modelServiceAccountEmail, iamMembers, repoIamMember, err := v.setupCustomModelIAM(ctx, args)
if err != nil {
return fmt.Errorf("failed to setup custom model IAM: %w", err)
}
v.modelServiceAccountEmail = modelServiceAccountEmail
v.iamMembers = iamMembers
v.repoIamMember = repoIamMember
}
// Else, it's a model from the garden. We have to run it with the default agent GSA,
// otherwise the internal endpoint automation fails with missing permissions
// ('storage.objects.list') error on bucket "vertex-model-garden-restricted-us".
// Upload model artifacts (including schemas) to bucket
modelArtifactsURI, uploadedModelArtifacts, err := v.setupModelBucket(ctx, args.ModelDir, args.ModelBucketBasePath, args.Labels)
if err != nil {
return fmt.Errorf("failed to upload model to bucket: %w", err)
}
// Upload input data to bucket
inputDataBucketURI, uploadedDataObjects, err := v.uploadInputDataToBucket(ctx, v.inputDataLocalDir, v.inputDataTargetDir)
if err != nil {
return fmt.Errorf("failed to upload input data to bucket: %w", err)
}
// Collect uploaded data file names for outputs
v.uploadedModelFiles = collectBucketObjectNames(uploadedModelArtifacts, uploadedDataObjects)
var modelDeployment *vertexmodeldeployment.VertexModelDeployment
if isCustomModel {
// Upload the model to the model registry and get a model ID for the job
modelDeployment, err = v.deployModel(ctx, modelArtifactsURI, modelServiceAccountEmail, uploadedModelArtifacts)
if err != nil {
return fmt.Errorf("failed to deploy model /o\\: %w", err)
}
v.modelDeployment = modelDeployment
}
// Create the batch prediction job
batchPredictionJob, err := v.createBatchPredictionJob(ctx, modelDeployment, inputDataBucketURI, modelServiceAccountEmail)
if err != nil {
return fmt.Errorf("failed to create batch prediction job: %w", err)
}
// track the job state to retry on failure
v.batchPredictionJob = batchPredictionJob
v.jobState = batchPredictionJob.State
return nil
}
// Getter methods for accessing internal resources
// GetModelServiceAccountEmail returns the model service account email.
func (v *AIBatch) GetModelServiceAccountEmail() pulumi.StringOutput {
return v.modelServiceAccountEmail
}
// GetBatchPredictionJob returns the Vertex AI Batch Prediction Job resource.
func (v *AIBatch) GetBatchPredictionJob() *v1.BatchPredictionJob {
return v.batchPredictionJob
}
// GetModelDeployment returns the Vertex AI Model Deployment resource.
func (v *AIBatch) GetModelDeployment() *vertexmodeldeployment.VertexModelDeployment {
return v.modelDeployment
}
// GetIAMMembers returns the IAM member resources.
func (v *AIBatch) GetIAMMembers() []*projects.IAMMember {
return v.iamMembers
}
// GetUploadedModelArtifacts returns the array of uploaded model artifact names.
func (v *AIBatch) GetUploadedModelArtifacts() pulumi.StringArrayOutput {
return v.uploadedModelFiles
}
// Package config provides an environment config helper
package config
import (
"fmt"
"log"
"github.com/kelseyhightower/envconfig"
"github.com/pulumi/pulumi/sdk/v3/go/pulumi"
"github.com/davidmontoyago/pulumi-gcp-ai-batch/pkg/gcp"
)
// Config allows setting the vertex batch prediction job configuration via environment variables
type Config struct {
GCPProject string `envconfig:"GCP_PROJECT" required:"true"`
GCPRegion string `envconfig:"GCP_REGION" required:"true"`
ModelDir string `envconfig:"MODEL_DIR" required:"false"`
ModelName string `envconfig:"MODEL_NAME" required:"false"`
ModelPredictionInputSchemaPath string `envconfig:"MODEL_PREDICTION_INPUT_SCHEMA_PATH" required:"false"`
ModelPredictionOutputSchemaPath string `envconfig:"MODEL_PREDICTION_OUTPUT_SCHEMA_PATH" required:"false"`
ModelPredictionBehaviorSchemaPath string `envconfig:"MODEL_PREDICTION_BEHAVIOR_SCHEMA_PATH" default:""`
ModelBucketBasePath string `envconfig:"MODEL_BUCKET_BASE_PATH" default:"model/"`
ModelImageURL string `envconfig:"MODEL_IMAGE_URL" default:"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-15:latest"`
EnablePrivateRegistryAccess bool `envconfig:"ENABLE_PRIVATE_REGISTRY_ACCESS" default:"false"`
MachineType string `envconfig:"MACHINE_TYPE" default:"n1-standard-2"`
JobDisplayName string `envconfig:"JOB_DISPLAY_NAME" default:""`
ModelDisplayName string `envconfig:"MODEL_DISPLAY_NAME" default:""`
// Batch prediction job specific configuration
InputDataURI string `envconfig:"INPUT_DATA_URI" default:"inputs/"`
InputFileName string `envconfig:"INPUT_FILE_NAME" default:"*.jsonl"`
InputFormat string `envconfig:"INPUT_FORMAT" default:"jsonl"`
OutputDataURIPrefix string `envconfig:"OUTPUT_DATA_URI_PREFIX" default:"predictions/"`
OutputFormat string `envconfig:"OUTPUT_FORMAT" default:"jsonl"`
StartingReplicaCount int `envconfig:"STARTING_REPLICA_COUNT" default:"1"`
MaxReplicaCount int `envconfig:"MAX_REPLICA_COUNT" default:"3"`
BatchSize int `envconfig:"BATCH_SIZE" default:"0"`
AcceleratorType string `envconfig:"ACCELERATOR_TYPE" default:"ACCELERATOR_TYPE_UNSPECIFIED"`
AcceleratorCount int `envconfig:"ACCELERATOR_COUNT" default:"1"`
RetainJobOnDelete bool `envconfig:"RETAIN_JOB_ON_DELETE" default:"false"`
}
// LoadConfig loads configuration from environment variables
// All required environment variables must be set or will cause an error
func LoadConfig() (*Config, error) {
var config Config
err := envconfig.Process("", &config)
if err != nil {
return nil, fmt.Errorf("failed to load configuration from environment variables: %w", err)
}
log.Printf("Configuration loaded successfully:")
log.Printf(" GCP Project: %s", config.GCPProject)
log.Printf(" GCP Region: %s", config.GCPRegion)
log.Printf(" Model Dir: %s", config.ModelDir)
log.Printf(" Model Name: %s", config.ModelName)
log.Printf(" Model Prediction Input Schema Path: %s", config.ModelPredictionInputSchemaPath)
log.Printf(" Model Prediction Output Schema Path: %s", config.ModelPredictionOutputSchemaPath)
log.Printf(" Model Prediction Behavior Schema Path: %s", config.ModelPredictionBehaviorSchemaPath)
log.Printf(" Model Bucket Base Path: %s", config.ModelBucketBasePath)
log.Printf(" Model Image URL: %s", config.ModelImageURL)
log.Printf(" Enable Private Registry Access: %t", config.EnablePrivateRegistryAccess)
log.Printf(" Machine Type: %s", config.MachineType)
log.Printf(" Job Display Name: %s", config.JobDisplayName)
log.Printf(" Model Display Name: %s", config.ModelDisplayName)
log.Printf(" Input Data URI: %s", config.InputDataURI)
log.Printf(" Input File Name: %s", config.InputFileName)
log.Printf(" Input Format: %s", config.InputFormat)
log.Printf(" Output Data URI Prefix: %s", config.OutputDataURIPrefix)
log.Printf(" Output Format: %s", config.OutputFormat)
log.Printf(" Starting Replica Count: %d", config.StartingReplicaCount)
log.Printf(" Max Replica Count: %d", config.MaxReplicaCount)
log.Printf(" Batch Size: %d", config.BatchSize)
log.Printf(" Accelerator Type: %s", config.AcceleratorType)
log.Printf(" Accelerator Count: %d", config.AcceleratorCount)
log.Printf(" Retain Job On Delete: %t", config.RetainJobOnDelete)
return &config, nil
}
// ToAIBatchArgs converts the config to AIBatchArgs for use with the Pulumi component
func (c *Config) ToAIBatchArgs() *gcp.AIBatchArgs {
args := &gcp.AIBatchArgs{
Project: c.GCPProject,
Region: c.GCPRegion,
ModelDir: c.ModelDir,
ModelName: c.ModelName,
ModelPredictionInputSchemaPath: c.ModelPredictionInputSchemaPath,
ModelPredictionOutputSchemaPath: c.ModelPredictionOutputSchemaPath,
ModelBucketBasePath: c.ModelBucketBasePath,
ModelImageURL: pulumi.String(c.ModelImageURL),
MachineType: pulumi.String(c.MachineType),
EnablePrivateRegistryAccess: c.EnablePrivateRegistryAccess,
// Batch prediction job specific fields
InputDataPath: c.InputDataURI,
InputFormat: c.InputFormat,
InputFileName: c.InputFileName,
OutputDataPath: pulumi.String(c.OutputDataURIPrefix),
OutputFormat: pulumi.String(c.OutputFormat),
StartingReplicaCount: pulumi.Int(c.StartingReplicaCount),
MaxReplicaCount: pulumi.Int(c.MaxReplicaCount),
BatchSize: pulumi.Int(c.BatchSize),
AcceleratorType: pulumi.String(c.AcceleratorType),
AcceleratorCount: pulumi.Int(c.AcceleratorCount),
RetainJobOnDelete: c.RetainJobOnDelete,
}
// Set optional fields only if provided
if c.JobDisplayName != "" {
args.JobDisplayName = pulumi.String(c.JobDisplayName)
}
if c.ModelDisplayName != "" {
args.ModelDisplayName = pulumi.String(c.ModelDisplayName)
}
if c.ModelPredictionBehaviorSchemaPath != "" {
args.ModelPredictionBehaviorSchemaPath = c.ModelPredictionBehaviorSchemaPath
}
return args
}
package gcp
import (
"fmt"
"time"
vertexmodeldeployment "github.com/davidmontoyago/pulumi-gcp-vertex-model-deployment/sdk/go/pulumi-gcp-vertex-model-deployment/resources"
v1 "github.com/pulumi/pulumi-google-native/sdk/go/google/aiplatform/v1"
"github.com/pulumi/pulumi/sdk/v3/go/pulumi"
)
// createBatchPredictionJob creates a Vertex AI Batch Prediction Job.
func (v *AIBatch) createBatchPredictionJob(ctx *pulumi.Context,
modelDeployment *vertexmodeldeployment.VertexModelDeployment,
inputDataBucketURI pulumi.StringOutput,
serviceAccountEmail pulumi.StringOutput) (*v1.BatchPredictionJob, error) {
dependencies := []pulumi.Resource{v.artifactsBucket}
var modelName pulumi.StringOutput
isCustomModel := modelDeployment != nil
if isCustomModel {
dependencies = append(dependencies, modelDeployment)
modelName = modelDeployment.ModelName
} else {
// if no model deployment, it's a model from the garden
modelName = pulumi.String(v.ModelName).ToStringOutput()
}
if v.repoIamMember != nil {
// wait for IAM binding to access a private registry
dependencies = append(dependencies, v.repoIamMember)
}
// Construct the input config
inputConfig := &v1.GoogleCloudAiplatformV1BatchPredictionJobInputConfigArgs{
InstancesFormat: v.InputFormat,
GcsSource: &v1.GoogleCloudAiplatformV1GcsSourceArgs{
Uris: pulumi.StringArray{
// URI to the data just uploaded by this component
pulumi.Sprintf("%s/%s", inputDataBucketURI, v.InputFileName),
},
},
}
// Construct the output config
outputConfig := &v1.GoogleCloudAiplatformV1BatchPredictionJobOutputConfigArgs{
PredictionsFormat: v.OutputFormat,
GcsDestination: &v1.GoogleCloudAiplatformV1GcsDestinationArgs{
OutputUriPrefix: pulumi.Sprintf("gs://%s/%s", v.artifactsBucket.Name, v.OutputDataPath),
},
}
// Construct dedicated resources for the job
dedicatedResources := &v1.GoogleCloudAiplatformV1BatchDedicatedResourcesArgs{
MachineSpec: &v1.GoogleCloudAiplatformV1MachineSpecArgs{
MachineType: v.MachineType,
AcceleratorCount: v.AcceleratorCount,
AcceleratorType: v.AcceleratorType.ApplyT(func(accelType string) v1.GoogleCloudAiplatformV1MachineSpecAcceleratorType {
return v1.GoogleCloudAiplatformV1MachineSpecAcceleratorType(accelType)
}).(v1.GoogleCloudAiplatformV1MachineSpecAcceleratorTypeOutput),
},
StartingReplicaCount: v.StartingReplicaCount,
MaxReplicaCount: v.MaxReplicaCount,
}
batchJobArgs := &v1.BatchPredictionJobArgs{
Project: pulumi.String(v.Project),
Location: pulumi.String(v.Region),
DisplayName: v.JobDisplayName,
Model: modelName, // Use the deployed model name or the name of a model from the garden
InputConfig: inputConfig,
OutputConfig: outputConfig,
DedicatedResources: dedicatedResources,
ManualBatchTuningParameters: &v1.GoogleCloudAiplatformV1ManualBatchTuningParametersArgs{
BatchSize: v.BatchSize,
},
Labels: pulumi.ToStringMap(v.Labels),
}
if isCustomModel {
batchJobArgs.ServiceAccount = serviceAccountEmail
}
// every pulumi up operation is a new launch
jobName := fmt.Sprintf("%s-%d", v.NewResourceName("batch-prediction-job", "", 63), time.Now().UnixMilli())
batchPredictionJob, err := v1.NewBatchPredictionJob(ctx,
jobName,
batchJobArgs,
pulumi.Parent(v),
pulumi.DependsOn(dependencies),
pulumi.RetainOnDelete(v.retainJobOnDelete),
)
if err != nil {
return nil, fmt.Errorf("failed to create batch prediction job: %w", err)
}
return batchPredictionJob, nil
}
package gcp
import (
"fmt"
"mime"
"os"
"path/filepath"
"strings"
"github.com/pulumi/pulumi-gcp/sdk/v8/go/gcp/storage"
"github.com/pulumi/pulumi/sdk/v3/go/pulumi"
)
// uploadDirectoryToBucket traverses a directory and uploads all files to a GCS bucket.
func (v *AIBatch) uploadDirectoryToBucket(ctx *pulumi.Context, localDir, baseObjectPath string) ([]pulumi.Resource, error) {
if localDir == "" {
// no model artifacts to upload. skip
return []pulumi.Resource{}, nil
}
var bucketObjects []*storage.BucketObject
err := filepath.Walk(localDir, func(filePath string, info os.FileInfo, err error) error {
if err != nil {
return fmt.Errorf("error walking path %s: %w", filePath, err)
}
// Skip directories
if info.IsDir() {
return nil
}
// Skip hidden files and system files
if strings.HasPrefix(info.Name(), ".") {
return nil
}
// Calculate relative path from the base directory to preserve directory structure
relPath, err := filepath.Rel(localDir, filePath)
if err != nil {
return fmt.Errorf("error calculating relative path: %w", err)
}
// Convert to GCS object key (this preserves the original filename and path structure)
gcsObjectName := strings.ReplaceAll(relPath, string(filepath.Separator), "/")
// Detect content type
contentType := detectContentType(filePath)
// Create a unique resource name by replacing path separators with hyphens
resourceName := fmt.Sprintf("file-%s", strings.ReplaceAll(gcsObjectName, "/", "-"))
resourceName = strings.ReplaceAll(resourceName, ".", "-")
// Prepend the base object path if provided
if baseObjectPath != "" {
gcsObjectName = filepath.Join(baseObjectPath, gcsObjectName)
gcsObjectName = strings.ReplaceAll(gcsObjectName, string(filepath.Separator), "/")
}
// Create BucketObject resource
bucketObject, err := storage.NewBucketObject(ctx, resourceName, &storage.BucketObjectArgs{
Name: pulumi.String(gcsObjectName),
Bucket: v.artifactsBucket.Name,
Source: pulumi.NewFileAsset(filePath),
ContentType: pulumi.String(contentType),
}, pulumi.Parent(v))
if err != nil {
return fmt.Errorf("error creating bucket object for %s: %w", filePath, err)
}
bucketObjects = append(bucketObjects, bucketObject)
return nil
})
if err != nil {
return nil, fmt.Errorf("error uploading directory %s: %w", localDir, err)
}
uploadedResources := make([]pulumi.Resource, len(bucketObjects))
for i, bucketObject := range bucketObjects {
uploadedResources[i] = bucketObject
}
return uploadedResources, nil
}
// detectContentType determines the MIME type of a file based on its extension
func detectContentType(filePath string) string {
ext := filepath.Ext(filePath)
contentType := mime.TypeByExtension(ext)
if contentType == "" {
// Default to binary if type cannot be determined
contentType = "application/octet-stream"
}
return contentType
}
// setupModelBucket creates a bucket for model artifacts and uploads the model directory if any.
// It returns the GCS URI of the uploaded model artifacts and the uploaded objects for dependency tracking.
func (v *AIBatch) setupModelBucket(ctx *pulumi.Context, modelDir string, modelBucketBasePath string, labels map[string]string) (pulumi.StringOutput, []pulumi.Resource, error) {
// Create the bucket for model artifacts
bucketName := v.NewResourceName("vertex-model", "bucket", 63)
// Merge default labels with provided labels
bucketLabels := pulumi.StringMap{
"purpose": pulumi.String("model-storage"),
}
// Add user-provided labels
for key, value := range labels {
bucketLabels[key] = pulumi.String(value)
}
artifactsBucket, err := storage.NewBucket(ctx, bucketName, &storage.BucketArgs{
Name: pulumi.String(bucketName),
Location: pulumi.String(v.Region),
Project: pulumi.String(v.Project),
ForceDestroy: pulumi.Bool(true), // Model data is part of the pipeline, safe to implode.
// Enable Uniform Bucket Level Access (UBLA) for enhanced security
// This is required for SBOMs and prevents ACL-based access control
UniformBucketLevelAccess: pulumi.Bool(true),
Versioning: &storage.BucketVersioningArgs{
Enabled: pulumi.Bool(true), // Enable versioning for audit trail
},
Labels: bucketLabels,
}, pulumi.Parent(v))
if err != nil {
return pulumi.StringOutput{}, nil, fmt.Errorf("failed to create artifacts bucket: %w", err)
}
v.artifactsBucket = artifactsBucket
// No luck with https://github.com/pulumi/pulumi-synced-folder /o\
// Upload the model artifacts, if any
uploadedObjects, err := v.uploadDirectoryToBucket(ctx, modelDir, modelBucketBasePath)
if err != nil {
return pulumi.StringOutput{}, nil, fmt.Errorf("failed to upload model artifacts: %w", err)
}
modelArtifactsURI := pulumi.Sprintf("gs://%s/%s", artifactsBucket.Name, modelBucketBasePath)
return modelArtifactsURI, uploadedObjects, nil
}
// uploadInputDataToBucket uploads the input data to the bucket.
func (v *AIBatch) uploadInputDataToBucket(ctx *pulumi.Context, inputDataDir string, inputDataBasePath string) (pulumi.StringOutput, []pulumi.Resource, error) {
uploadedDataObjects, err := v.uploadDirectoryToBucket(ctx, inputDataDir, inputDataBasePath)
if err != nil {
return pulumi.StringOutput{}, nil, fmt.Errorf("failed to upload input data to bucket: %w", err)
}
inputDataBucketURI := pulumi.Sprintf("gs://%s/%s", v.artifactsBucket.Name, inputDataBasePath)
return inputDataBucketURI, uploadedDataObjects, nil
}
// collectBucketObjectNames collects the names of the uploaded model artifacts and data objects.
func collectBucketObjectNames(
uploadedModelArtifacts []pulumi.Resource,
uploadedDataObjects []pulumi.Resource,
) pulumi.StringArrayOutput {
uploadedObjectNames := pulumi.StringArray{}
for _, resource := range uploadedModelArtifacts {
if bucketObject, ok := resource.(*storage.BucketObject); ok {
uploadedObjectNames = append(uploadedObjectNames, bucketObject.Name.ApplyT(func(name string) string {
return name
}).(pulumi.StringOutput))
}
}
for _, resource := range uploadedDataObjects {
if bucketObject, ok := resource.(*storage.BucketObject); ok {
uploadedObjectNames = append(uploadedObjectNames, bucketObject.Name.ApplyT(func(name string) string {
return name
}).(pulumi.StringOutput))
}
}
return uploadedObjectNames.ToStringArrayOutput()
}
package gcp
import (
vertexmodeldeployment "github.com/davidmontoyago/pulumi-gcp-vertex-model-deployment/sdk/go/pulumi-gcp-vertex-model-deployment/resources"
"github.com/pulumi/pulumi/sdk/v3/go/pulumi"
)
// deployModel deploys the model to Vertex AI
// for batch prediction jobs, we only need the model, not an endpoint
func (v *AIBatch) deployModel(ctx *pulumi.Context, modelArtifactsURI pulumi.StringOutput, serviceAccountEmail pulumi.StringOutput, uploadedObjects []pulumi.Resource) (*vertexmodeldeployment.VertexModelDeployment, error) {
modelDeploymentArgs := &vertexmodeldeployment.VertexModelDeploymentArgs{
ProjectId: pulumi.String(v.Project),
Region: pulumi.String(v.Region),
ModelArtifactsBucketUri: modelArtifactsURI,
ModelImageUrl: v.ModelImageURL,
ModelPredictionInputSchemaUri: pulumi.Sprintf("%s/%s", modelArtifactsURI, v.ModelPredictionInputSchemaPath),
ModelPredictionOutputSchemaUri: pulumi.Sprintf("%s/%s", modelArtifactsURI, v.ModelPredictionOutputSchemaPath),
ServiceAccount: serviceAccountEmail,
// TODO make me configurable
PredictRoute: pulumi.String("/predict"),
HealthRoute: pulumi.String("/health"),
}
if v.ModelPredictionBehaviorSchemaPath != "" {
modelDeploymentArgs.ModelPredictionBehaviorSchemaUri = pulumi.Sprintf("%s/%s", modelArtifactsURI, v.ModelPredictionBehaviorSchemaPath)
}
// Include dependencies on both the artifacts bucket and uploaded model artifacts
dependencies := []pulumi.Resource{v.artifactsBucket}
dependencies = append(dependencies, uploadedObjects...)
return vertexmodeldeployment.NewVertexModelDeployment(ctx,
v.NewResourceName("vertex-model-deployment", "", 63),
modelDeploymentArgs,
pulumi.Parent(v),
pulumi.DependsOn(dependencies),
)
}
package gcp
import (
"fmt"
"github.com/pulumi/pulumi-gcp/sdk/v8/go/gcp/artifactregistry"
"github.com/pulumi/pulumi-gcp/sdk/v8/go/gcp/projects"
"github.com/pulumi/pulumi-gcp/sdk/v8/go/gcp/serviceaccount"
"github.com/pulumi/pulumi/sdk/v3/go/pulumi"
)
// grantModelIAMRoles grants necessary IAM roles to the model service account.
func (v *AIBatch) grantModelIAMRoles(ctx *pulumi.Context, serviceAccountEmail pulumi.StringOutput) ([]*projects.IAMMember, error) {
// IAM roles specific to what the batch prediction job needs to operate
roles := []string{
"roles/storage.bucketViewer", // List and get buckets
"roles/storage.objectCreator", // For writing prediction results to GCS
"roles/logging.logWriter", // For writing logs during prediction
"roles/monitoring.metricWriter", // For writing custom metrics
"roles/aiplatform.user", // For accessing Vertex AI resources
}
iamMembers := make([]*projects.IAMMember, len(roles))
for roleIndex, role := range roles {
bindingName := v.NewResourceName(fmt.Sprintf("model-sa-iam-%s", role), "", 63)
member, err := projects.NewIAMMember(ctx, bindingName, &projects.IAMMemberArgs{
Project: pulumi.String(v.Project),
Role: pulumi.String(role),
Member: pulumi.Sprintf("serviceAccount:%s", serviceAccountEmail),
}, pulumi.Parent(v))
if err != nil {
return nil, fmt.Errorf("failed to create IAM member for role %s: %w", role, err)
}
iamMembers[roleIndex] = member
}
return iamMembers, nil
}
// createModelServiceAccount creates a service account for Vertex AI operations.
func (v *AIBatch) createModelServiceAccount(ctx *pulumi.Context) (pulumi.StringOutput, error) {
accountID := v.NewResourceName("model-account", "", 30)
modelServiceAccount, err := serviceaccount.NewAccount(ctx, v.NewResourceName("model-account", "", 63), &serviceaccount.AccountArgs{
Project: pulumi.String(v.Project),
AccountId: pulumi.String(accountID),
DisplayName: pulumi.Sprintf("%s Vertex AI Service Account", v.ModelDisplayName),
Description: pulumi.String("Service account for deployed model operations"),
}, pulumi.Parent(v))
if err != nil {
return pulumi.StringOutput{}, fmt.Errorf("failed to create model service account: %w", err)
}
return modelServiceAccount.Email, nil
}
func (v *AIBatch) setupCustomModelIAM(ctx *pulumi.Context, args *AIBatchArgs) (pulumi.StringOutput, []*projects.IAMMember, *artifactregistry.RepositoryIamMember, error) {
modelServiceAccountEmail, err := v.createModelServiceAccount(ctx)
if err != nil {
return pulumi.StringOutput{}, nil, nil, fmt.Errorf("failed to create model service account: %w", err)
}
// Grant necessary IAM roles to the model service account
iamMembers, err := v.grantModelIAMRoles(ctx, modelServiceAccountEmail)
if err != nil {
return pulumi.StringOutput{}, nil, nil, fmt.Errorf("failed to grant model IAM roles: %w", err)
}
var repoIamMember *artifactregistry.RepositoryIamMember
if args.EnablePrivateRegistryAccess {
repoIamMember, err = v.grantRegistryIAMAccess(ctx, modelServiceAccountEmail)
if err != nil {
return pulumi.StringOutput{}, nil, nil, fmt.Errorf("failed to grant registry IAM access: %w", err)
}
}
return modelServiceAccountEmail, iamMembers, repoIamMember, nil
}
package gcp
import (
"github.com/pulumi/pulumi/sdk/v3/go/pulumi"
)
// Helper functions for setting default values
func setDefaultString(input pulumi.StringInput, defaultValue string) pulumi.StringOutput {
if input == nil {
return pulumi.String(defaultValue).ToStringOutput()
}
return input.ToStringOutput()
}
func setDefaultInt(input pulumi.IntInput, defaultValue int) pulumi.IntOutput {
if input == nil {
return pulumi.Int(defaultValue).ToIntOutput()
}
return input.ToIntOutput()
}
package gcp
import (
"fmt"
"strings"
"github.com/pulumi/pulumi-gcp/sdk/v8/go/gcp/artifactregistry"
"github.com/pulumi/pulumi/sdk/v3/go/pulumi"
)
// grantRegistryIAMAccess grants the SA access to the registry source of the model docker image.
func (v *AIBatch) grantRegistryIAMAccess(ctx *pulumi.Context, serviceAccountEmail pulumi.StringOutput) (*artifactregistry.RepositoryIamMember, error) {
modelImageRepoName := v.ModelImageURL.ApplyT(func(url string) string {
return strings.Split(url, "/")[2]
}).(pulumi.StringOutput)
bindingName := v.NewResourceName("model-registry-access", "iam-member", 63)
repoMember, err := artifactregistry.NewRepositoryIamMember(ctx, bindingName, &artifactregistry.RepositoryIamMemberArgs{
Repository: modelImageRepoName,
Location: pulumi.String(v.Region),
Project: pulumi.String(v.Project),
Role: pulumi.String("roles/artifactregistry.reader"),
Member: pulumi.Sprintf("serviceAccount:%s", serviceAccountEmail),
}, pulumi.Parent(v))
if err != nil {
return nil, fmt.Errorf("failed to grant registry IAM access: %w", err)
}
return repoMember, nil
}