Skip to content

Commit 772d817

Browse files
authored
Merge pull request #299 from bugout-dev/dropper-v0.2.0-api-improvements
Multiple Dropper v0.2.0 API improvements
2 parents 8706a6f + 9b9c201 commit 772d817

File tree

8 files changed

+283
-79
lines changed

8 files changed

+283
-79
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""Fix unique constract on registered_contracts to include moonstream_user_id
2+
3+
Revision ID: dedd8a7d0624
4+
Revises: d1be5f227664
5+
Create Date: 2023-05-02 15:52:36.654980
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
11+
12+
# revision identifiers, used by Alembic.
13+
revision = "dedd8a7d0624"
14+
down_revision = "d1be5f227664"
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade():
20+
# ### commands auto generated by Alembic - please adjust! ###
21+
op.drop_constraint(
22+
"uq_registered_contracts_blockchain", "registered_contracts", type_="unique"
23+
)
24+
op.create_unique_constraint(
25+
op.f("uq_registered_contracts_blockchain"),
26+
"registered_contracts",
27+
["blockchain", "moonstream_user_id", "address", "contract_type"],
28+
)
29+
# ### end Alembic commands ###
30+
31+
32+
def downgrade():
33+
# ### commands auto generated by Alembic - please adjust! ###
34+
op.drop_constraint(
35+
op.f("uq_registered_contracts_blockchain"),
36+
"registered_contracts",
37+
type_="unique",
38+
)
39+
op.create_unique_constraint(
40+
"uq_registered_contracts_blockchain",
41+
"registered_contracts",
42+
["blockchain", "address", "contract_type"],
43+
)
44+
# ### end Alembic commands ###

api/engineapi/contracts_actions.py

Lines changed: 103 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import argparse
2+
from datetime import timedelta
23
import json
34
import logging
45
import uuid
5-
from enum import Enum
66
from typing import Any, Dict, List, Optional
77

88
from sqlalchemy import func, text
99
from sqlalchemy.exc import IntegrityError, NoResultFound
1010
from sqlalchemy.orm import Session
1111
from web3 import Web3
1212

13+
from .data import ContractType
14+
1315
from . import data, db
1416
from .models import RegisteredContract, CallRequest
1517

@@ -21,11 +23,6 @@ class ContractAlreadyRegistered(Exception):
2123
pass
2224

2325

