-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy path__init__.py
More file actions
276 lines (230 loc) · 10.2 KB
/
__init__.py
File metadata and controls
276 lines (230 loc) · 10.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import importlib
import os
import sys
import warnings
from enum import Enum
from pathlib import Path
from typing import Union
import google.protobuf.internal.api_implementation
from google.protobuf import symbol_database as _symbol_database
from snet.sdk.exceptions import NoGroupsFoundError, GroupNotFoundError, ServiceMetadataMismatchError
from snet.sdk.registry.models import StorageType, FileURI
from snet.sdk.registry.organization_metadata import OrganizationMetadata
from snet.sdk.registry.registry_contract import RegistryContract
from snet.sdk.registry.service_metadata import ServiceMetadata, Group
with warnings.catch_warnings():
# Suppress the eth-typing package`s warnings related to some new networks
warnings.filterwarnings(
"ignore",
"Network .* does not have a valid ChainId. eth-typing should be "
"updated with the latest networks.",
UserWarning,
)
from snet.sdk.account import Account
from snet.sdk.config import config
from snet.sdk.client_lib_generator import ClientLibGenerator
from snet.sdk.mpe.mpe_contract import MPEContract
from snet.sdk.mpe.payment_channel_provider import PaymentChannelProvider
from snet.sdk.payment_strategies import (
DefaultPaymentStrategy,
PaidCallPaymentStrategy,
PrePaidPaymentStrategy,
FreeCallPaymentStrategy,
PaymentStrategy,
)
from snet.sdk.service_client import ServiceClient
from snet.sdk.registry.storage_provider import StorageProvider
from snet.sdk.types import ModuleName, ServiceStub
from snet.sdk.utils.utils import (
find_file_by_keyword,
get_we3_object,
)
google.protobuf.internal.api_implementation.Type = lambda: "python"
_sym_db = _symbol_database.Default()
_sym_db.RegisterMessage = lambda x: None
class PaymentStrategyType(Enum):
PAID_CALL = PaidCallPaymentStrategy
FREE_CALL = FreeCallPaymentStrategy
PREPAID_CALL = PrePaidPaymentStrategy
DEFAULT = DefaultPaymentStrategy
class SnetSDK:
def __init__(self):
self.w3 = get_we3_object()
self.mpe_contract = MPEContract()
self.registry_contract = RegistryContract()
self.storage_provider = StorageProvider()
self.payment_channel_provider = PaymentChannelProvider(self.mpe_contract)
self.account = Account(self.mpe_contract.contract.address)
def create_service_client(
self,
org_id: str,
service_id: str,
group_name: str = None,
payment_strategy: PaymentStrategy = None,
payment_strategy_type: PaymentStrategyType = PaymentStrategyType.DEFAULT,
options=None,
concurrent_calls: int = 1,
):
service_metadata = self._enhance_service_metadata(org_id, service_id)
lib_generator = ClientLibGenerator(self.storage_provider, org_id, service_id)
if service_metadata.service_api_source is not None:
service_api_source = service_metadata.service_api_source
else:
service_api_source = service_metadata.model_ipfs_hash
service_api_source = FileURI.from_raw_uri(service_api_source)
force_update = config.FORCE_UPDATE
if force_update:
lib_generator.generate_client_library(service_api_source)
else:
path_to_pb_files = lib_generator.proto_dir
pb_2_file_name = find_file_by_keyword(
path_to_pb_files, keyword="pb2.py", exclude=["training"]
)
pb_2_grpc_file_name = find_file_by_keyword(
path_to_pb_files, keyword="pb2_grpc.py", exclude=["training"]
)
if not pb_2_file_name or not pb_2_grpc_file_name:
print("Generating client library...")
lib_generator.generate_client_library(service_api_source)
if options is None:
options = dict()
options["concurrency"] = config.CONCURRENCY
options["concurrent_calls"] = concurrent_calls
if payment_strategy is None:
payment_strategy = payment_strategy_type.value()
group = self._get_service_group(org_id, service_id, service_metadata, group_name)
service_stubs = self.get_service_stub(lib_generator)
pb2_module = self.get_module_by_keyword("pb2.py", lib_generator)
_service_client = ServiceClient(
org_id,
service_id,
group,
service_stubs,
payment_strategy,
options,
self.mpe_contract,
self.account,
self.w3,
pb2_module,
self.payment_channel_provider,
lib_generator.proto_dir,
lib_generator.training_added(),
)
return _service_client
def _enhance_service_metadata(self, org_id, service_id):
service_metadata = self.get_service_metadata(org_id, service_id)
org_metadata = self.get_organization_metadata(org_id)
org_group_map = {}
for group in org_metadata.groups:
org_group_map[group.group_name] = group
for group in service_metadata.groups:
group.payment = org_group_map[group.group_name].payment
return service_metadata
def get_service_stub(self, lib_generator: ClientLibGenerator) -> list[ServiceStub]:
path_to_pb_files = str(lib_generator.proto_dir)
module_name = self.get_module_by_keyword("pb2_grpc.py", lib_generator)
sys.path.append(path_to_pb_files)
try:
grpc_file = importlib.import_module(module_name)
properties_and_methods_of_grpc_file = dir(grpc_file)
service_stubs = []
for elem in properties_and_methods_of_grpc_file:
if "Stub" in elem:
service_stubs.append(getattr(grpc_file, elem))
return [ServiceStub(service_stub) for service_stub in service_stubs]
except Exception as e:
raise Exception(f"Error importing module: {e}")
def get_module_by_keyword(self, keyword: str, lib_generator: ClientLibGenerator) -> ModuleName:
path_to_pb_files = lib_generator.proto_dir
file_name = find_file_by_keyword(path_to_pb_files, keyword, exclude=["training"])
module_name = os.path.splitext(file_name)[0]
return ModuleName(module_name)
def get_service_metadata(self, org_id, service_id) -> ServiceMetadata:
service = self.registry_contract.get_service(org_id, service_id)
return self.storage_provider.fetch_service_metadata(service.metadata_uri)
def get_organization_metadata(self, org_id: str) -> OrganizationMetadata:
org = self.registry_contract.get_org(org_id)
return self.storage_provider.fetch_org_metadata(org.metadata_uri)
def _get_service_group(
self, org_id: str, service_id: str, service_metadata: ServiceMetadata, group_name: str
) -> Group:
if len(service_metadata.groups) == 0:
raise NoGroupsFoundError(org_id, service_id)
if group_name is None:
return service_metadata.groups[0]
for group in service_metadata.groups:
if group.group_name == group_name:
return group
raise GroupNotFoundError(org_id, service_id, group_name)
def get_organization_list(self) -> list:
return self.registry_contract.list_orgs()
def get_services_list(self, org_id: str) -> list:
return self.registry_contract.list_service_for_org(org_id)
def publish_service_comprehensively(
self,
org_id: str,
service_id: str,
metadata: ServiceMetadata,
proto_dir: Union[str, Path],
storage_type: StorageType = StorageType.IPFS,
) -> bool:
"""
1. publish .proto files as .tar.gz archive into the storage
2. add other fields to the service metadata
3. validate service metadata
4. publish service metadata into the storage
5. publish service into Registry contract
"""
proto_uri = self.storage_provider.publish_proto(proto_dir, storage_type)
metadata.service_api_source = str(proto_uri)
metadata.mpe_address = self.mpe_contract.contract.address
self._check_and_update_service_groups(org_id, metadata.groups)
metadata_uri = self.storage_provider.publish_service_metadata(metadata, storage_type)
receipt = self.registry_contract.create_service(
self.account, org_id, service_id, metadata_uri
)
return receipt["status"] != 0
def update_service(
self,
org_id: str,
service_id: str,
metadata: ServiceMetadata,
proto_dir: Union[str, Path, None] = None,
storage_type: StorageType = StorageType.IPFS,
) -> bool:
if proto_dir is not None:
proto_uri = self.storage_provider.publish_proto(proto_dir, storage_type)
metadata.service_api_source = str(proto_uri)
if not metadata.mpe_address:
metadata.mpe_address = self.mpe_contract.contract.address
self._check_and_update_service_groups(org_id, metadata.groups)
metadata_uri = self.storage_provider.publish_service_metadata(metadata, storage_type)
receipt = self.registry_contract.update_service_metadata(
self.account, org_id, service_id, metadata_uri
)
return receipt["status"] != 0
def update_organization(
self,
org_id: str,
organization_metadata: OrganizationMetadata,
storage_type: StorageType = StorageType.IPFS,
) -> bool:
metadata_uri = self.storage_provider.publish_organization_metadata(
organization_metadata, storage_type
)
receipt = self.registry_contract.update_org_metadata(self.account, org_id, metadata_uri)
return receipt["status"] != 0
def _check_and_update_service_groups(
self, org_id: str, service_groups: list[Group]
) -> list[Group]:
org = self.registry_contract.get_org(org_id)
org_metadata = self.storage_provider.fetch_org_metadata(org.metadata_uri)
org_groups_map = {g.group_name: g for g in org_metadata.groups}
for group in service_groups:
try:
group.group_id = org_groups_map[group.group_name].group_id
except KeyError:
raise ServiceMetadataMismatchError(
"All groups added to the service must also exist in the organization!"
)
return service_groups