1
1
import argparse
2
+ from datetime import timedelta
2
3
import json
3
4
import logging
4
5
import uuid
5
- from enum import Enum
6
6
from typing import Any , Dict , List , Optional
7
7
8
8
from sqlalchemy import func , text
9
9
from sqlalchemy .exc import IntegrityError , NoResultFound
10
10
from sqlalchemy .orm import Session
11
11
from web3 import Web3
12
12
13
+ from .data import ContractType
14
+
13
15
from . import data , db
14
16
from .models import RegisteredContract , CallRequest
15
17
@@ -21,11 +23,6 @@ class ContractAlreadyRegistered(Exception):
21
23
pass
22
24
23
25
24
- class ContractType (Enum ):
25
- raw = "raw"
26
- dropper = "dropper-v0.2.0"
27
-
28
-
29
26
def validate_method_and_params (
30
27
contract_type : ContractType , method : str , parameters : Dict [str , Any ]
31
28
) -> None :
@@ -71,12 +68,6 @@ def register_contract(
71
68
"""
72
69
Register a contract against the Engine instance
73
70
"""
74
-
75
- # TODO(zomglings): Make it so that contract_type is passed as a string. Convert to
76
- # ContractType here. That will mean there is a single point at which the validation is
77
- # performed rather than relying on each entrypoint to register_contract having to implement
78
- # their own validation.
79
-
80
71
try :
81
72
contract = RegisteredContract (
82
73
blockchain = blockchain ,
@@ -100,6 +91,46 @@ def register_contract(
100
91
return render_registered_contract (contract )
101
92
102
93
94
+ def update_registered_contract (
95
+ db_session : Session ,
96
+ moonstream_user_id : uuid .UUID ,
97
+ contract_id : uuid .UUID ,
98
+ title : Optional [str ] = None ,
99
+ description : Optional [str ] = None ,
100
+ image_uri : Optional [str ] = None ,
101
+ ignore_nulls : bool = True ,
102
+ ) -> data .RegisteredContract :
103
+ """
104
+ Update the registered contract with the given contract ID provided that the user with moonstream_user_id
105
+ has access to it.
106
+ """
107
+ query = db_session .query (RegisteredContract ).filter (
108
+ RegisteredContract .id == contract_id ,
109
+ RegisteredContract .moonstream_user_id == moonstream_user_id ,
110
+ )
111
+
112
+ contract = query .one ()
113
+
114
+ if not (title is None and ignore_nulls ):
115
+ contract .title = title
116
+ if not (description is None and ignore_nulls ):
117
+ contract .description = description
118
+ if not (image_uri is None and ignore_nulls ):
119
+ contract .image_uri = image_uri
120
+
121
+ try :
122
+ db_session .add (contract )
123
+ db_session .commit ()
124
+ except Exception as err :
125
+ logger .error (
126
+ f"update_registered_contract -- error storing update in database: { repr (err )} "
127
+ )
128
+ db_session .rollback ()
129
+ raise
130
+
131
+ return render_registered_contract (contract )
132
+
133
+
103
134
def lookup_registered_contracts (
104
135
db_session : Session ,
105
136
moonstream_user_id : uuid .UUID ,
@@ -164,7 +195,8 @@ def delete_registered_contract(
164
195
def request_calls (
165
196
db_session : Session ,
166
197
moonstream_user_id : uuid .UUID ,
167
- registered_contract_id : uuid .UUID ,
198
+ registered_contract_id : Optional [uuid .UUID ],
199
+ contract_address : Optional [str ],
168
200
call_specs : List [data .CallSpecification ],
169
201
ttl_days : Optional [int ] = None ,
170
202
) -> int :
@@ -174,21 +206,31 @@ def request_calls(
174
206
# TODO(zomglings): Do not pass raw ttl_days into SQL query - could be subject to SQL injection
175
207
# For now, in the interest of speed, let us just be super cautious with ttl_days.
176
208
# Check that the ttl_days is indeed an integer
209
+ if registered_contract_id is None and contract_address is None :
210
+ raise ValueError (
211
+ "At least one of registered_contract_id or contract_address is required"
212
+ )
213
+
177
214
if ttl_days is not None :
178
215
assert ttl_days == int (ttl_days ), "ttl_days must be an integer"
179
216
if ttl_days <= 0 :
180
217
raise ValueError ("ttl_days must be positive" )
181
218
182
- # Check that the moonstream_user_id matches the RegisteredContract
183
- try :
184
- registered_contract = (
185
- db_session .query (RegisteredContract )
186
- .filter (
187
- RegisteredContract .id == registered_contract_id ,
188
- RegisteredContract .moonstream_user_id == moonstream_user_id ,
189
- )
190
- .one ()
219
+ # Check that the moonstream_user_id matches a RegisteredContract with the given id or address
220
+ query = db_session .query (RegisteredContract ).filter (
221
+ RegisteredContract .moonstream_user_id == moonstream_user_id
222
+ )
223
+
224
+ if registered_contract_id is not None :
225
+ query = query .filter (RegisteredContract .id == registered_contract_id )
226
+
227
+ if contract_address is not None :
228
+ query = query .filter (
229
+ RegisteredContract .address == Web3 .toChecksumAddress (contract_address )
191
230
)
231
+
232
+ try :
233
+ registered_contract = query .one ()
192
234
except NoResultFound :
193
235
raise ValueError ("Invalid registered_contract_id or moonstream_user_id" )
194
236
@@ -202,18 +244,17 @@ def request_calls(
202
244
contract_type , specification .method , specification .parameters
203
245
)
204
246
205
- # Calculate the expiration time (if ttl_days is specified)
206
- expires_at_sql = None
247
+ expires_at = None
207
248
if ttl_days is not None :
208
- expires_at_sql = text ( f"(NOW( ) + INTERVAL ' { ttl_days } days')" )
249
+ expires_at = func . now ( ) + timedelta ( days = ttl_days )
209
250
210
251
request = CallRequest (
211
252
registered_contract_id = registered_contract .id ,
212
253
caller = normalized_caller ,
213
254
moonstream_user_id = moonstream_user_id ,
214
255
method = specification .method ,
215
256
parameters = specification .parameters ,
216
- expires_at = expires_at_sql ,
257
+ expires_at = expires_at ,
217
258
)
218
259
219
260
db_session .add (request )
@@ -229,21 +270,42 @@ def request_calls(
229
270
230
271
def list_call_requests (
231
272
db_session : Session ,
232
- registered_contract_id : uuid .UUID ,
233
- caller : str ,
273
+ contract_id : Optional [uuid .UUID ],
274
+ contract_address : Optional [str ],
275
+ caller : Optional [str ],
234
276
limit : int = 10 ,
235
277
offset : Optional [int ] = None ,
236
278
show_expired : bool = False ,
237
279
) -> List [data .CallRequest ]:
238
280
"""
239
281
List call requests for the given moonstream_user_id
240
282
"""
283
+ if caller is None :
284
+ raise ValueError ("caller must be specified" )
285
+
286
+ if contract_id is None and contract_address is None :
287
+ raise ValueError (
288
+ "At least one of contract_id or contract_address must be specified"
289
+ )
290
+
241
291
# If show_expired is False, filter out expired requests using current time on database server
242
- query = db_session .query (CallRequest ).filter (
243
- CallRequest .registered_contract_id == registered_contract_id ,
244
- CallRequest .caller == Web3 .toChecksumAddress (caller ),
292
+ query = (
293
+ db_session .query (CallRequest , RegisteredContract )
294
+ .join (
295
+ RegisteredContract ,
296
+ CallRequest .registered_contract_id == RegisteredContract .id ,
297
+ )
298
+ .filter (CallRequest .caller == Web3 .toChecksumAddress (caller ))
245
299
)
246
300
301
+ if contract_id is not None :
302
+ query = query .filter (CallRequest .registered_contract_id == contract_id )
303
+
304
+ if contract_address is not None :
305
+ query = query .filter (
306
+ RegisteredContract .address == Web3 .toChecksumAddress (contract_address )
307
+ )
308
+
247
309
if not show_expired :
248
310
query = query .filter (
249
311
CallRequest .expires_at > func .now (),
@@ -254,7 +316,10 @@ def list_call_requests(
254
316
255
317
query = query .limit (limit )
256
318
results = query .all ()
257
- return [render_call_request (call_request ) for call_request in results ]
319
+ return [
320
+ render_call_request (call_request , registered_contract )
321
+ for call_request , registered_contract in results
322
+ ]
258
323
259
324
260
325
# TODO(zomglings): What should the delete functionality for call requests look like?
@@ -282,10 +347,13 @@ def render_registered_contract(contract: RegisteredContract) -> data.RegisteredC
282
347
)
283
348
284
349
285
- def render_call_request (call_request : CallRequest ) -> data .CallRequest :
350
+ def render_call_request (
351
+ call_request : CallRequest , registered_contract : RegisteredContract
352
+ ) -> data .CallRequest :
286
353
return data .CallRequest (
287
354
id = call_request .id ,
288
- registered_contract_id = call_request .registered_contract_id ,
355
+ contract_id = call_request .registered_contract_id ,
356
+ contract_address = registered_contract .address ,
289
357
moonstream_user_id = call_request .moonstream_user_id ,
290
358
caller = call_request .caller ,
291
359
method = call_request .method ,
@@ -404,7 +472,7 @@ def handle_list_requests(args: argparse.Namespace) -> None:
404
472
with db .yield_db_session_ctx () as db_session :
405
473
call_requests = list_call_requests (
406
474
db_session = db_session ,
407
- registered_contract_id = args .registered_contract_id ,
475
+ contract_id = args .registered_contract_id ,
408
476
caller = args .caller ,
409
477
limit = args .limit ,
410
478
offset = args .offset ,
0 commit comments