Skip to content

Commit 5d3a643

Browse files
authored
Add support for multi-reference image editing with new image_urls parameter (#99)
- Add new examples for multi-reference image editing - Enforce mutual exclusivity of image_url and image_urls parameters in image generation methods - Add tests to ensure new `image_urls` parameter is passed correctly
1 parent 48adaab commit 5d3a643

10 files changed

Lines changed: 414 additions & 46 deletions

File tree

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
"""Multi-reference image editing example.
2+
3+
Example:
4+
uv run examples/aio/image_multi_reference_editing.py \
5+
--prompt "Blend these references into a cyberpunk skyline" \
6+
--image-url "https://example.com/one.jpg" \
7+
--image-url "https://example.com/two.jpg" \
8+
--format url
9+
"""
10+
11+
import asyncio
12+
from typing import Sequence, cast
13+
14+
from absl import app, flags
15+
16+
import xai_sdk
17+
from xai_sdk.aio.image import ImageResponse
18+
from xai_sdk.image import ImageFormat
19+
20+
N = flags.DEFINE_integer("n", 1, "Number of images to generate.")
21+
FORMAT = flags.DEFINE_enum("format", "base64", ["base64", "url"], "Image format used to return the result.")
22+
MODEL = flags.DEFINE_string("model", "grok-imagine-image", "Image generation model to use.")
23+
OUTPUT_DIR = flags.DEFINE_string("output-dir", None, "Directory to save the generated images.")
24+
PROMPT = flags.DEFINE_string("prompt", None, "Prompt to edit the input images.", required=True)
25+
IMAGE_URLS = flags.DEFINE_multi_string(
26+
"image-url",
27+
[],
28+
"Input image URL or base64-encoded string. Repeat for multiple images.",
29+
)
30+
31+
32+
async def edit_images(client: xai_sdk.AsyncClient, image_format: ImageFormat) -> Sequence[ImageResponse]:
33+
"""Multi-reference image editing using image URLs or base64 strings."""
34+
image_urls = list(IMAGE_URLS.value)
35+
if not image_urls:
36+
raise app.UsageError("At least one --image-url is required.")
37+
38+
if N.value == 1:
39+
response = await client.image.sample(
40+
PROMPT.value,
41+
model=MODEL.value,
42+
image_format=image_format,
43+
image_urls=image_urls,
44+
)
45+
return [response]
46+
47+
return await client.image.sample_batch(
48+
PROMPT.value,
49+
n=N.value,
50+
model=MODEL.value,
51+
image_format=image_format,
52+
image_urls=image_urls,
53+
)
54+
55+
56+
async def save_images(responses: Sequence[ImageResponse]) -> None:
57+
"""Save images to a file."""
58+
for i, image in enumerate(responses):
59+
with open(f"{OUTPUT_DIR.value}/image_{i}.jpg", "wb") as f:
60+
f.write(await image.image)
61+
62+
63+
async def main(argv: Sequence[str]) -> None:
64+
if len(argv) > 1:
65+
raise app.UsageError("Unexpected command line arguments.")
66+
67+
if FORMAT.value != "url" and not OUTPUT_DIR.value:
68+
raise app.UsageError("--output-dir is required when --format is not url.")
69+
70+
client = xai_sdk.AsyncClient()
71+
image_format: ImageFormat = cast(ImageFormat, FORMAT.value)
72+
responses = await edit_images(client, image_format)
73+
74+
if image_format == "url":
75+
for i, image in enumerate(responses):
76+
print(f"Image {i} URL: {image.url}")
77+
return
78+
79+
await save_images(responses)
80+
81+
82+
if __name__ == "__main__":
83+
app.run(lambda argv: asyncio.run(main(argv)))
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""Multi-reference image editing example.
2+
3+
Example:
4+
uv run examples/sync/image_multi_reference_editing.py \
5+
--prompt "Blend these references into a cyberpunk skyline" \
6+
--image-url "https://example.com/one.jpg" \
7+
--image-url "https://example.com/two.jpg" \
8+
--format url
9+
"""
10+
11+
from typing import Sequence, cast
12+
13+
from absl import app, flags
14+
15+
import xai_sdk
16+
from xai_sdk.image import ImageFormat
17+
from xai_sdk.sync.image import ImageResponse
18+
19+
N = flags.DEFINE_integer("n", 1, "Number of images to generate.")
20+
FORMAT = flags.DEFINE_enum("format", "base64", ["base64", "url"], "Image format used to return the result.")
21+
MODEL = flags.DEFINE_string("model", "grok-imagine-image", "Image generation model to use.")
22+
OUTPUT_DIR = flags.DEFINE_string("output-dir", None, "Directory to save the generated images.")
23+
PROMPT = flags.DEFINE_string("prompt", None, "Prompt to edit the input images.", required=True)
24+
IMAGE_URLS = flags.DEFINE_multi_string(
25+
"image-url",
26+
[],
27+
"Input image URL or base64-encoded string. Repeat for multiple images.",
28+
)
29+
30+
31+
def edit_images(client: xai_sdk.Client, image_format: ImageFormat) -> Sequence[ImageResponse]:
32+
"""Multi-reference image editing using image URLs or base64 strings."""
33+
image_urls = list(IMAGE_URLS.value)
34+
if not image_urls:
35+
raise app.UsageError("At least one --image-url is required.")
36+
37+
if N.value == 1:
38+
response = client.image.sample(
39+
PROMPT.value,
40+
model=MODEL.value,
41+
image_format=image_format,
42+
image_urls=image_urls,
43+
)
44+
return [response]
45+
46+
return client.image.sample_batch(
47+
PROMPT.value,
48+
n=N.value,
49+
model=MODEL.value,
50+
image_format=image_format,
51+
image_urls=image_urls,
52+
)
53+
54+
55+
def save_images(responses: Sequence[ImageResponse]) -> None:
56+
"""Save images to a file."""
57+
for i, image in enumerate(responses):
58+
with open(f"{OUTPUT_DIR.value}/image_{i}.jpg", "wb") as f:
59+
f.write(image.image)
60+
61+
62+
def main(argv: Sequence[str]) -> None:
63+
if len(argv) > 1:
64+
raise app.UsageError("Unexpected command line arguments.")
65+
66+
if FORMAT.value != "url" and not OUTPUT_DIR.value:
67+
raise app.UsageError("--output-dir is required when --format is not url.")
68+
69+
client = xai_sdk.Client()
70+
image_format: ImageFormat = cast(ImageFormat, FORMAT.value)
71+
responses = edit_images(client, image_format)
72+
73+
if image_format == "url":
74+
for i, image in enumerate(responses):
75+
print(f"Image {i} URL: {image.url}")
76+
return
77+
78+
save_images(responses)
79+
80+
81+
if __name__ == "__main__":
82+
app.run(main)

src/xai_sdk/aio/image.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,23 @@ async def sample(
3131
model: str,
3232
*,
3333
image_url: Optional[str] = None,
34+
image_urls: Optional[Sequence[str]] = None,
3435
user: Optional[str] = None,
3536
image_format: Optional[ImageFormat] = None,
3637
aspect_ratio: Optional[ImageAspectRatio] = None,
3738
resolution: Optional[ImageResolution] = None,
3839
) -> "ImageResponse":
39-
"""Samples a single image asynchronously based on the provided prompt.
40+
"""Samples a single image asynchronously based on the provided prompt.
4041
4142
Args:
4243
prompt: The prompt to generate an image from.
4344
model: The model to use for image generation.
4445
image_url: The URL or base64-encoded string of an input image to use as a starting point for generation.
46+
This field cannot be set together with `image_urls`.
47+
Only supported for grok-imagine models.
48+
image_urls: Optional list of input images for multi-reference image editing.
49+
Each image is a URL or base64-encoded string, matching the `image_url` format.
50+
This field cannot be set together with `image_url`.
4551
Only supported for grok-imagine models.
4652
user: A unique identifier representing your end-user, which can help xAI to monitor and detect abuse.
4753
image_format: The format of the image to return. One of:
@@ -71,6 +77,9 @@ async def sample(
7177
Returns:
7278
An `ImageResponse` object allowing access to the generated image.
7379
"""
80+
if image_url is not None and image_urls is not None:
81+
raise ValueError("Only one of image_url or image_urls can be set for a request.")
82+
7483
image_format = image_format or "url"
7584
request = image_pb2.GenerateImageRequest(
7685
prompt=prompt,
@@ -86,6 +95,16 @@ async def sample(
8695
detail=image_pb2.ImageDetail.DETAIL_AUTO,
8796
)
8897
)
98+
if image_urls is not None:
99+
request.images.extend(
100+
[
101+
image_pb2.ImageUrlContent(
102+
image_url=url,
103+
detail=image_pb2.ImageDetail.DETAIL_AUTO,
104+
)
105+
for url in image_urls
106+
]
107+
)
89108
if aspect_ratio is not None:
90109
request.aspect_ratio = convert_image_aspect_ratio_to_pb(aspect_ratio)
91110
if resolution is not None:
@@ -108,6 +127,7 @@ async def sample_batch(
108127
n: int,
109128
*,
110129
image_url: Optional[str] = None,
130+
image_urls: Optional[Sequence[str]] = None,
111131
user: Optional[str] = None,
112132
image_format: Optional[ImageFormat] = None,
113133
aspect_ratio: Optional[ImageAspectRatio] = None,
@@ -120,6 +140,11 @@ async def sample_batch(
120140
model: The model to use for image generation.
121141
n: The number of images to generate.
122142
image_url: The URL or base64-encoded string of an input image to use as a starting point for generation.
143+
This field cannot be set together with `image_urls`.
144+
Only supported for grok-imagine models.
145+
image_urls: Optional list of input images for multi-reference image editing.
146+
Each image is a URL or base64-encoded string, matching the `image_url` format.
147+
This field cannot be set together with `image_url`.
123148
Only supported for grok-imagine models.
124149
user: A unique identifier representing your end-user, which can help xAI to monitor and detect abuse.
125150
image_format: The format of the image to return. One of:
@@ -149,6 +174,9 @@ async def sample_batch(
149174
Returns:
150175
A sequence of `ImageResponse` objects, one for each image generated.
151176
"""
177+
if image_url is not None and image_urls is not None:
178+
raise ValueError("Only one of image_url or image_urls can be set for a request.")
179+
152180
image_format = image_format or "url"
153181
request = image_pb2.GenerateImageRequest(
154182
prompt=prompt,
@@ -164,6 +192,16 @@ async def sample_batch(
164192
detail=image_pb2.ImageDetail.DETAIL_AUTO,
165193
)
166194
)
195+
if image_urls is not None:
196+
request.images.extend(
197+
[
198+
image_pb2.ImageUrlContent(
199+
image_url=url,
200+
detail=image_pb2.ImageDetail.DETAIL_AUTO,
201+
)
202+
for url in image_urls
203+
]
204+
)
167205
if aspect_ratio is not None:
168206
request.aspect_ratio = convert_image_aspect_ratio_to_pb(aspect_ratio)
169207
if resolution is not None:

src/xai_sdk/proto/v5/image_pb2.py

Lines changed: 20 additions & 20 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)