Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions parseq/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

from executorch.parseq.export_utils import get_dummy_input, prepare_export_model

from executorch.examples.qualcomm.utils import build_executorch_binary


if __name__ == '__main__':

model_name = "parseq"
export_mode = 'executorch'
model = prepare_export_model(model_name, export_mode)
image = get_dummy_input()
inputs = (image,)

build_executorch_binary(
model,
inputs,
"SM8650",
"parseq_qualcomm.pte",
[inputs],
skip_node_op_set={"aten.full.default", "aten.where.self"},
skip_node_id_set={"aten_view_copy_default_235"},
)
51 changes: 51 additions & 0 deletions parseq/export_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torchvision
import yaml

from executorch.parseq.parseq import PARSeq, Tokenizer
from PIL import Image


def get_transform(img_size=(32, 128)):
return torchvision.transforms.Compose(
[
torchvision.transforms.Resize(
img_size, torchvision.transforms.InterpolationMode.BICUBIC
),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(0.5, 0.5),
]
)


# Gray dummy image
def get_dummy_input(img_h=32, img_w=128, n_channels=3):
transform = get_transform((img_h, img_w))
image = Image.new("RGB", (img_w, img_h), color=128)
image = transform(image)
image = image.view(1, *image.size())
return image


def prepare_export_model(model_name, export_mode):
with open(f"parseq/{model_name}.yaml", "r") as f:
cfg = yaml.load(f, Loader=yaml.FullLoader)

print(cfg)
cfg.pop("name")
cfg.pop("_target_")
cfg.pop("lr")
cfg.pop("perm_num")
cfg.pop("perm_forward")
cfg.pop("perm_mirrored")
charset = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
tokenizer = Tokenizer(charset)
lightning_model = PARSeq(
tokenizer,
25,
(32, 128),
**cfg,
)

model = lightning_model
model.export_mode = export_mode
return model.eval()
Loading
Loading