Skip to content
Merged
1 change: 1 addition & 0 deletions src/cohere/aws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def _event_hook(request: httpx.Request) -> None:
)
request.url = URL(url)
request.headers["host"] = request.url.host
headers["host"] = request.url.host

if endpoint == "rerank":
body["api_version"] = get_api_version(version=api_version)
Expand Down
20 changes: 15 additions & 5 deletions src/cohere/manually_maintained/cohere_aws/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,23 @@ class Client:
def __init__(
self,
aws_region: typing.Optional[str] = None,
mode: Mode = Mode.SAGEMAKER,
):
"""
By default we assume region configured in AWS CLI (`aws configure get region`). You can change the region with
`aws configure set region us-west-2` or override it with `region_name` parameter.
"""
self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region)
self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region)
self.mode = mode
if os.environ.get('AWS_DEFAULT_REGION') is None:
os.environ['AWS_DEFAULT_REGION'] = aws_region
self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client)
self.mode = Mode.SAGEMAKER

if self.mode == Mode.SAGEMAKER:
self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region)
self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region)
self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client)
elif self.mode == Mode.BEDROCK:
self._client = lazy_boto3().client("bedrock-runtime", region_name=aws_region)
self._service_client = lazy_boto3().client("bedrock", region_name=aws_region)
Comment thread
cursor[bot] marked this conversation as resolved.



Expand Down Expand Up @@ -550,11 +556,15 @@ def embed(
variant: Optional[str] = None,
input_type: Optional[str] = None,
model_id: Optional[str] = None,
output_dimension: Optional[int] = None,
embedding_types: Optional[List[str]] = None,
) -> Embeddings:
json_params = {
'texts': texts,
'truncate': truncate,
"input_type": input_type
"input_type": input_type,
"output_dimension": output_dimension,
"embedding_types": embedding_types,
Comment thread
fern-support marked this conversation as resolved.
}
for key, value in list(json_params.items()):
if value is None:
Expand Down