diff --git a/main.py b/main.py index a5aeabe..abba0a2 100644 --- a/main.py +++ b/main.py @@ -33,6 +33,14 @@ Security hardening applied 2026-03-21: - Fix #8: Input validation regex for addresses, tx hashes, policy IDs - Fix #9: Correct tx hash calculation (blake2b of tx body, not full tx) - Fix #10: Enforce key expiry globally in get_api_key_info + +Security hardening pass 2 (2026-03-21): +- Fix #11: Request body size limit (64KB) on /v1/tx/submit +- Fix #12: CIP-8 empty payload bypass fixed +- Fix #13: Pagination on /v1/address/{addr}/tokens and /v1/asset/{policy_id}/info +- Fix #14: cbor2 bumped to >=5.6.5 (CVE-2024-26134) +- Fix #15: Fixed holder count query (was using GROUP BY + COUNT DISTINCT incorrectly) +- Fix #16: Async lock for protocol params cache to prevent stampede """ import os @@ -52,6 +60,7 @@ from contextlib import asynccontextmanager from fastapi import FastAPI, Request, HTTPException, Query, Header, Depends, BackgroundTasks from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware from pydantic import BaseModel, Field import asyncpg from asyncpg.exceptions import UndefinedTableError, PostgresError @@ -132,6 +141,9 @@ db_pool: Optional[asyncpg.Pool] = None redis_client: Optional[redis.Redis] = None protocol_params_cache: dict = {"data": None, "expires": 0} +# Fix #16: Async lock for protocol params cache to prevent stampede +_params_lock = asyncio.Lock() + # ============ Input Validation (Fix #8) ============ @@ -264,18 +276,22 @@ def verify_cip8_signature(address: str, nonce: str, signature_hex: str, key_hex: protected, unprotected, payload, sig = cose_sign1 + # Fix #12: Reject empty payloads - nonce verification must happen + if not payload: + logger.warning("CIP-8 verification rejected: empty payload") + return False + # The payload should contain our nonce - if payload: - payload_decoded = payload if isinstance(payload, bytes) else bytes(payload) - # Payload might be hex-encoded nonce or raw bytes - try: - if payload_decoded.hex() != nonce and payload_decoded.decode('utf-8') != nonce: - logger.warning("Payload doesn't match nonce") - return False - except: - if payload_decoded.hex() != nonce: - logger.warning("Payload doesn't match nonce (hex check)") - return False + payload_decoded = payload if isinstance(payload, bytes) else bytes(payload) + # Payload might be hex-encoded nonce or raw bytes + try: + if payload_decoded.hex() != nonce and payload_decoded.decode('utf-8') != nonce: + logger.warning("Payload doesn't match nonce") + return False + except: + if payload_decoded.hex() != nonce: + logger.warning("Payload doesn't match nonce (hex check)") + return False # Build the Sig_structure for verification # Sig_structure = ["Signature1", protected, external_aad, payload] @@ -417,11 +433,31 @@ async def lifespan(app: FastAPI): app = FastAPI( title="Cardano Chain Data API", description="REST API for querying Cardano blockchain data via db-sync and cardano-node", - version="2.1.0", # Bumped for security fixes + version="2.2.0", # Bumped for security hardening pass 2 lifespan=lifespan ) +# ============ Fix #11: Request Body Size Limit Middleware ============ + +class LimitBodySizeMiddleware(BaseHTTPMiddleware): + """Limit request body size on tx submit to prevent DoS. Cardano max tx is ~16KB.""" + MAX_TX_SIZE = 65536 # 64KB - generous limit + + async def dispatch(self, request: Request, call_next): + if request.url.path == "/v1/tx/submit": + content_length = request.headers.get("content-length") + if content_length and int(content_length) > self.MAX_TX_SIZE: + return JSONResponse( + status_code=413, + content={"error": "payload_too_large", "message": "Transaction exceeds maximum size of 64KB"} + ) + return await call_next(request) + + +app.add_middleware(LimitBodySizeMiddleware) + + # ============ Exception Handlers ============ @app.exception_handler(UndefinedTableError) @@ -968,44 +1004,46 @@ async def get_protocol_params(auth: dict = Depends(require_standard_tier)): """ global protocol_params_cache - # Check cache - if protocol_params_cache["data"] and protocol_params_cache["expires"] > time.time(): - return protocol_params_cache["data"] - - # Query from node - success, stdout, stderr = run_cardano_cli([ - "query", "protocol-parameters", - "--mainnet" - ]) - - if not success: - if "Network.Socket.connect" in stderr or "does not exist" in stderr: + # Fix #16: Use async lock to prevent cache stampede + async with _params_lock: + # Check cache inside the lock + if protocol_params_cache["data"] and protocol_params_cache["expires"] > time.time(): + return protocol_params_cache["data"] + + # Query from node + success, stdout, stderr = run_cardano_cli([ + "query", "protocol-parameters", + "--mainnet" + ]) + + if not success: + if "Network.Socket.connect" in stderr or "does not exist" in stderr: + raise HTTPException( + status_code=503, + detail={"error": "node_unavailable", "message": "Cardano node not available"} + ) + # Fix #5: Don't leak stderr + logger.error(f"protocol-params query failed: {stderr}") raise HTTPException( - status_code=503, - detail={"error": "node_unavailable", "message": "Cardano node not available"} + status_code=500, + detail={"error": "node_error", "message": "Node command failed"} ) - # Fix #5: Don't leak stderr - logger.error(f"protocol-params query failed: {stderr}") - raise HTTPException( - status_code=500, - detail={"error": "node_error", "message": "Node command failed"} - ) - - try: - params = json.loads(stdout) - except json.JSONDecodeError: - raise HTTPException( - status_code=500, - detail={"error": "parse_error", "message": "Failed to parse protocol parameters"} - ) - - # Cache for 5 minutes - protocol_params_cache = { - "data": params, - "expires": time.time() + CACHE_TTLS["protocol_params"] - } - - return params + + try: + params = json.loads(stdout) + except json.JSONDecodeError: + raise HTTPException( + status_code=500, + detail={"error": "parse_error", "message": "Failed to parse protocol parameters"} + ) + + # Cache for 5 minutes + protocol_params_cache = { + "data": params, + "expires": time.time() + CACHE_TTLS["protocol_params"] + } + + return params # ============ Auth Endpoints (TRP-Gated) ============ @@ -1327,8 +1365,13 @@ async def get_address_balance(address: str, auth: dict = Depends(get_auth_contex @app.get("/v1/address/{address}/tokens") -async def get_address_tokens(address: str, auth: dict = Depends(get_auth_context)): - """Get native tokens held by an address.""" +async def get_address_tokens( + address: str, + page: int = Query(1, ge=1, description="Page number"), + limit: int = Query(100, ge=1, le=1000, description="Results per page (max 1000)"), + auth: dict = Depends(get_auth_context) +): + """Get native tokens held by an address. Fix #13: Now paginated.""" # Fix #8: Validate address if not validate_address(address): raise HTTPException( @@ -1336,12 +1379,26 @@ async def get_address_tokens(address: str, auth: dict = Depends(get_auth_context detail={"error": "invalid_address", "message": "Invalid Cardano address format"} ) - cache_key = f"tokens_{address}" + offset = (page - 1) * limit + cache_key = f"tokens_{address}_{page}_{limit}" cached = await get_cached(cache_key) if cached: return cached async with db_pool.acquire() as conn: + # Get total count for pagination info + count_result = await conn.fetchrow(""" + SELECT COUNT(DISTINCT ma.id) as total + FROM ma_tx_out mto + JOIN multi_asset ma ON ma.id = mto.ident + JOIN tx_out txo ON txo.id = mto.tx_out_id + LEFT JOIN tx_in txi ON txi.tx_out_id = txo.tx_id AND txi.tx_out_index = txo.index + WHERE txo.address = $1 AND txi.id IS NULL + """, address) + + total_count = count_result["total"] if count_result else 0 + + # Fix #13: Add LIMIT and OFFSET for pagination tokens = await conn.fetch(""" SELECT encode(ma.policy, 'hex') as policy_id, @@ -1357,10 +1414,15 @@ async def get_address_tokens(address: str, auth: dict = Depends(get_auth_context GROUP BY ma.id, ma.policy, ma.name, ma.fingerprint HAVING SUM(mto.quantity) > 0 ORDER BY quantity DESC - """, address) + LIMIT $2 OFFSET $3 + """, address, limit, offset) result = { "address": address, + "page": page, + "limit": limit, + "total_count": total_count, + "total_pages": (total_count + limit - 1) // limit if total_count > 0 else 0, "tokens": [ { "policy_id": t["policy_id"], @@ -1547,8 +1609,13 @@ async def get_transaction(tx_hash: str, auth: dict = Depends(get_auth_context)): # ============ Asset Endpoints ============ @app.get("/v1/asset/{policy_id}/info") -async def get_asset_info(policy_id: str, auth: dict = Depends(get_auth_context)): - """Get info about all assets under a policy ID.""" +async def get_asset_info( + policy_id: str, + page: int = Query(1, ge=1, description="Page number"), + limit: int = Query(100, ge=1, le=500, description="Results per page (max 500)"), + auth: dict = Depends(get_auth_context) +): + """Get info about all assets under a policy ID. Fix #13: Now paginated.""" # Fix #8: Validate policy ID if not validate_policy_id(policy_id): raise HTTPException( @@ -1556,7 +1623,8 @@ async def get_asset_info(policy_id: str, auth: dict = Depends(get_auth_context)) detail={"error": "invalid_policy_id", "message": "Invalid policy ID format (expected 56 hex chars)"} ) - cache_key = f"asset_info_{policy_id}" + offset = (page - 1) * limit + cache_key = f"asset_info_{policy_id}_{page}_{limit}" cached = await get_cached(cache_key) if cached: return cached @@ -1564,6 +1632,16 @@ async def get_asset_info(policy_id: str, auth: dict = Depends(get_auth_context)) clean_policy = policy_id.lower().replace("0x", "") async with db_pool.acquire() as conn: + # Get total count for pagination + count_result = await conn.fetchrow(""" + SELECT COUNT(*) as total + FROM multi_asset ma + WHERE ma.policy = decode($1, 'hex') + """, clean_policy) + + total_count = count_result["total"] if count_result else 0 + + # Fix #13: Add LIMIT and OFFSET for pagination assets = await conn.fetch(""" SELECT encode(ma.name, 'hex') as asset_name_hex, @@ -1573,13 +1651,19 @@ async def get_asset_info(policy_id: str, auth: dict = Depends(get_auth_context)) (SELECT COUNT(*) FROM ma_tx_mint WHERE ident = ma.id AND quantity > 0) as mint_count FROM multi_asset ma WHERE ma.policy = decode($1, 'hex') - """, clean_policy) + ORDER BY ma.id + LIMIT $2 OFFSET $3 + """, clean_policy, limit, offset) - if not assets: + if not assets and page == 1: raise HTTPException(status_code=404, detail={"error": "not_found", "message": f"Policy {policy_id} not found"}) result = { "policy_id": policy_id, + "page": page, + "limit": limit, + "total_count": total_count, + "total_pages": (total_count + limit - 1) // limit if total_count > 0 else 0, "assets": [ { "asset_name": a["asset_name"] or a["asset_name_hex"], @@ -1641,15 +1725,17 @@ async def get_asset_holders( LIMIT $2 """, asset["id"], limit) - # Count total holders + # Fix #15: Correct holder count query using subquery holder_count = await conn.fetchrow(""" - SELECT COUNT(DISTINCT txo.address) as count - FROM ma_tx_out mto - JOIN tx_out txo ON txo.id = mto.tx_out_id - LEFT JOIN tx_in txi ON txi.tx_out_id = txo.tx_id AND txi.tx_out_index = txo.index - WHERE mto.ident = $1 AND txi.id IS NULL - GROUP BY txo.address - HAVING SUM(mto.quantity) > 0 + SELECT COUNT(*) as count FROM ( + SELECT txo.address + FROM ma_tx_out mto + JOIN tx_out txo ON txo.id = mto.tx_out_id + LEFT JOIN tx_in txi ON txi.tx_out_id = txo.tx_id AND txi.tx_out_index = txo.index + WHERE mto.ident = $1 AND txi.id IS NULL + GROUP BY txo.address + HAVING SUM(mto.quantity) > 0 + ) sub """, asset["id"]) result = { diff --git a/requirements.txt b/requirements.txt index 163c600..3786021 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,5 +5,5 @@ redis==5.0.0 pydantic==2.9.0 python-dotenv==1.0.0 pycardano==0.11.0 -cbor2==5.6.0 +cbor2>=5.6.5 PyNaCl==1.5.0