diff --git a/main.py b/main.py index dd58005..a5aeabe 100644 --- a/main.py +++ b/main.py @@ -21,9 +21,22 @@ Access Tiers (strictly enforced): Node endpoints (/v1/address/{addr}/utxos, /v1/protocol-params, /v1/tx/submit) return HTTP 403 for insufficient tier. + +Security hardening applied 2026-03-21: +- Fix #1: Atomic nonce GETDEL to prevent race conditions +- Fix #2: X-Forwarded-For only trusted from known proxies +- Fix #3: TRP refresh every 10 min + 48h key expiry for TRP-gated keys +- Fix #4: SHA-256 hashed key storage in Redis +- Fix #5: Generic error messages (no internal detail leakage) +- Fix #6: Auth refresh is self-service only (key refreshes itself) +- Fix #7: CBOR validation before tx submit +- Fix #8: Input validation regex for addresses, tx hashes, policy IDs +- Fix #9: Correct tx hash calculation (blake2b of tx body, not full tx) +- Fix #10: Enforce key expiry globally in get_api_key_info """ import os +import re import json import hashlib import secrets @@ -102,14 +115,50 @@ CACHE_TTLS = { "utxos": 10 # Short cache for UTxOs } +# TRP-gated key expiry (48 hours) +TRP_KEY_EXPIRY_HOURS = 48 + +# Fix #2: Trusted proxies for X-Forwarded-For +TRUSTED_PROXIES = {"127.0.0.1", "::1", "172.22.0.1", "172.17.0.1"} + +# Fix #8: Input validation regexes +ADDR_RE = re.compile(r'^addr1[a-z0-9]{50,120}$') +ADDR_TEST_RE = re.compile(r'^addr_test1[a-z0-9]{50,120}$') +HEX64_RE = re.compile(r'^[a-fA-F0-9]{64}$') +POLICY_RE = re.compile(r'^[a-fA-F0-9]{56}$') + # Global connections db_pool: Optional[asyncpg.Pool] = None redis_client: Optional[redis.Redis] = None protocol_params_cache: dict = {"data": None, "expires": 0} +# ============ Input Validation (Fix #8) ============ + +def validate_address(address: str) -> bool: + """Validate Cardano address format.""" + return bool(ADDR_RE.match(address) or ADDR_TEST_RE.match(address)) + + +def validate_tx_hash(tx_hash: str) -> bool: + """Validate transaction hash format (64 hex chars).""" + clean = tx_hash.lower().replace("0x", "") + return bool(HEX64_RE.match(clean)) + + +def validate_policy_id(policy_id: str) -> bool: + """Validate policy ID format (56 hex chars).""" + clean = policy_id.lower().replace("0x", "") + return bool(POLICY_RE.match(clean)) + + # ============ Helper Functions ============ +def hash_api_key(key: str) -> str: + """Hash an API key for storage/lookup. Fix #4.""" + return hashlib.sha256(key.encode()).hexdigest() + + def run_cardano_cli(args: list[str], timeout: int = 30) -> tuple[bool, str, str]: """Run cardano-cli command and return (success, stdout, stderr).""" cmd = ["cardano-cli"] + args @@ -130,7 +179,8 @@ def run_cardano_cli(args: list[str], timeout: int = 30) -> tuple[bool, str, str] except FileNotFoundError: return False, "", "cardano-cli not found" except Exception as e: - return False, "", str(e) + logger.error(f"cardano-cli error: {e}") + return False, "", "Node command failed" def decode_hex_or_base64(data: str) -> bytes: @@ -255,10 +305,10 @@ def verify_cip8_signature(address: str, nonce: str, signature_hex: str, key_hex: # ============ Lifespan ============ async def refresh_trp_tiers_task(): - """Background task to refresh TRP tiers for all gated keys every hour.""" + """Background task to refresh TRP tiers for all gated keys every 10 minutes. Fix #3.""" while True: try: - await asyncio.sleep(3600) # Run every hour + await asyncio.sleep(600) # Run every 10 minutes (was 3600) await refresh_all_trp_tiers() except asyncio.CancelledError: break @@ -367,7 +417,7 @@ async def lifespan(app: FastAPI): app = FastAPI( title="Cardano Chain Data API", description="REST API for querying Cardano blockchain data via db-sync and cardano-node", - version="2.0.0", + version="2.1.0", # Bumped for security fixes lifespan=lifespan ) @@ -386,11 +436,11 @@ async def handle_undefined_table(request: Request, exc: UndefinedTableError): @app.exception_handler(PostgresError) async def handle_postgres_error(request: Request, exc: PostgresError): - """Handle general postgres errors.""" + """Handle general postgres errors. Fix #5: Don't leak internal details.""" logger.error(f"Database error: {exc}") return JSONResponse( status_code=503, - content={"error": "database_error", "message": str(exc)} + content={"error": "database_error", "message": "Internal database error"} ) @@ -445,11 +495,34 @@ class TxSubmitResponse(BaseModel): # ============ Helpers ============ async def get_api_key_info(key: str) -> Optional[dict]: - """Get API key info from Redis.""" + """ + Get API key info from Redis using hashed key lookup. Fix #4 and Fix #10. + Also enforces key expiry for TRP-gated keys. + """ if not redis_client: return None - data = await redis_client.hgetall(f"apikey:{key}") - return data if data else None + + # Hash the key for lookup (Fix #4) + key_hash = hash_api_key(key) + data = await redis_client.hgetall(f"apikey:{key_hash}") + + if not data: + return None + + # Fix #10: Check expiry for TRP-gated keys + expires_at = data.get("expires_at") + if expires_at: + try: + expiry_time = datetime.fromisoformat(expires_at.replace('Z', '+00:00')) + if datetime.now(timezone.utc) > expiry_time: + # Key expired, delete it and return None + logger.info(f"API key expired for owner {data.get('owner', 'unknown')}") + await redis_client.delete(f"apikey:{key_hash}") + return None + except Exception as e: + logger.error(f"Error parsing expires_at: {e}") + + return data async def check_rate_limit(identifier: str, tier: str, limit_type: str = "general") -> tuple[bool, int]: @@ -488,11 +561,18 @@ async def set_cached(cache_key: str, data: dict, ttl: int): def get_client_ip(request: Request) -> str: - """Extract client IP from request.""" - forwarded = request.headers.get("X-Forwarded-For") - if forwarded: - return forwarded.split(",")[0].strip() - return request.client.host if request.client else "unknown" + """ + Extract client IP from request. Fix #2: Only trust X-Forwarded-For from known proxies. + """ + client_host = request.client.host if request.client else "unknown" + + # Only trust X-Forwarded-For if connecting IP is a trusted proxy + if client_host in TRUSTED_PROXIES: + forwarded = request.headers.get("X-Forwarded-For") + if forwarded: + return forwarded.split(",")[0].strip() + + return client_host # ============ Auth Dependency ============ @@ -514,9 +594,10 @@ async def get_auth_context( if key_info: return { "tier": key_info.get("tier", "standard"), - "identifier": key, + "identifier": hash_api_key(key), # Use hash as identifier "label": key_info.get("label", "unknown"), - "owner": key_info.get("owner") + "owner": key_info.get("owner"), + "raw_key": key # Keep raw key for refresh endpoint } return {"tier": "anonymous", "identifier": client_ip, "label": None} @@ -595,13 +676,13 @@ async def rate_limit_middleware(request: Request, call_next): key_info = await get_api_key_info(key) if key_info: tier = key_info.get("tier", "standard") - identifier = key + identifier = hash_api_key(key) # Use hash as identifier label = key_info.get("label", "unknown") # Check rate limit (general) allowed, retry_after = await check_rate_limit(identifier, tier) if not allowed: - logger.warning(f"Rate limit exceeded: {identifier} ({tier})") + logger.warning(f"Rate limit exceeded: {identifier[:16]}... ({tier})") return JSONResponse( status_code=429, content={"error": "rate_limit_exceeded", "retry_after": retry_after} @@ -610,9 +691,10 @@ async def rate_limit_middleware(request: Request, call_next): # Process request response = await call_next(request) - # Log request + # Log request (don't log full key hash for security) elapsed = int((time.time() - start_time) * 1000) - logger.info(f"{request.method} {request.url.path} | {label or client_ip} | {elapsed}ms | {response.status_code}") + log_id = label or client_ip + logger.info(f"{request.method} {request.url.path} | {log_id} | {elapsed}ms | {response.status_code}") return response @@ -688,6 +770,13 @@ async def get_address_utxos(address: str, auth: dict = Depends(require_standard_ Requires: standard tier (50+ TRP) or higher. Anonymous users should use /v1/address/{address}/balance (db-sync) instead. """ + # Fix #8: Validate address format + if not validate_address(address): + raise HTTPException( + status_code=400, + detail={"error": "invalid_address", "message": "Invalid Cardano address format"} + ) + cache_key = f"utxos_{address}" cached = await get_cached(cache_key) if cached: @@ -707,9 +796,11 @@ async def get_address_utxos(address: str, auth: dict = Depends(require_standard_ status_code=503, detail={"error": "node_unavailable", "message": "Cardano node not available"} ) + # Fix #5: Don't leak stderr details + logger.error(f"cardano-cli utxo query failed: {stderr}") raise HTTPException( status_code=500, - detail={"error": "node_error", "message": stderr} + detail={"error": "node_error", "message": "Node command failed"} ) try: @@ -789,6 +880,27 @@ async def submit_transaction( detail={"error": "invalid_encoding", "message": str(e)} ) + # Fix #7: Validate CBOR before proceeding + try: + tx_cbor = cbor2.loads(tx_bytes) + except Exception: + raise HTTPException( + status_code=400, + detail={"error": "invalid_tx", "message": "Transaction is not valid CBOR"} + ) + + # Fix #9: Calculate correct tx hash from tx body (index 0), not full tx + try: + if isinstance(tx_cbor, (list, tuple)) and len(tx_cbor) > 0: + tx_body_cbor = cbor2.dumps(tx_cbor[0]) + tx_hash = hashlib.blake2b(tx_body_cbor, digest_size=32).hexdigest() + else: + # Fallback if structure is unexpected + tx_hash = hashlib.blake2b(tx_bytes, digest_size=32).hexdigest() + except Exception as e: + logger.error(f"Error calculating tx hash: {e}") + tx_hash = hashlib.blake2b(tx_bytes, digest_size=32).hexdigest() + # Write to temp file for cardano-cli with tempfile.NamedTemporaryFile(suffix=".signed", delete=False) as f: f.write(tx_bytes) @@ -808,7 +920,7 @@ async def submit_transaction( status_code=503, detail={"error": "node_unavailable", "message": "Cardano node not available"} ) - # Parse common errors + # Parse common errors (these are safe to expose) if "OutsideValidityIntervalUTxO" in stderr: raise HTTPException( status_code=400, @@ -824,21 +936,14 @@ async def submit_transaction( status_code=400, detail={"error": "value_not_conserved", "message": "Input/output value mismatch"} ) + # Fix #5: Don't leak full stderr for other errors + logger.error(f"tx submit failed: {stderr}") raise HTTPException( status_code=400, - detail={"error": "submit_failed", "message": stderr.strip()} + detail={"error": "submit_failed", "message": "Node command failed"} ) - # Calculate tx hash from the CBOR - # The hash is blake2b-256 of the transaction body - tx_hash = hashlib.blake2b(tx_bytes, digest_size=32).hexdigest() - - # Try to get actual hash from output - if "Transaction successfully submitted" in stdout or success: - # cardano-cli doesn't output the hash, but we calculated it - pass - - logger.info(f"Transaction submitted: {tx_hash[:16]}... by {auth['identifier']}") + logger.info(f"Transaction submitted: {tx_hash[:16]}... by {auth['identifier'][:16]}...") return TxSubmitResponse( tx_hash=tx_hash, @@ -879,9 +984,11 @@ async def get_protocol_params(auth: dict = Depends(require_standard_tier)): status_code=503, detail={"error": "node_unavailable", "message": "Cardano node not available"} ) + # Fix #5: Don't leak stderr + logger.error(f"protocol-params query failed: {stderr}") raise HTTPException( status_code=500, - detail={"error": "node_error", "message": stderr} + detail={"error": "node_error", "message": "Node command failed"} ) try: @@ -911,8 +1018,8 @@ async def create_auth_challenge(request: AuthChallengeRequest): """ address = request.address - # Validate address format - if not address.startswith(("addr1", "addr_test1")): + # Fix #8: Validate address format + if not validate_address(address): raise HTTPException( status_code=400, detail={"error": "invalid_address", "message": "Invalid Cardano address format"} @@ -944,22 +1051,32 @@ async def verify_auth(request: AuthVerifyRequest): - 0 TRP: anonymous (no key issued) - 50+ TRP: standard (100 req/min) - 500+ TRP: elevated (1000 req/min) + + TRP-gated keys expire after 48 hours and must be re-authenticated. """ address = request.address nonce = request.nonce - # Check nonce exists and hasn't expired - challenge_key = f"auth_challenge:{address}:{nonce}" - challenge_data = await redis_client.get(challenge_key) - - if not challenge_data: + # Fix #8: Validate address + if not validate_address(address): raise HTTPException( status_code=400, - detail={"error": "invalid_nonce", "message": "Nonce expired or invalid"} + detail={"error": "invalid_address", "message": "Invalid Cardano address format"} ) - # Delete nonce (one-time use) - await redis_client.delete(challenge_key) + # Fix #1: Atomic nonce check-and-delete using pipeline + challenge_key = f"auth_challenge:{address}:{nonce}" + + async with redis_client.pipeline(transaction=True) as pipe: + await pipe.get(challenge_key) + await pipe.delete(challenge_key) + challenge_data, deleted = await pipe.execute() + + if not challenge_data or not deleted: + raise HTTPException( + status_code=400, + detail={"error": "invalid_nonce", "message": "Nonce expired or already used"} + ) # Verify CIP-8 signature if not verify_cip8_signature(address, nonce, request.signature, request.key): @@ -987,17 +1104,24 @@ async def verify_auth(request: AuthVerifyRequest): # Generate API key new_key = f"capi_{secrets.token_hex(24)}" - # Store key info - await redis_client.hset(f"apikey:{new_key}", mapping={ + # Fix #3: Set expiry 48 hours from now for TRP-gated keys + expires_at = datetime.now(timezone.utc) + timedelta(hours=TRP_KEY_EXPIRY_HOURS) + + # Fix #4: Store using hashed key, not raw key + key_hash = hash_api_key(new_key) + + await redis_client.hset(f"apikey:{key_hash}", mapping={ "label": f"TRP-gated:{address[:20]}...", "tier": tier, "owner": address, "trp_balance": str(trp_balance), - "created_at": datetime.now(timezone.utc).isoformat() + "created_at": datetime.now(timezone.utc).isoformat(), + "expires_at": expires_at.isoformat() # Fix #3: Add expiry }) - logger.info(f"Issued {tier} API key for {address[:20]}... (TRP: {trp_balance})") + logger.info(f"Issued {tier} API key for {address[:20]}... (TRP: {trp_balance}, expires: {expires_at.isoformat()})") + # Return the raw key to the user (only time it's exposed) return AuthVerifyResponse( api_key=new_key, tier=tier, @@ -1012,6 +1136,10 @@ async def refresh_auth( ): """ Re-check TRP balance and upgrade/downgrade tier for an existing key. + + Fix #6: This endpoint is self-service only. The API key in the header + is both the authentication AND the key being refreshed. You cannot + use key A to refresh key B. """ if auth["tier"] == "anonymous": raise HTTPException( @@ -1037,10 +1165,17 @@ async def refresh_auth( new_tier = get_tier_from_trp_balance(trp_balance) old_tier = auth["tier"] + # Fix #3: Reset expiry on refresh + expires_at = datetime.now(timezone.utc) + timedelta(hours=TRP_KEY_EXPIRY_HOURS) + + # Use hashed key for storage (Fix #4) + key_hash = hash_api_key(x_api_key) + # Update key info - await redis_client.hset(f"apikey:{x_api_key}", mapping={ + await redis_client.hset(f"apikey:{key_hash}", mapping={ "tier": new_tier, - "trp_balance": str(trp_balance) + "trp_balance": str(trp_balance), + "expires_at": expires_at.isoformat() # Reset expiry }) tier_changed = new_tier != old_tier @@ -1052,6 +1187,7 @@ async def refresh_auth( "previous_tier": old_tier, "trp_balance": trp_balance, "tier_changed": tier_changed, + "expires_at": expires_at.isoformat(), "message": message } @@ -1129,6 +1265,13 @@ async def get_block(block_no: int, auth: dict = Depends(get_auth_context)): @app.get("/v1/address/{address}/balance") async def get_address_balance(address: str, auth: dict = Depends(get_auth_context)): """Get address balance including native tokens.""" + # Fix #8: Validate address + if not validate_address(address): + raise HTTPException( + status_code=400, + detail={"error": "invalid_address", "message": "Invalid Cardano address format"} + ) + cache_key = f"balance_{address}" cached = await get_cached(cache_key) if cached: @@ -1186,6 +1329,13 @@ async def get_address_balance(address: str, auth: dict = Depends(get_auth_contex @app.get("/v1/address/{address}/tokens") async def get_address_tokens(address: str, auth: dict = Depends(get_auth_context)): """Get native tokens held by an address.""" + # Fix #8: Validate address + if not validate_address(address): + raise HTTPException( + status_code=400, + detail={"error": "invalid_address", "message": "Invalid Cardano address format"} + ) + cache_key = f"tokens_{address}" cached = await get_cached(cache_key) if cached: @@ -1236,6 +1386,13 @@ async def get_address_transactions( auth: dict = Depends(get_auth_context) ): """Get transactions for an address.""" + # Fix #8: Validate address + if not validate_address(address): + raise HTTPException( + status_code=400, + detail={"error": "invalid_address", "message": "Invalid Cardano address format"} + ) + cache_key = f"txs_{address}_{page}_{limit}_{order}" cached = await get_cached(cache_key) if cached: @@ -1304,6 +1461,13 @@ async def get_address_transactions( @app.get("/v1/tx/{tx_hash}") async def get_transaction(tx_hash: str, auth: dict = Depends(get_auth_context)): """Get transaction details by hash.""" + # Fix #8: Validate tx hash + if not validate_tx_hash(tx_hash): + raise HTTPException( + status_code=400, + detail={"error": "invalid_tx_hash", "message": "Invalid transaction hash format (expected 64 hex chars)"} + ) + cache_key = f"tx_{tx_hash}" cached = await get_cached(cache_key) if cached: @@ -1385,6 +1549,13 @@ async def get_transaction(tx_hash: str, auth: dict = Depends(get_auth_context)): @app.get("/v1/asset/{policy_id}/info") async def get_asset_info(policy_id: str, auth: dict = Depends(get_auth_context)): """Get info about all assets under a policy ID.""" + # Fix #8: Validate policy ID + if not validate_policy_id(policy_id): + raise HTTPException( + status_code=400, + detail={"error": "invalid_policy_id", "message": "Invalid policy ID format (expected 56 hex chars)"} + ) + cache_key = f"asset_info_{policy_id}" cached = await get_cached(cache_key) if cached: @@ -1433,6 +1604,13 @@ async def get_asset_holders( auth: dict = Depends(get_auth_context) ): """Get top holders of a specific asset.""" + # Fix #8: Validate policy ID + if not validate_policy_id(policy_id): + raise HTTPException( + status_code=400, + detail={"error": "invalid_policy_id", "message": "Invalid policy ID format (expected 56 hex chars)"} + ) + cache_key = f"holders_{policy_id}_{asset_name}_{limit}" cached = await get_cached(cache_key) if cached: @@ -1542,19 +1720,24 @@ async def get_pool_info(pool_id: str, auth: dict = Depends(get_auth_context)): @app.post("/admin/keys") async def create_api_key(key_data: APIKeyCreate, auth: dict = Depends(require_master_key)): - """Create a new API key.""" + """Create a new API key (admin-created keys don't expire unless explicitly set).""" new_key = f"capi_{secrets.token_hex(24)}" - await redis_client.hset(f"apikey:{new_key}", mapping={ + # Fix #4: Store using hashed key + key_hash = hash_api_key(new_key) + + # Admin-created keys don't expire by default (Fix #10) + await redis_client.hset(f"apikey:{key_hash}", mapping={ "label": key_data.label, "tier": key_data.tier, "owner": key_data.owner or "", "trp_balance": str(key_data.trp_balance or 0), "created_at": datetime.now(timezone.utc).isoformat() + # No expires_at for admin-created keys }) return APIKeyResponse( - key=new_key, + key=new_key, # Return raw key only once label=key_data.label, tier=key_data.tier, owner=key_data.owner, @@ -1565,15 +1748,17 @@ async def create_api_key(key_data: APIKeyCreate, auth: dict = Depends(require_ma @app.delete("/admin/keys/{key}") async def revoke_api_key(key: str, auth: dict = Depends(require_master_key)): """Revoke an API key.""" - deleted = await redis_client.delete(f"apikey:{key}") + # Fix #4: Hash the key for lookup + key_hash = hash_api_key(key) + deleted = await redis_client.delete(f"apikey:{key_hash}") if not deleted: raise HTTPException(status_code=404, detail={"error": "not_found", "message": "API key not found"}) - return {"status": "revoked", "key": key} + return {"status": "revoked", "key": key[:16] + "..."} @app.get("/admin/keys") async def list_api_keys(auth: dict = Depends(require_master_key)): - """List all API keys.""" + """List all API keys (shows metadata, not raw keys since we only store hashes).""" keys = [] cursor = 0 while True: @@ -1581,13 +1766,16 @@ async def list_api_keys(auth: dict = Depends(require_master_key)): for key in found_keys: data = await redis_client.hgetall(key) if data: + # Key is stored as hash, so we show the hash (last 8 chars) for identification + key_hash = key.replace("apikey:", "") keys.append({ - "key": key.replace("apikey:", ""), + "key_hash_suffix": f"...{key_hash[-8:]}", "label": data.get("label"), "tier": data.get("tier"), "owner": data.get("owner") or None, "trp_balance": int(data.get("trp_balance", 0)), - "created_at": data.get("created_at") + "created_at": data.get("created_at"), + "expires_at": data.get("expires_at") }) if cursor == 0: break