[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle AWS encryption details #12495

Merged
merged 10 commits into from
Jul 10, 2024
22 changes: 22 additions & 0 deletions mlflow/protos/databricks_uc_registry_messages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ message TemporaryCredentials {
optional int64 expiration_time = 1;

optional StorageMode storage_mode = 6;

optional EncryptionDetails encryption_details = 7;
}

// AWS temporary credentials for API authentication.
Expand Down Expand Up @@ -190,6 +192,26 @@ enum StorageMode {
DEFAULT_STORAGE = 2;
}

message EncryptionDetails {
oneof encryption_details_type {
// Details for CLOUD_MANAGED_SSE_KEYS:
artjen marked this conversation as resolved.
Show resolved Hide resolved
SseEncryptionDetails sse_encryption_details = 1;
}
}

enum SseEncryptionAlgorithm {
SSE_ENCRYPTION_ALGORITHM_UNSPECIFIED = 0;
AWS_SSE_KMS = 2; // "aws:kms" in x-amz-server-side-encryption' header
artjen marked this conversation as resolved.
Show resolved Hide resolved
}

message SseEncryptionDetails {
optional SseEncryptionAlgorithm algorithm = 1;

// Optional. The ARN of the SSE-KMS key used with the S3 location, when algorithm = "SSE-KMS".
// Sets the value of the 'x-amz-server-side-encryption-aws-kms-key-id' header.
optional string aws_kms_key_arn = 2;
}

message CreateRegisteredModelRequest {
option (scalapb.message).extends = "com.databricks.rpc.RPC[CreateRegisteredModelResponse]";

Expand Down
504 changes: 268 additions & 236 deletions mlflow/protos/databricks_uc_registry_messages_pb2.py

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions mlflow/store/artifact/optimized_s3_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
credential_refresh_def=None,
addressing_style=None,
s3_endpoint_url=None,
s3_upload_extra_args=None,
):
super().__init__(artifact_uri)
self._access_key_id = access_key_id
Expand All @@ -62,6 +63,7 @@ def __init__(
self._s3_endpoint_url = s3_endpoint_url
self.bucket, self.bucket_path = self.parse_s3_compliant_uri(self.artifact_uri)
self._region_name = self._get_region_name()
self._s3_upload_extra_args = s3_upload_extra_args if s3_upload_extra_args else {}

def _refresh_credentials(self):
if not self._credential_refresh_def:
Expand All @@ -70,6 +72,7 @@ def _refresh_credentials(self):
self._access_key_id = new_creds["access_key_id"]
self._secret_access_key = new_creds["secret_access_key"]
self._session_token = new_creds["session_token"]
self._s3_upload_extra_args = new_creds["s3_upload_extra_args"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this key guaranteed to be there?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes: the credential refresh function found here ensures this key is present.

return self._get_s3_client()

def _get_region_name(self):
Expand Down Expand Up @@ -146,6 +149,7 @@ def get_s3_file_upload_extra_args():

def _upload_file(self, s3_client, local_file, bucket, key):
extra_args = {}
extra_args.update(self._s3_upload_extra_args)
guessed_type, guessed_encoding = guess_type(local_file)
if guessed_type is not None:
extra_args["ContentType"] = guessed_type
Expand Down
2 changes: 2 additions & 0 deletions mlflow/store/artifact/r2_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(
secret_access_key=None,
session_token=None,
credential_refresh_def=None,
s3_upload_extra_args=None,
):
# setup Cloudflare R2 backend to be endpoint_url, otherwise all s3 requests
# will go to AWS S3 by default
Expand All @@ -31,6 +32,7 @@ def __init__(
credential_refresh_def=credential_refresh_def,
addressing_style="virtual",
s3_endpoint_url=s3_endpoint_url,
s3_upload_extra_args=s3_upload_extra_args,
)

# Cloudflare implementation of head_bucket is not the same as AWS's, so we
Expand Down
29 changes: 28 additions & 1 deletion mlflow/utils/_unity_catalog_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from mlflow.protos.databricks_uc_registry_messages_pb2 import (
RegisteredModelTag as ProtoRegisteredModelTag,
)
from mlflow.protos.databricks_uc_registry_messages_pb2 import TemporaryCredentials
from mlflow.protos.databricks_uc_registry_messages_pb2 import (
SseEncryptionAlgorithm,
TemporaryCredentials,
)
from mlflow.store.artifact.artifact_repo import ArtifactRepository

_STRING_TO_STATUS = {k: ProtoModelVersionStatus.Value(k) for k in ProtoModelVersionStatus.keys()}
Expand Down Expand Up @@ -155,14 +158,17 @@ def _get_artifact_repo_from_storage_info(
from mlflow.store.artifact.optimized_s3_artifact_repo import OptimizedS3ArtifactRepository

aws_creds = scoped_token.aws_temp_credentials
s3_upload_extra_args = _parse_aws_sse_credential(scoped_token)

def aws_credential_refresh():
new_scoped_token = base_credential_refresh_def()
new_aws_creds = new_scoped_token.aws_temp_credentials
new_s3_upload_extra_args = _parse_aws_sse_credential(new_scoped_token)
return {
"access_key_id": new_aws_creds.access_key_id,
"secret_access_key": new_aws_creds.secret_access_key,
"session_token": new_aws_creds.session_token,
"s3_upload_extra_args": new_s3_upload_extra_args,
}

return OptimizedS3ArtifactRepository(
Expand All @@ -171,6 +177,7 @@ def aws_credential_refresh():
secret_access_key=aws_creds.secret_access_key,
session_token=aws_creds.session_token,
credential_refresh_def=aws_credential_refresh,
s3_upload_extra_args=s3_upload_extra_args,
)
elif credential_type == "azure_user_delegation_sas":
from azure.core.credentials import AzureSasCredential
Expand Down Expand Up @@ -232,6 +239,26 @@ def r2_credential_refresh():
)


def _parse_aws_sse_credential(scoped_token: TemporaryCredentials):
encryption_details = scoped_token.encryption_details
if not encryption_details:
return {}

if encryption_details.WhichOneof("encryption_details_type") != "sse_encryption_details":
return {}

sse_encryption_details = encryption_details.sse_encryption_details

if sse_encryption_details.algorithm != SseEncryptionAlgorithm.AWS_SSE_KMS:
return {}

key_id = sse_encryption_details.aws_kms_key_arn.split("/")[-1]
return {
"ServerSideEncryption": "aws:kms",
"SSEKMSKeyId": key_id,
}


def get_full_name_from_sc(name, spark) -> str:
"""
Constructs the full name of a registered model using the active catalog and schema in a spark
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def test_create_model_version_with_langchain_dependencies(store, langchain_local
secret_access_key=secret_access_key,
session_token=session_token,
credential_refresh_def=ANY,
s3_upload_extra_args={},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any way to add a test/mock around this not being empty every time?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added tests in d8bd736

)
mock_artifact_repo.log_artifacts.assert_called_once_with(local_dir=ANY, artifact_path="")
_assert_create_model_version_endpoints_called(
Expand Down Expand Up @@ -377,6 +378,7 @@ def test_create_model_version_with_resources(store, langchain_local_model_dir_wi
secret_access_key=secret_access_key,
session_token=session_token,
credential_refresh_def=ANY,
s3_upload_extra_args={},
)
mock_artifact_repo.log_artifacts.assert_called_once_with(local_dir=ANY, artifact_path="")
_assert_create_model_version_endpoints_called(
Expand Down Expand Up @@ -434,6 +436,7 @@ def test_create_model_version_with_langchain_no_dependencies(
secret_access_key=secret_access_key,
session_token=session_token,
credential_refresh_def=ANY,
s3_upload_extra_args={},
)
mock_artifact_repo.log_artifacts.assert_called_once_with(local_dir=ANY, artifact_path="")
_assert_create_model_version_endpoints_called(
Expand Down Expand Up @@ -1019,6 +1022,7 @@ def test_create_model_version_aws(store, local_model_dir):
secret_access_key=secret_access_key,
session_token=session_token,
credential_refresh_def=ANY,
s3_upload_extra_args={},
)
mock_artifact_repo.log_artifacts.assert_called_once_with(local_dir=ANY, artifact_path="")
_assert_create_model_version_endpoints_called(
Expand Down
1 change: 1 addition & 0 deletions tests/store/artifact/test_optimized_s3_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def credential_refresh_def():
"access_key_id": "my-id-2",
"secret_access_key": "my-key-2",
"session_token": "my-session-2",
"s3_upload_extra_args": {},
}

repo = OptimizedS3ArtifactRepository(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def test_uc_models_artifact_repo_download_artifacts_uses_temporary_creds_aws(mon
secret_access_key=fake_secret_access_key,
session_token=fake_session_token,
credential_refresh_def=ANY,
s3_upload_extra_args={},
)
mock_s3_repo.download_artifacts.assert_called_once_with("artifact_path", "dst_path")
request_mock.assert_called_with(
Expand Down
41 changes: 41 additions & 0 deletions tests/utils/test_unity_catalog_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
)
from mlflow.entities.model_registry.model_version_search import ModelVersionSearch
from mlflow.entities.model_registry.registered_model_search import RegisteredModelSearch
from mlflow.protos.databricks_uc_registry_messages_pb2 import (
EncryptionDetails,
SseEncryptionAlgorithm,
SseEncryptionDetails,
TemporaryCredentials,
)
from mlflow.protos.databricks_uc_registry_messages_pb2 import ModelVersion as ProtoModelVersion
from mlflow.protos.databricks_uc_registry_messages_pb2 import (
ModelVersionStatus as ProtoModelVersionStatus,
Expand All @@ -26,6 +32,7 @@
RegisteredModelTag as ProtoRegisteredModelTag,
)
from mlflow.utils._unity_catalog_utils import (
_parse_aws_sse_credential,
model_version_from_uc_proto,
model_version_search_from_uc_proto,
registered_model_from_uc_proto,
Expand Down Expand Up @@ -245,3 +252,37 @@ def test_registered_model_and_registered_model_search_equality():
registered_model_search_2 = RegisteredModelSearch(**kwargs)

assert registered_model_2 == registered_model_search_2


@pytest.mark.parametrize(
("temp_credentials", "parsed"),
[
(TemporaryCredentials(), {}),
(
TemporaryCredentials(
encryption_details=EncryptionDetails(
sse_encryption_details=SseEncryptionDetails(
algorithm=SseEncryptionAlgorithm.SSE_ENCRYPTION_ALGORITHM_UNSPECIFIED
)
)
),
{},
),
(
TemporaryCredentials(
encryption_details=EncryptionDetails(
sse_encryption_details=SseEncryptionDetails(
algorithm=SseEncryptionAlgorithm.AWS_SSE_KMS,
aws_kms_key_arn="key_id",
)
)
),
{
"ServerSideEncryption": "aws:kms",
"SSEKMSKeyId": "key_id",
},
),
],
)
def test_parse_aws_sse_credential(temp_credentials, parsed):
assert _parse_aws_sse_credential(temp_credentials) == parsed
Loading