Skip to content

Commit b20b617

Browse files
fern-supportclaude
andcommitted
fix: guard SageMaker-only methods in Bedrock mode
Address review feedback: In Bedrock mode, `self._sess` was never set, so SageMaker-only methods would throw confusing AttributeErrors. Now: - Initialize `_sess=None` and `_endpoint_name=None` in Bedrock mode - Add `_require_sagemaker()` guard to connect_to_endpoint, create_endpoint, export_finetune, summarize, and delete_endpoint Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 065e010 commit b20b617

1 file changed

Lines changed: 10 additions & 1 deletion

File tree

  • src/cohere/manually_maintained/cohere_aws

src/cohere/manually_maintained/cohere_aws/client.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,12 @@ def __init__(
3737
elif self.mode == Mode.BEDROCK:
3838
self._client = lazy_boto3().client("bedrock-runtime", region_name=aws_region)
3939
self._service_client = lazy_boto3().client("bedrock", region_name=aws_region)
40+
self._sess = None
41+
self._endpoint_name = None
4042

41-
43+
def _require_sagemaker(self) -> None:
44+
if self.mode != Mode.SAGEMAKER:
45+
raise CohereError("This method is only supported in SageMaker mode.")
4246

4347
def _does_endpoint_exist(self, endpoint_name: str) -> bool:
4448
try:
@@ -56,6 +60,7 @@ def connect_to_endpoint(self, endpoint_name: str) -> None:
5660
Raises:
5761
CohereError: Connection to the endpoint failed.
5862
"""
63+
self._require_sagemaker()
5964
if not self._does_endpoint_exist(endpoint_name):
6065
raise CohereError(f"Endpoint {endpoint_name} does not exist.")
6166
self._endpoint_name = endpoint_name
@@ -143,6 +148,7 @@ def create_endpoint(
143148
will be used to get the role. This should work when one uses the client inside SageMaker. If this errors
144149
out, the default role "ServiceRoleSagemaker" will be used, which generally works outside of SageMaker.
145150
"""
151+
self._require_sagemaker()
146152
# First, check if endpoint already exists
147153
if self._does_endpoint_exist(endpoint_name):
148154
if recreate:
@@ -815,6 +821,7 @@ def export_finetune(
815821
This should work when one uses the client inside SageMaker. If this errors out,
816822
the default role "ServiceRoleSagemaker" will be used, which generally works outside SageMaker.
817823
"""
824+
self._require_sagemaker()
818825
if name == "model":
819826
raise ValueError("name cannot be 'model'")
820827

@@ -958,6 +965,7 @@ def summarize(
958965
additional_command: Optional[str] = "",
959966
variant: Optional[str] = None
960967
) -> Summary:
968+
self._require_sagemaker()
961969

962970
if self._endpoint_name is None:
963971
raise CohereError("No endpoint connected. "
@@ -999,6 +1007,7 @@ def summarize(
9991007

10001008

10011009
def delete_endpoint(self) -> None:
1010+
self._require_sagemaker()
10021011
if self._endpoint_name is None:
10031012
raise CohereError("No endpoint connected.")
10041013
try:

0 commit comments

Comments
 (0)