Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions integration-tests/tests/bool_input_output.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Test bool as a direct predict input and output type

# Build the image
cog build -t $TEST_IMAGE

# Bool input and output works via JSON (true -> false)
cog predict $TEST_IMAGE --json '{"flag": true}'
stdout '"output": false'

# Bool input and output works via JSON (false -> true)
cog predict $TEST_IMAGE --json '{"flag": false}'
stdout '"output": true'

-- cog.yaml --
build:
python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
def predict(self, flag: bool) -> bool:
return not flag
27 changes: 27 additions & 0 deletions integration-tests/tests/concatenate_iterator_output.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Test ConcatenateIterator[str] as predict output type
#
# ConcatenateIterator is the primary streaming text output type for LLMs.
# cog predict renders each yielded token as an array element.

# Build the image
cog build -t $TEST_IMAGE

# Streaming output yields individual tokens
cog predict $TEST_IMAGE -i prompt=hello
stdout '"hello"'
stdout '" world"'
stdout '" !"'

-- cog.yaml --
build:
python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor, ConcatenateIterator


class Predictor(BasePredictor):
def predict(self, prompt: str) -> ConcatenateIterator[str]:
for token in [prompt, " world", " !"]:
yield token
22 changes: 22 additions & 0 deletions integration-tests/tests/dict_output.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Test bare dict return type works for predict output

# Build the image
cog build -t $TEST_IMAGE

# Predict returns a dict
cog predict $TEST_IMAGE -i name=alice
stdout '"greeting": "hello alice"'
stdout '"length": 5'

-- cog.yaml --
build:
python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
def predict(self, name: str) -> dict:
return {"greeting": "hello " + name, "length": len(name)}
28 changes: 28 additions & 0 deletions integration-tests/tests/iterator_string_output.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Test Iterator[str] as predict output type
#
# Iterator[str] yields individual string items as an array.

# Build the image
cog build -t $TEST_IMAGE

# Iterator output returns items
cog predict $TEST_IMAGE -i count=3
stdout 'item-0'
stdout 'item-1'
stdout 'item-2'

-- cog.yaml --
build:
python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from typing import Iterator

from cog import BasePredictor


class Predictor(BasePredictor):
def predict(self, count: int) -> Iterator[str]:
for i in range(count):
yield f"item-{i}"
25 changes: 25 additions & 0 deletions integration-tests/tests/list_int_input_output.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Test list[int] as predict input and output types

# Build the image
cog build -t $TEST_IMAGE

# List of ints works as input and output
cog predict $TEST_IMAGE --json '{"numbers": [1, 2, 3]}'
stdout '"status": "succeeded"'
stdout '"output":'
stdout '2'
stdout '4'
stdout '6'

-- cog.yaml --
build:
python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
def predict(self, numbers: list[int]) -> list[int]:
return [n * 2 for n in numbers]
23 changes: 23 additions & 0 deletions integration-tests/tests/list_string_output.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Test list[str] as predict output type

# Build the image
cog build -t $TEST_IMAGE

# List output returns items
cog predict $TEST_IMAGE -i text='hello world foo'
stdout 'hello'
stdout 'world'
stdout 'foo'

-- cog.yaml --
build:
python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
def predict(self, text: str) -> list[str]:
return text.split()
12 changes: 7 additions & 5 deletions mise.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,22 @@ run = [
]

[tasks.install]
description = "Symlink cog CLI to /usr/local/bin (or PREFIX). Run build:cog first."
depends = ["build:cog"]
description = "Build and symlink cog CLI"
usage = 'arg "[dest]" help="Directory to symlink into (e.g. ~/.local/bin)" default="~/.local/bin"'
run = """
#!/usr/bin/env bash
set -e
PREFIX="${PREFIX:-/usr/local}"
DEST="${usage_dest/#\\~/$HOME}"
BINARY=$(ls dist/go/*/cog 2>/dev/null | head -1)
if [ -z "$BINARY" ]; then
echo "Error: no cog binary found in dist/go/. Run 'mise run build:cog' first." >&2
exit 1
fi
BINARY="$(cd "$(dirname "$BINARY")" && pwd)/$(basename "$BINARY")"
mkdir -p "$PREFIX/bin"
ln -sf "$BINARY" "$PREFIX/bin/cog"
echo "Installed $PREFIX/bin/cog -> $BINARY"
mkdir -p "$DEST"
ln -sf "$BINARY" "$DEST/cog"
echo "Installed $DEST/cog -> $BINARY"
"""

[tasks."build:cog"]
Expand Down
9 changes: 9 additions & 0 deletions python/cog/_adt.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,15 @@ def from_type(tpe: type) -> "FieldType":
tpe = Set[Any]
origin = typing.get_origin(tpe)

if origin is dict:
# dict / Dict[K, V] → opaque JSON object, consistent with the
# static Go schema generator's SchemaAnyType().
return FieldType(
primitive=PrimitiveType.ANY,
repetition=Repetition.REQUIRED,
coder=None,
)

if origin in (list, List):
t_args = typing.get_args(tpe)
if t_args:
Expand Down