diff --git a/sagemaker_studio_image_build/builder.py b/sagemaker_studio_image_build/builder.py index 18e8364..8bb136c 100644 --- a/sagemaker_studio_image_build/builder.py +++ b/sagemaker_studio_image_build/builder.py @@ -63,14 +63,14 @@ def delete_zip_file(bucket, key): s3 = boto3.session.Session().client("s3") s3.delete_object(Bucket=bucket, Key=key) - -def build_image(repository, role, bucket, compute_type, vpc_config, extra_args, log=True): +def build_image(repository, role, bucket, compute_type, vpc_config, environment, extra_args, log=True): bucket, key = upload_zip_file(repository, bucket, " ".join(extra_args)) try: from sagemaker_studio_image_build.codebuild import TempCodeBuildProject with TempCodeBuildProject(f"{bucket}/{key}", role, repository=repository, - compute_type=compute_type, vpc_config=vpc_config) as p: + compute_type=compute_type, vpc_config=vpc_config, + environment=environment) as p: p.build(log) finally: delete_zip_file(bucket, key) diff --git a/sagemaker_studio_image_build/cli.py b/sagemaker_studio_image_build/cli.py index 146ccd6..b395916 100644 --- a/sagemaker_studio_image_build/cli.py +++ b/sagemaker_studio_image_build/cli.py @@ -72,7 +72,7 @@ def build_image(args, extra_args): builder.build_image( args.repository, get_role(args), args.bucket, args.compute_type, - construct_vpc_config(args), extra_args, log=not args.no_logs + construct_vpc_config(args), args.environment, extra_args, log=not args.no_logs ) @@ -86,19 +86,25 @@ def main(): build_parser = subparsers.add_parser( "build", - help="Use AWS CodeBuild to build a Docker image and push to Amazon ECR", + help="Use AWS CodeBuild to build a Docker image and push to Amazon ECR.", ) build_parser.add_argument( "--repository", - help="The ECR repository:tag for the image (default: sagemaker-studio-${domain_id}:latest)", + help="The ECR repository:tag for the image (default: sagemaker-studio-${domain_id}:latest).", ) build_parser.add_argument( "--compute-type", - help="The CodeBuild compute type (default: BUILD_GENERAL1_SMALL)", + help="The CodeBuild compute type (default: BUILD_GENERAL1_SMALL) set to BUILD_GENERAL1_LARGE for LINUX_GPU_CONTAINER environment.", choices=["BUILD_GENERAL1_SMALL", "BUILD_GENERAL1_MEDIUM", "BUILD_GENERAL1_LARGE", "BUILD_GENERAL1_2XLARGE"], default="BUILD_GENERAL1_SMALL" ) + build_parser.add_argument( + "--environment", + help="The CodeBuild environment (default: LINUX_CONTAINER).", + choices=["LINUX_CONTAINER", "LINUX_GPU_CONTAINER"], + default="LINUX_CONTAINER" + ) build_parser.add_argument( "--role", help=f"The IAM role name for CodeBuild to use (default: the Studio execution role).", diff --git a/sagemaker_studio_image_build/codebuild.py b/sagemaker_studio_image_build/codebuild.py index a9c99b7..0e77566 100644 --- a/sagemaker_studio_image_build/codebuild.py +++ b/sagemaker_studio_image_build/codebuild.py @@ -11,14 +11,18 @@ class TempCodeBuildProject: - def __init__(self, s3_location, role, repository=None, compute_type=None, vpc_config=None): + def __init__(self, s3_location, role, repository=None, compute_type=None, vpc_config=None, environment=None): self.s3_location = s3_location self.role = role self.session = boto3.session.Session() self.domain_id, self.user_profile_name = self._get_studio_metadata() self.repo_name = None - self.compute_type = compute_type or "BUILD_GENERAL1_SMALL" + self.compute_type = compute_type + self.environment = environment + if self.environment=="LINUX_GPU_CONTAINER": + assert self.compute_type=="BUILD_GENERAL1_LARGE", \ + "LINUX_GPU_CONTAINER builds only available on BUILD_GENERAL1_LARGE. Please set `--compute-type BUILD_GENERAL1_LARGE`" self.vpc_config = vpc_config if repository: @@ -62,7 +66,7 @@ def __enter__(self): "source": {"type": "S3", "location": self.s3_location}, "artifacts": {"type": "NO_ARTIFACTS"}, "environment": { - "type": "LINUX_CONTAINER", + "type": self.environment, "image": "aws/codebuild/standard:4.0", "computeType": self.compute_type, "environmentVariables": [