clients/python: apply audit findings (90e158f → next)
- H1: quote slug in revoke_token - H2: redact AppToken.token in repr/str - M1-M6: wrap stdlib exceptions in ForgeError, validate timeouts, document uploads - L1/L5/L7: type-strict, immutable ip_cidrs, validate ok field - Bump requests floor to 2.32 Audit: memory/clawdforge-audits/python-90e158f.md
This commit is contained in:
parent
cc54cfbe6c
commit
1b097a21be
5 changed files with 421 additions and 50 deletions
|
|
@ -20,7 +20,7 @@ classifiers = [
|
|||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
]
|
||||
dependencies = [
|
||||
"requests>=2.28",
|
||||
"requests>=2.32",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ from __future__ import annotations
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from .exceptions import ForgeAPIError, ForgeError
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RunResult:
|
||||
|
|
@ -24,12 +26,30 @@ class RunResult:
|
|||
stop_reason: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_response(cls, payload: dict) -> "RunResult":
|
||||
def from_response(cls, payload: dict[str, Any]) -> "RunResult":
|
||||
try:
|
||||
ok = bool(payload.get("ok", True))
|
||||
duration_ms = int(payload.get("duration_ms", 0))
|
||||
stop_reason = payload.get("stop_reason")
|
||||
result = payload.get("result")
|
||||
except (TypeError, ValueError, KeyError) as e:
|
||||
raise ForgeError(f"malformed /run response: {e}") from e
|
||||
|
||||
# Server contract: failures come back as 502 and never reach this
|
||||
# parser. If `ok=False` slips through anyway, surface it loudly so
|
||||
# callers don't silently treat it as success.
|
||||
if not ok:
|
||||
raise ForgeAPIError(
|
||||
"server returned ok=False on /run (contract violation)",
|
||||
status_code=200,
|
||||
body=payload,
|
||||
)
|
||||
|
||||
return cls(
|
||||
ok=bool(payload.get("ok", True)),
|
||||
result=payload.get("result"),
|
||||
duration_ms=int(payload.get("duration_ms", 0)),
|
||||
stop_reason=payload.get("stop_reason"),
|
||||
ok=ok,
|
||||
result=result,
|
||||
duration_ms=duration_ms,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -48,12 +68,15 @@ class FileToken:
|
|||
size: int
|
||||
|
||||
@classmethod
|
||||
def from_response(cls, payload: dict) -> "FileToken":
|
||||
return cls(
|
||||
file_token=payload["file_token"],
|
||||
ttl_secs=int(payload["ttl_secs"]),
|
||||
size=int(payload["size"]),
|
||||
)
|
||||
def from_response(cls, payload: dict[str, Any]) -> "FileToken":
|
||||
try:
|
||||
return cls(
|
||||
file_token=payload["file_token"],
|
||||
ttl_secs=int(payload["ttl_secs"]),
|
||||
size=int(payload["size"]),
|
||||
)
|
||||
except (KeyError, TypeError, ValueError) as e:
|
||||
raise ForgeError(f"malformed /files response: {e}") from e
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
|
|
@ -66,33 +89,53 @@ class AppToken:
|
|||
|
||||
name: str
|
||||
token: str | None = None
|
||||
ip_cidrs: list[str] = field(default_factory=list)
|
||||
ip_cidrs: tuple[str, ...] = field(default_factory=tuple)
|
||||
created_at: int | None = None
|
||||
last_used: int | None = None
|
||||
enabled: bool = True
|
||||
|
||||
@classmethod
|
||||
def from_create_response(cls, payload: dict) -> "AppToken":
|
||||
return cls(
|
||||
name=payload["name"],
|
||||
token=payload.get("token"),
|
||||
ip_cidrs=list(payload.get("ip_cidrs") or []),
|
||||
def __repr__(self) -> str:
|
||||
# Redact the plaintext bearer in repr so accidental log.info(token)
|
||||
# calls don't leak the secret. Keep a flag indicating presence so
|
||||
# debugging "is token populated" still works without revealing it.
|
||||
token_repr = "'<redacted>'" if self.token is not None else "None"
|
||||
return (
|
||||
f"AppToken(name={self.name!r}, token={token_repr}, "
|
||||
f"ip_cidrs={self.ip_cidrs!r}, created_at={self.created_at!r}, "
|
||||
f"last_used={self.last_used!r}, enabled={self.enabled!r})"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__repr__()
|
||||
|
||||
@classmethod
|
||||
def from_list_row(cls, row: dict) -> "AppToken":
|
||||
def from_create_response(cls, payload: dict[str, Any]) -> "AppToken":
|
||||
try:
|
||||
return cls(
|
||||
name=payload["name"],
|
||||
token=payload.get("token"),
|
||||
ip_cidrs=tuple(payload.get("ip_cidrs") or ()),
|
||||
)
|
||||
except (KeyError, TypeError) as e:
|
||||
raise ForgeError(f"malformed /admin/tokens create response: {e}") from e
|
||||
|
||||
@classmethod
|
||||
def from_list_row(cls, row: dict[str, Any]) -> "AppToken":
|
||||
# `ip_cidrs` from the server's list endpoint is a comma-joined string
|
||||
# (see store.list_tokens) — we normalize to list[str] here.
|
||||
raw = row.get("ip_cidrs", "")
|
||||
if isinstance(raw, str):
|
||||
cidrs = [s for s in raw.split(",") if s]
|
||||
else:
|
||||
cidrs = list(raw or [])
|
||||
return cls(
|
||||
name=row["name"],
|
||||
token=None,
|
||||
ip_cidrs=cidrs,
|
||||
created_at=row.get("created_at"),
|
||||
last_used=row.get("last_used"),
|
||||
enabled=bool(row.get("enabled", 1)),
|
||||
)
|
||||
# (see store.list_tokens) — we normalize to tuple[str, ...] here.
|
||||
try:
|
||||
raw = row.get("ip_cidrs", "")
|
||||
if isinstance(raw, str):
|
||||
cidrs: tuple[str, ...] = tuple(s for s in raw.split(",") if s)
|
||||
else:
|
||||
cidrs = tuple(raw or ())
|
||||
return cls(
|
||||
name=row["name"],
|
||||
token=None,
|
||||
ip_cidrs=cidrs,
|
||||
created_at=row.get("created_at"),
|
||||
last_used=row.get("last_used"),
|
||||
enabled=bool(row.get("enabled", 1)),
|
||||
)
|
||||
except (KeyError, TypeError) as e:
|
||||
raise ForgeError(f"malformed /admin/tokens list row: {e}") from e
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from __future__ import annotations
|
|||
import os
|
||||
from pathlib import Path
|
||||
from typing import IO, Any
|
||||
from urllib.parse import quote
|
||||
|
||||
import requests
|
||||
|
||||
|
|
@ -33,6 +34,10 @@ _DEFAULT_RUN_TIMEOUT_SECS = 120
|
|||
# pattern in cauldron's inline Forge wrapper.
|
||||
_HTTP_TIMEOUT_MARGIN_SECS = 30
|
||||
_HEALTHZ_TIMEOUT_SECS = 10
|
||||
# Server-side range for /run timeout_secs — mirrored locally so we fail
|
||||
# fast before hitting requests with absurd values.
|
||||
_RUN_TIMEOUT_MIN = 5
|
||||
_RUN_TIMEOUT_MAX = 600
|
||||
|
||||
|
||||
class Forge:
|
||||
|
|
@ -85,7 +90,12 @@ class Forge:
|
|||
# -- lifecycle ---------------------------------------------------------
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the underlying ``requests.Session`` if we own it."""
|
||||
"""Close the underlying ``requests.Session`` if we own it.
|
||||
|
||||
Idempotent. Note: after `__exit__` (or `close()`), this Forge
|
||||
instance is no longer usable for further requests if it owned the
|
||||
session — construct a fresh `Forge` instead of reusing.
|
||||
"""
|
||||
if self._owns_session:
|
||||
self._session.close()
|
||||
|
||||
|
|
@ -111,9 +121,9 @@ class Forge:
|
|||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
json_body: dict | None = None,
|
||||
data: dict | None = None,
|
||||
files: dict | None = None,
|
||||
json_body: dict[str, Any] | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
files: dict[str, Any] | None = None,
|
||||
timeout: float | tuple[float, float] | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
) -> Any:
|
||||
|
|
@ -143,7 +153,7 @@ class Forge:
|
|||
def _parse(resp: requests.Response) -> Any:
|
||||
# Try JSON first; fall back to text. Some error responses (502 in
|
||||
# particular) carry JSON bodies that we want to surface verbatim.
|
||||
body: dict | str | None
|
||||
body: dict[str, Any] | str | None
|
||||
try:
|
||||
body = resp.json()
|
||||
except ValueError:
|
||||
|
|
@ -164,16 +174,21 @@ class Forge:
|
|||
|
||||
# -- /healthz ----------------------------------------------------------
|
||||
|
||||
def healthz(self) -> dict:
|
||||
def healthz(self) -> dict[str, Any]:
|
||||
"""``GET /healthz``.
|
||||
|
||||
Returns:
|
||||
``{"ok": True, "claude_present": bool, "claude_version": str | None}``
|
||||
|
||||
Raises:
|
||||
ForgeTransportError, ForgeAPIError
|
||||
ForgeTransportError, ForgeAPIError, ForgeError (non-dict response)
|
||||
"""
|
||||
return self._request("GET", "/healthz", timeout=_HEALTHZ_TIMEOUT_SECS)
|
||||
payload = self._request("GET", "/healthz", timeout=_HEALTHZ_TIMEOUT_SECS)
|
||||
if not isinstance(payload, dict):
|
||||
raise ForgeError(
|
||||
f"unexpected /healthz response type: {type(payload).__name__}"
|
||||
)
|
||||
return payload
|
||||
|
||||
# -- /run --------------------------------------------------------------
|
||||
|
||||
|
|
@ -206,11 +221,20 @@ class Forge:
|
|||
``stderr``, ``duration_ms``, ``stop_reason``.
|
||||
ForgeAuthError: bad token / IP not allowed.
|
||||
ForgeTransportError: connection-level failure.
|
||||
ValueError: empty prompt.
|
||||
ValueError: empty prompt or out-of-range timeout_secs.
|
||||
"""
|
||||
if not prompt:
|
||||
raise ValueError("prompt must be non-empty")
|
||||
|
||||
if timeout_secs is not None:
|
||||
if not isinstance(timeout_secs, int) or isinstance(timeout_secs, bool):
|
||||
raise ValueError("timeout_secs must be int")
|
||||
if timeout_secs < _RUN_TIMEOUT_MIN or timeout_secs > _RUN_TIMEOUT_MAX:
|
||||
raise ValueError(
|
||||
f"timeout_secs out of range "
|
||||
f"({_RUN_TIMEOUT_MIN}..{_RUN_TIMEOUT_MAX}), got {timeout_secs}"
|
||||
)
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
"model": model or self.default_model,
|
||||
|
|
@ -222,7 +246,12 @@ class Forge:
|
|||
if timeout_secs is not None:
|
||||
body["timeout_secs"] = timeout_secs
|
||||
|
||||
effective_run_timeout = timeout_secs or self.default_timeout_secs
|
||||
# Use `is not None` instead of `or` so that the (already-rejected
|
||||
# above) `timeout_secs=0` case can't silently fall through to the
|
||||
# default; preserves intent against future range loosening.
|
||||
effective_run_timeout = (
|
||||
timeout_secs if timeout_secs is not None else self.default_timeout_secs
|
||||
)
|
||||
http_timeout = effective_run_timeout + self.http_timeout_margin
|
||||
|
||||
payload = self._request("POST", "/run", json_body=body, timeout=http_timeout)
|
||||
|
|
@ -239,6 +268,7 @@ class Forge:
|
|||
ttl_secs: int = 3600,
|
||||
filename: str | None = None,
|
||||
content_type: str | None = None,
|
||||
follow_symlinks: bool = True,
|
||||
) -> FileToken:
|
||||
"""``POST /files``: upload a file, get back a ``ff_...`` token.
|
||||
|
||||
|
|
@ -252,16 +282,44 @@ class Forge:
|
|||
``"upload"`` when given a raw file object without ``.name``.
|
||||
content_type: optional MIME type. If None, requests/servers
|
||||
default to ``application/octet-stream``.
|
||||
follow_symlinks: if True (default), a path that is a symlink is
|
||||
opened and its target uploaded. Set False to refuse symlinks
|
||||
explicitly (raises ForgeError). Has no effect when
|
||||
``path_or_fileobj`` is a file-like object.
|
||||
|
||||
Caveats:
|
||||
- The entire file body is materialized in memory before the
|
||||
request fires (a `requests` library limitation when using the
|
||||
`files=` multipart kwarg). For large files (>~100 MB) consider
|
||||
streaming via your own multipart encoder. The clawdforge server
|
||||
caps uploads at 25 MiB by default in any case.
|
||||
- Symlinks are followed by default — see ``follow_symlinks`` to
|
||||
opt out.
|
||||
- Non-existent paths surface as ``ForgeError``, not the underlying
|
||||
``FileNotFoundError``.
|
||||
|
||||
Returns:
|
||||
FileToken
|
||||
|
||||
Raises:
|
||||
ForgeError: path missing, not a file, or a refused symlink.
|
||||
ForgeAPIError / ForgeAuthError / ForgeTransportError: server-side.
|
||||
"""
|
||||
# Resolve to (filename, fileobj, opened-by-us-flag)
|
||||
opened_here = False
|
||||
fileobj: IO[bytes]
|
||||
if isinstance(path_or_fileobj, (str, os.PathLike)):
|
||||
p = Path(path_or_fileobj)
|
||||
fileobj = p.open("rb")
|
||||
if p.is_symlink() and not follow_symlinks:
|
||||
raise ForgeError(
|
||||
f"refusing to upload symlink (follow_symlinks=False): {p}"
|
||||
)
|
||||
try:
|
||||
fileobj = p.open("rb")
|
||||
except FileNotFoundError as e:
|
||||
raise ForgeError(f"file not found: {p}") from e
|
||||
except OSError as e:
|
||||
raise ForgeError(f"failed to open {p}: {e}") from e
|
||||
opened_here = True
|
||||
resolved_name = filename or p.name
|
||||
else:
|
||||
|
|
@ -319,8 +377,19 @@ class Forge:
|
|||
Returns True on success. Raises ForgeAPIError(404) if no such token
|
||||
exists rather than returning False — that matches the server's
|
||||
contract and lets callers distinguish "missing" from "revoked".
|
||||
|
||||
Raises:
|
||||
ValueError: empty `name`.
|
||||
ForgeAPIError: 4xx/5xx from server.
|
||||
ForgeAuthError, ForgeTransportError: as usual.
|
||||
"""
|
||||
payload = self._request("DELETE", f"/admin/tokens/{name}", timeout=10)
|
||||
if not name:
|
||||
raise ValueError("name must be non-empty")
|
||||
# `safe=''` so '/', '..', and other reserved chars get percent-encoded
|
||||
# — defends against path traversal like revoke_token("../../healthz")
|
||||
# which would otherwise issue DELETE /healthz.
|
||||
slug = quote(name, safe="")
|
||||
payload = self._request("DELETE", f"/admin/tokens/{slug}", timeout=10)
|
||||
if isinstance(payload, dict):
|
||||
return bool(payload.get("ok", True))
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ without sniffing status codes by hand.
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ForgeError(Exception):
|
||||
"""Base exception for everything the clawdforge SDK raises.
|
||||
|
|
@ -37,7 +39,7 @@ class ForgeAPIError(ForgeError):
|
|||
message: str,
|
||||
*,
|
||||
status_code: int,
|
||||
body: dict | str | None = None,
|
||||
body: dict[str, Any] | str | None = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
|
|
|||
|
|
@ -6,8 +6,10 @@ from __future__ import annotations
|
|||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import responses
|
||||
|
|
@ -315,7 +317,7 @@ class TestAdminTokens(unittest.TestCase):
|
|||
self.assertIsInstance(t, AppToken)
|
||||
self.assertEqual(t.name, "cauldron")
|
||||
self.assertEqual(t.token, "cf_brandnew_xxx")
|
||||
self.assertEqual(t.ip_cidrs, ["172.24.0.0/16"])
|
||||
self.assertEqual(t.ip_cidrs, ("172.24.0.0/16",))
|
||||
|
||||
@responses.activate
|
||||
def test_list_tokens(self) -> None:
|
||||
|
|
@ -346,10 +348,10 @@ class TestAdminTokens(unittest.TestCase):
|
|||
toks = f.list_tokens()
|
||||
self.assertEqual(len(toks), 2)
|
||||
self.assertEqual(toks[0].name, "cauldron")
|
||||
self.assertEqual(toks[0].ip_cidrs, ["172.24.0.0/16"])
|
||||
self.assertEqual(toks[0].ip_cidrs, ("172.24.0.0/16",))
|
||||
self.assertTrue(toks[0].enabled)
|
||||
self.assertIsNone(toks[0].token)
|
||||
self.assertEqual(toks[1].ip_cidrs, [])
|
||||
self.assertEqual(toks[1].ip_cidrs, ())
|
||||
self.assertFalse(toks[1].enabled)
|
||||
|
||||
@responses.activate
|
||||
|
|
@ -407,5 +409,260 @@ class TestForgeConstruction(unittest.TestCase):
|
|||
sess.close()
|
||||
|
||||
|
||||
class TestRevokeTokenSlugQuoting(unittest.TestCase):
|
||||
"""H1: revoke_token must percent-encode the slug to defeat path traversal."""
|
||||
|
||||
def test_revoke_token_empty_name_rejected(self) -> None:
|
||||
with _forge() as f, self.assertRaises(ValueError):
|
||||
f.revoke_token("")
|
||||
|
||||
@responses.activate
|
||||
def test_revoke_token_path_traversal_is_quoted(self) -> None:
|
||||
# Pre-fix this would issue DELETE /healthz; post-fix it must hit the
|
||||
# admin/tokens endpoint with the slug percent-encoded.
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
def cb(request: requests.PreparedRequest) -> tuple[int, dict[str, Any], str]:
|
||||
captured["url"] = request.url
|
||||
return (404, {}, json.dumps({"detail": "no such token"}))
|
||||
|
||||
# Match any URL under /admin/tokens/ — the bug shape would route
|
||||
# somewhere else entirely.
|
||||
responses.add_callback(
|
||||
responses.DELETE,
|
||||
f"{BASE_URL}/admin/tokens/..%2F..%2Fhealthz",
|
||||
callback=cb,
|
||||
)
|
||||
with _forge() as f, self.assertRaises(ForgeAPIError):
|
||||
f.revoke_token("../../healthz")
|
||||
# Must have been routed under /admin/tokens/, not /healthz.
|
||||
self.assertIn("/admin/tokens/", captured["url"])
|
||||
self.assertNotIn("/healthz", captured["url"].split("/admin/tokens/")[0])
|
||||
|
||||
@responses.activate
|
||||
def test_revoke_token_slash_in_name_quoted(self) -> None:
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
def cb(request: requests.PreparedRequest) -> tuple[int, dict[str, Any], str]:
|
||||
captured["url"] = request.url
|
||||
return (200, {}, json.dumps({"ok": True}))
|
||||
|
||||
responses.add_callback(
|
||||
responses.DELETE,
|
||||
f"{BASE_URL}/admin/tokens/foo%2Fbar",
|
||||
callback=cb,
|
||||
)
|
||||
with _forge() as f:
|
||||
self.assertTrue(f.revoke_token("foo/bar"))
|
||||
self.assertIn("foo%2Fbar", captured["url"])
|
||||
|
||||
|
||||
class TestAppTokenRedaction(unittest.TestCase):
|
||||
"""H2: AppToken repr/str must redact the plaintext bearer."""
|
||||
|
||||
def test_repr_redacts_token(self) -> None:
|
||||
t = AppToken(name="x", token="cf_secret_xxxxxxxx", ip_cidrs=("10.0.0.0/8",))
|
||||
r = repr(t)
|
||||
self.assertNotIn("cf_secret_xxxxxxxx", r)
|
||||
self.assertIn("<redacted>", r)
|
||||
self.assertIn("name='x'", r)
|
||||
|
||||
def test_str_redacts_token(self) -> None:
|
||||
t = AppToken(name="x", token="cf_secret_xxxxxxxx")
|
||||
s = str(t)
|
||||
self.assertNotIn("cf_secret_xxxxxxxx", s)
|
||||
self.assertIn("<redacted>", s)
|
||||
|
||||
def test_repr_token_none_shows_none(self) -> None:
|
||||
t = AppToken(name="x", token=None)
|
||||
r = repr(t)
|
||||
self.assertIn("token=None", r)
|
||||
self.assertNotIn("<redacted>", r)
|
||||
|
||||
def test_format_string_doesnt_leak(self) -> None:
|
||||
# `log.info("token: %s", t)` is the worry case — uses __str__.
|
||||
t = AppToken(name="cauldron", token="cf_super_secret")
|
||||
formatted = f"token: {t}"
|
||||
self.assertNotIn("cf_super_secret", formatted)
|
||||
|
||||
|
||||
class TestModelExceptionWrapping(unittest.TestCase):
|
||||
"""M1, M2: stdlib exceptions in *.from_response must surface as ForgeError."""
|
||||
|
||||
def test_run_result_malformed_duration(self) -> None:
|
||||
with self.assertRaises(ForgeError):
|
||||
RunResult.from_response(
|
||||
{"ok": True, "result": "x", "duration_ms": "not-an-int"}
|
||||
)
|
||||
|
||||
def test_run_result_ok_false_raises_api_error(self) -> None:
|
||||
# L7: ok=False should raise ForgeAPIError, not silently parse.
|
||||
with self.assertRaises(ForgeAPIError) as ctx:
|
||||
RunResult.from_response(
|
||||
{"ok": False, "result": None, "duration_ms": 0}
|
||||
)
|
||||
self.assertEqual(ctx.exception.status_code, 200)
|
||||
|
||||
def test_file_token_missing_field(self) -> None:
|
||||
with self.assertRaises(ForgeError):
|
||||
FileToken.from_response({"ttl_secs": 60, "size": 5}) # no file_token
|
||||
|
||||
def test_file_token_bad_int(self) -> None:
|
||||
with self.assertRaises(ForgeError):
|
||||
FileToken.from_response(
|
||||
{"file_token": "ff_x", "ttl_secs": "lots", "size": 5}
|
||||
)
|
||||
|
||||
def test_app_token_create_missing_name(self) -> None:
|
||||
with self.assertRaises(ForgeError):
|
||||
AppToken.from_create_response({"token": "cf_x"})
|
||||
|
||||
def test_app_token_list_row_missing_name(self) -> None:
|
||||
with self.assertRaises(ForgeError):
|
||||
AppToken.from_list_row({"ip_cidrs": ""})
|
||||
|
||||
@responses.activate
|
||||
def test_run_payload_with_bad_duration_surfaces_forge_error(self) -> None:
|
||||
responses.add(
|
||||
responses.POST,
|
||||
f"{BASE_URL}/run",
|
||||
json={"ok": True, "result": "x", "duration_ms": "junk"},
|
||||
status=200,
|
||||
)
|
||||
with _forge() as f, self.assertRaises(ForgeError):
|
||||
f.run(prompt="hi")
|
||||
|
||||
|
||||
class TestUploadFileMissingPath(unittest.TestCase):
|
||||
"""M3: missing path must surface as ForgeError, not FileNotFoundError."""
|
||||
|
||||
def test_missing_path_raises_forge_error(self) -> None:
|
||||
with _forge() as f, self.assertRaises(ForgeError):
|
||||
f.upload_file("/nonexistent/path/to/file.bin")
|
||||
|
||||
def test_missing_path_does_not_leak_filenotfounderror(self) -> None:
|
||||
# FileNotFoundError is OSError; ForgeError is not. Verify caller can
|
||||
# rely on `except ForgeError` to catch this.
|
||||
with _forge() as f:
|
||||
try:
|
||||
f.upload_file("/nonexistent/path/to/file.bin")
|
||||
except ForgeError:
|
||||
pass
|
||||
except FileNotFoundError:
|
||||
self.fail("FileNotFoundError leaked through ForgeError boundary")
|
||||
|
||||
|
||||
class TestUploadFileSymlinkOption(unittest.TestCase):
|
||||
"""M5: optional follow_symlinks=False kwarg refuses symlinks."""
|
||||
|
||||
@responses.activate
|
||||
def test_follow_symlinks_false_refuses_symlink(self) -> None:
|
||||
import tempfile
|
||||
|
||||
with tempfile.NamedTemporaryFile("wb", delete=False, suffix=".txt") as tf:
|
||||
tf.write(b"target")
|
||||
target = tf.name
|
||||
link = target + ".link"
|
||||
try:
|
||||
os.symlink(target, link)
|
||||
with _forge() as f, self.assertRaises(ForgeError):
|
||||
f.upload_file(link, follow_symlinks=False)
|
||||
finally:
|
||||
Path(link).unlink(missing_ok=True)
|
||||
Path(target).unlink(missing_ok=True)
|
||||
|
||||
@responses.activate
|
||||
def test_follow_symlinks_true_uploads_target(self) -> None:
|
||||
import tempfile
|
||||
|
||||
responses.add(
|
||||
responses.POST,
|
||||
f"{BASE_URL}/files",
|
||||
json={"file_token": "ff_sym", "ttl_secs": 60, "size": 6},
|
||||
status=200,
|
||||
)
|
||||
with tempfile.NamedTemporaryFile("wb", delete=False, suffix=".txt") as tf:
|
||||
tf.write(b"target")
|
||||
target = tf.name
|
||||
link = target + ".link"
|
||||
try:
|
||||
os.symlink(target, link)
|
||||
with _forge() as f:
|
||||
ft = f.upload_file(link, ttl_secs=60, follow_symlinks=True)
|
||||
self.assertEqual(ft.file_token, "ff_sym")
|
||||
finally:
|
||||
Path(link).unlink(missing_ok=True)
|
||||
Path(target).unlink(missing_ok=True)
|
||||
|
||||
|
||||
class TestRunTimeoutValidation(unittest.TestCase):
|
||||
"""M4: timeout_secs is range-validated locally."""
|
||||
|
||||
def test_negative_timeout_rejected(self) -> None:
|
||||
with _forge() as f, self.assertRaises(ValueError):
|
||||
f.run(prompt="hi", timeout_secs=-30)
|
||||
|
||||
def test_zero_timeout_rejected(self) -> None:
|
||||
# 0 used to be falsy-substituted with the default; now it's a hard error.
|
||||
with _forge() as f, self.assertRaises(ValueError):
|
||||
f.run(prompt="hi", timeout_secs=0)
|
||||
|
||||
def test_excessive_timeout_rejected(self) -> None:
|
||||
with _forge() as f, self.assertRaises(ValueError):
|
||||
f.run(prompt="hi", timeout_secs=10_000)
|
||||
|
||||
@responses.activate
|
||||
def test_min_boundary_accepted(self) -> None:
|
||||
# Smoke: 5 is the minimum and must NOT trip local validation.
|
||||
responses.add(
|
||||
responses.POST,
|
||||
f"{BASE_URL}/run",
|
||||
json={"ok": True, "result": "x", "duration_ms": 1, "stop_reason": "end_turn"},
|
||||
status=200,
|
||||
)
|
||||
with _forge() as f:
|
||||
r = f.run(prompt="hi", timeout_secs=5)
|
||||
self.assertTrue(r.ok)
|
||||
|
||||
@responses.activate
|
||||
def test_max_boundary_accepted(self) -> None:
|
||||
responses.add(
|
||||
responses.POST,
|
||||
f"{BASE_URL}/run",
|
||||
json={"ok": True, "result": "x", "duration_ms": 1, "stop_reason": "end_turn"},
|
||||
status=200,
|
||||
)
|
||||
with _forge() as f:
|
||||
r = f.run(prompt="hi", timeout_secs=600)
|
||||
self.assertTrue(r.ok)
|
||||
|
||||
|
||||
class TestHealthzValidation(unittest.TestCase):
|
||||
"""M6: non-dict /healthz response raises ForgeError."""
|
||||
|
||||
@responses.activate
|
||||
def test_healthz_string_response_raises(self) -> None:
|
||||
responses.add(
|
||||
responses.GET,
|
||||
f"{BASE_URL}/healthz",
|
||||
body="OK", # plain string, not JSON
|
||||
status=200,
|
||||
content_type="text/plain",
|
||||
)
|
||||
with _forge() as f, self.assertRaises(ForgeError):
|
||||
f.healthz()
|
||||
|
||||
@responses.activate
|
||||
def test_healthz_json_list_raises(self) -> None:
|
||||
responses.add(
|
||||
responses.GET,
|
||||
f"{BASE_URL}/healthz",
|
||||
json=["not", "a", "dict"],
|
||||
status=200,
|
||||
)
|
||||
with _forge() as f, self.assertRaises(ForgeError):
|
||||
f.healthz()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue