Skip to content
Merged
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ dynamic = ["version"]

[tool.poetry]
name = "cohere"
version = "5.20.5"
version = "5.20.6"
description = ""
readme = "README.md"
authors = []
Expand Down
1 change: 1 addition & 0 deletions src/cohere/aws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def _event_hook(request: httpx.Request) -> None:
)
request.url = URL(url)
request.headers["host"] = request.url.host
headers["host"] = request.url.host

if endpoint == "rerank":
body["api_version"] = get_api_version(version=api_version)
Expand Down
4 changes: 2 additions & 2 deletions src/cohere/core/client_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ def get_headers(self) -> typing.Dict[str, str]:
import platform

headers: typing.Dict[str, str] = {
"User-Agent": "cohere/5.20.5",
"User-Agent": "cohere/5.20.6",
"X-Fern-Language": "Python",
"X-Fern-Runtime": f"python/{platform.python_version()}",
"X-Fern-Platform": f"{platform.system().lower()}/{platform.release()}",
"X-Fern-SDK-Name": "cohere",
"X-Fern-SDK-Version": "5.20.5",
"X-Fern-SDK-Version": "5.20.6",
**(self.get_custom_headers() or {}),
}
if self._client_name is not None:
Expand Down
41 changes: 33 additions & 8 deletions src/cohere/manually_maintained/cohere_aws/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,29 @@ class Client:
def __init__(
self,
aws_region: typing.Optional[str] = None,
mode: Mode = Mode.SAGEMAKER,
):
"""
By default we assume region configured in AWS CLI (`aws configure get region`). You can change the region with
`aws configure set region us-west-2` or override it with `region_name` parameter.
"""
self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region)
self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region)
self.mode = mode
if os.environ.get('AWS_DEFAULT_REGION') is None:
os.environ['AWS_DEFAULT_REGION'] = aws_region
self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client)
self.mode = Mode.SAGEMAKER

if self.mode == Mode.SAGEMAKER:
self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region)
self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region)
self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client)
elif self.mode == Mode.BEDROCK:
self._client = lazy_boto3().client("bedrock-runtime", region_name=aws_region)
self._service_client = lazy_boto3().client("bedrock", region_name=aws_region)
self._sess = None
self._endpoint_name = None

def _require_sagemaker(self) -> None:
if self.mode != Mode.SAGEMAKER:
raise CohereError("This method is only supported in SageMaker mode.")

def _does_endpoint_exist(self, endpoint_name: str) -> bool:
try:
Expand All @@ -50,6 +60,7 @@ def connect_to_endpoint(self, endpoint_name: str) -> None:
Raises:
CohereError: Connection to the endpoint failed.
"""
self._require_sagemaker()
if not self._does_endpoint_exist(endpoint_name):
raise CohereError(f"Endpoint {endpoint_name} does not exist.")
self._endpoint_name = endpoint_name
Expand Down Expand Up @@ -137,6 +148,7 @@ def create_endpoint(
will be used to get the role. This should work when one uses the client inside SageMaker. If this errors
out, the default role "ServiceRoleSagemaker" will be used, which generally works outside of SageMaker.
"""
self._require_sagemaker()
# First, check if endpoint already exists
if self._does_endpoint_exist(endpoint_name):
if recreate:
Expand Down Expand Up @@ -550,11 +562,15 @@ def embed(
variant: Optional[str] = None,
input_type: Optional[str] = None,
model_id: Optional[str] = None,
) -> Embeddings:
output_dimension: Optional[int] = None,
embedding_types: Optional[List[str]] = None,
) -> Union[Embeddings, Dict[str, List]]:
json_params = {
'texts': texts,
'truncate': truncate,
"input_type": input_type
"input_type": input_type,
"output_dimension": output_dimension,
"embedding_types": embedding_types,
}
for key, value in list(json_params.items()):
if value is None:
Expand Down Expand Up @@ -591,7 +607,10 @@ def _sagemaker_embed(self, json_params: Dict[str, Any], variant: str):
# ValidationError, e.g. when variant is bad
raise CohereError(str(e))

return Embeddings(response['embeddings'])
embeddings = response['embeddings']
if isinstance(embeddings, dict):
return embeddings
return Embeddings(embeddings)

def _bedrock_embed(self, json_params: Dict[str, Any], model_id: str):
if not model_id:
Expand All @@ -612,7 +631,10 @@ def _bedrock_embed(self, json_params: Dict[str, Any], model_id: str):
# ValidationError, e.g. when variant is bad
raise CohereError(str(e))

return Embeddings(response['embeddings'])
embeddings = response['embeddings']
if isinstance(embeddings, dict):
return embeddings
return Embeddings(embeddings)


def rerank(self,
Expand Down Expand Up @@ -805,6 +827,7 @@ def export_finetune(
This should work when one uses the client inside SageMaker. If this errors out,
the default role "ServiceRoleSagemaker" will be used, which generally works outside SageMaker.
"""
self._require_sagemaker()
if name == "model":
raise ValueError("name cannot be 'model'")

Expand Down Expand Up @@ -948,6 +971,7 @@ def summarize(
additional_command: Optional[str] = "",
variant: Optional[str] = None
) -> Summary:
self._require_sagemaker()

if self._endpoint_name is None:
raise CohereError("No endpoint connected. "
Expand Down Expand Up @@ -989,6 +1013,7 @@ def summarize(


def delete_endpoint(self) -> None:
self._require_sagemaker()
if self._endpoint_name is None:
raise CohereError("No endpoint connected.")
try:
Expand Down
Loading