24-
class ContractType(Enum):
25-
raw = "raw"
26-
dropper = "dropper-v0.2.0"
27-
28-
2926
def validate_method_and_params(
3027
contract_type: ContractType, method: str, parameters: Dict[str, Any]
3128
) -> None:
@@ -71,12 +68,6 @@ def register_contract(
7168
"""
7269
Register a contract against the Engine instance
7370
"""
74-
75-
# TODO(zomglings): Make it so that contract_type is passed as a string. Convert to
76-
# ContractType here. That will mean there is a single point at which the validation is
77-
# performed rather than relying on each entrypoint to register_contract having to implement
78-
# their own validation.
79-
8071
try:
8172
contract = RegisteredContract(
8273
blockchain=blockchain,
@@ -100,6 +91,46 @@ def register_contract(
10091
return render_registered_contract(contract)
10192

10293

94+
def update_registered_contract(
95+
db_session: Session,
96+
moonstream_user_id: uuid.UUID,
97+
contract_id: uuid.UUID,
98+
title: Optional[str] = None,
99+
description: Optional[str] = None,
100+
image_uri: Optional[str] = None,
101+
ignore_nulls: bool = True,
102+
) -> data.RegisteredContract:
103+
"""
104+
Update the registered contract with the given contract ID provided that the user with moonstream_user_id
105+
has access to it.
106+
"""
107+
query = db_session.query(RegisteredContract).filter(
108+
RegisteredContract.id == contract_id,
109+
RegisteredContract.moonstream_user_id == moonstream_user_id,
110+
)
111+
112+
contract = query.one()
113+
114+
if not (title is None and ignore_nulls):
115+
contract.title = title
116+
if not (description is None and ignore_nulls):
117+
contract.description = description
118+
if not (image_uri is None and ignore_nulls):
119+
contract.image_uri = image_uri
120+
121+
try:
122+
db_session.add(contract)
123+
db_session.commit()
124+
except Exception as err:
125+
logger.error(
126+
f"update_registered_contract -- error storing update in database: {repr(err)}"
127+
)
128+
db_session.rollback()
129+
raise
130+
131+
return render_registered_contract(contract)
132+
133+
103134
def lookup_registered_contracts(
104135
db_session: Session,
105136
moonstream_user_id: uuid.UUID,
@@ -164,7 +195,8 @@ def delete_registered_contract(
164195
def request_calls(
165196
db_session: Session,
166197
moonstream_user_id: uuid.UUID,
167-
registered_contract_id: uuid.UUID,
198+
registered_contract_id: Optional[uuid.UUID],
199+
contract_address: Optional[str],
168200
call_specs: List[data.CallSpecification],
169201
ttl_days: Optional[int] = None,
170202
) -> int:
@@ -174,21 +206,31 @@ def request_calls(
174206
# TODO(zomglings): Do not pass raw ttl_days into SQL query - could be subject to SQL injection
175207
# For now, in the interest of speed, let us just be super cautious with ttl_days.
176208
# Check that the ttl_days is indeed an integer
209+
if registered_contract_id is None and contract_address is None:
210+
raise ValueError(
211+
"At least one of registered_contract_id or contract_address is required"
212+
)
213+
177214
if ttl_days is not None:
178215
assert ttl_days == int(ttl_days), "ttl_days must be an integer"
179216
if ttl_days <= 0:
180217
raise ValueError("ttl_days must be positive")
181218

182-
# Check that the moonstream_user_id matches the RegisteredContract
183-
try:
184-
registered_contract = (
185-
db_session.query(RegisteredContract)
186-
.filter(
187-
RegisteredContract.id == registered_contract_id,
188-
RegisteredContract.moonstream_user_id == moonstream_user_id,
189-
)
190-
.one()
219+
# Check that the moonstream_user_id matches a RegisteredContract with the given id or address
220+
query = db_session.query(RegisteredContract).filter(
221+
RegisteredContract.moonstream_user_id == moonstream_user_id
222+
)
223+
224+
if registered_contract_id is not None:
225+
query = query.filter(RegisteredContract.id == registered_contract_id)
226+
227+
if contract_address is not None:
228+
query = query.filter(
229+
RegisteredContract.address == Web3.toChecksumAddress(contract_address)
191230
)
231+
232+
try:
233+
registered_contract = query.one()
192234
except NoResultFound:
193235
raise ValueError("Invalid registered_contract_id or moonstream_user_id")
194236

@@ -202,18 +244,17 @@ def request_calls(
202244
contract_type, specification.method, specification.parameters
203245
)
204246

205-
# Calculate the expiration time (if ttl_days is specified)
206-
expires_at_sql = None
247+
expires_at = None
207248
if ttl_days is not None:
208-
expires_at_sql = text(f"(NOW() + INTERVAL '{ttl_days} days')")
249+
expires_at = func.now() + timedelta(days=ttl_days)
209250

210251
request = CallRequest(
211252
registered_contract_id=registered_contract.id,
212253
caller=normalized_caller,
213254
moonstream_user_id=moonstream_user_id,
214255
method=specification.method,
215256
parameters=specification.parameters,
216-
expires_at=expires_at_sql,
257+
expires_at=expires_at,
217258
)
218259

219260
db_session.add(request)
@@ -229,21 +270,42 @@ def request_calls(
229270

230271
def list_call_requests(
231272
db_session: Session,
232-
registered_contract_id: uuid.UUID,
233-
caller: str,
273+
contract_id: Optional[uuid.UUID],
274+
contract_address: Optional[str],
275+
caller: Optional[str],
234276
limit: int = 10,
235277
offset: Optional[int] = None,
236278
show_expired: bool = False,
237279
) -> List[data.CallRequest]:
238280
"""
239281
List call requests for the given moonstream_user_id
240282
"""
283+
if caller is None:
284+
raise ValueError("caller must be specified")
285+
286+
if contract_id is None and contract_address is None:
287+
raise ValueError(
288+
"At least one of contract_id or contract_address must be specified"
289+
)
290+
241291
# If show_expired is False, filter out expired requests using current time on database server
242-
query = db_session.query(CallRequest).filter(
243-
CallRequest.registered_contract_id == registered_contract_id,
244-
CallRequest.caller == Web3.toChecksumAddress(caller),
292+
query = (
293+
db_session.query(CallRequest, RegisteredContract)
294+
.join(
295+
RegisteredContract,
296+
CallRequest.registered_contract_id == RegisteredContract.id,
297+
)
298+
.filter(CallRequest.caller == Web3.toChecksumAddress(caller))
245299
)
246300

301+
if contract_id is not None:
302+
query = query.filter(CallRequest.registered_contract_id == contract_id)
303+
304+
if contract_address is not None:
305+
query = query.filter(
306+
RegisteredContract.address == Web3.toChecksumAddress(contract_address)
307+
)
308+
247309
if not show_expired:
248310
query = query.filter(
249311
CallRequest.expires_at > func.now(),
@@ -254,7 +316,10 @@ def list_call_requests(
254316

255317
query = query.limit(limit)
256318
results = query.all()
257-
return [render_call_request(call_request) for call_request in results]
319+
return [
320+
render_call_request(call_request, registered_contract)
321+
for call_request, registered_contract in results
322+
]
258323

259324

260325
# TODO(zomglings): What should the delete functionality for call requests look like?
@@ -282,10 +347,13 @@ def render_registered_contract(contract: RegisteredContract) -> data.RegisteredC
282347
)
283348

284349

285-
def render_call_request(call_request: CallRequest) -> data.CallRequest:
350+
def render_call_request(
351+
call_request: CallRequest, registered_contract: RegisteredContract
352+
) -> data.CallRequest:
286353
return data.CallRequest(
287354
id=call_request.id,
288-
registered_contract_id=call_request.registered_contract_id,
355+
contract_id=call_request.registered_contract_id,
356+
contract_address=registered_contract.address,
289357
moonstream_user_id=call_request.moonstream_user_id,
290358
caller=call_request.caller,
291359
method=call_request.method,
@@ -404,7 +472,7 @@ def handle_list_requests(args: argparse.Namespace) -> None:
404472
with db.yield_db_session_ctx() as db_session:
405473
call_requests = list_call_requests(
406474
db_session=db_session,
407-
registered_contract_id=args.registered_contract_id,
475+
contract_id=args.registered_contract_id,
408476
caller=args.caller,
409477
limit=args.limit,
410478
offset=args.offset,

api/engineapi/data.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from datetime import datetime
2+
from enum import Enum
23
from typing import Any, Dict, List, Optional
3-
4-
from pydantic import BaseModel, Field, validator
54
from uuid import UUID
65

6+
from pydantic import BaseModel, Field, validator, root_validator
7+
from web3 import Web3
8+
79

810
class PingResponse(BaseModel):
911
"""
@@ -53,7 +55,6 @@ class DropperBlockchainResponse(BaseModel):
5355

5456

5557
class DropRegisterRequest(BaseModel):
56-
5758
dropper_contract_id: UUID
5859
title: Optional[str] = None
5960
description: Optional[str] = None
@@ -177,13 +178,25 @@ class DropUpdatedResponse(BaseModel):
177178
active: bool = True
178179

179180

181+
class ContractType(Enum):
182+
raw = "raw"
183+
dropper = "dropper-v0.2.0"
184+
185+
180186
class RegisterContractRequest(BaseModel):
181187
blockchain: str
182188
address: str
183-
contract_type: str
189+
contract_type: ContractType
190+
title: Optional[str] = None
191+
description: Optional[str] = None
192+
image_uri: Optional[str] = None
193+
194+
195+
class UpdateContractRequest(BaseModel):
184196
title: Optional[str] = None
185197
description: Optional[str] = None
186198
image_uri: Optional[str] = None
199+
ignore_nulls: bool = True
187200

188201

189202
class RegisteredContract(BaseModel):
@@ -214,28 +227,45 @@ class CallSpecification(BaseModel):
214227

215228

216229
class CreateCallRequestsAPIRequest(BaseModel):
230+
contract_id: Optional[UUID] = None
231+
contract_address: Optional[str] = None
217232
specifications: List[CallSpecification] = Field(default_factory=list)
218233
ttl_days: Optional[int] = None
219234

235+
# Solution found thanks to https://github.com/pydantic/pydantic/issues/506
236+
@root_validator
237+
def at_least_one_of_contract_id_and_contract_address(cls, values):
238+
if values.get("contract_id") is None and values.get("contract_address") is None:
239+
raise ValueError(
240+
"At least one of contract_id and contract_address must be provided"
241+
)
242+
return values
243+
220244

221245
class CallRequest(BaseModel):
222246
id: UUID
223-
registered_contract_id: UUID
247+
contract_id: UUID
248+
contract_address: str
224249
moonstream_user_id: UUID
225250
caller: str
226251
method: str
227252
parameters: Dict[str, Any]
228-
expires_at: datetime
253+
expires_at: Optional[datetime]
229254
created_at: datetime
230255
updated_at: datetime
231256

232-
@validator("id", "registered_contract_id", "moonstream_user_id")
257+
@validator("id", "contract_id", "moonstream_user_id")
233258
def validate_uuids(cls, v):
234259
return str(v)
235260

236261
@validator("created_at", "updated_at", "expires_at")
237262
def validate_datetimes(cls, v):
238-
return v.isoformat()
263+
if v is not None:
264+
return v.isoformat()
265+
266+
@validator("contract_address", "caller")
267+
def validate_web3_adresses(cls, v):
268+
return Web3.toChecksumAddress(v)
239269

240270

241271
class QuartilesResponse(BaseModel):

0 commit comments

Comments
 (0)