diff --git a/cmd/mapt/cmd/aws/hosts/rhelai.go b/cmd/mapt/cmd/aws/hosts/rhelai.go index c3a392ce1..e6e482082 100644 --- a/cmd/mapt/cmd/aws/hosts/rhelai.go +++ b/cmd/mapt/cmd/aws/hosts/rhelai.go @@ -72,6 +72,9 @@ func getRHELAICreate() *cobra.Command { HFToken: viper.GetString(params.RhelAIHFToken), APIKey: viper.GetString(params.RhelAIAPIKey), AutoStart: viper.IsSet(params.RhelAIAutoStart), + ToolCallParser: viper.GetString(params.RhelAIToolCallParser), + ChatTemplate: viper.GetString(params.RhelAIChatTemplate), + MaxModelLen: viper.GetInt(params.RhelAIMaxModelLen), ExposePorts: viper.GetIntSlice(params.RhelAIExposePorts), }) }, @@ -87,6 +90,9 @@ func getRHELAICreate() *cobra.Command { flagSet.StringP(params.RhelAIAPIKey, "", "", params.RhelAIAPIKeyDesc) flagSet.Bool(params.RhelAIAutoStart, false, params.RhelAIAutoStartDesc) flagSet.IntSlice(params.RhelAIExposePorts, nil, params.RhelAIExposePortsDesc) + flagSet.StringP(params.RhelAIToolCallParser, "", "", params.RhelAIToolCallParserDesc) + flagSet.StringP(params.RhelAIChatTemplate, "", "", params.RhelAIChatTemplateDesc) + flagSet.Int(params.RhelAIMaxModelLen, 0, params.RhelAIMaxModelLenDesc) flagSet.StringP(params.Timeout, "", "", params.TimeoutDesc) params.AddComputeRequestFlags(flagSet) params.AddSpotFlags(flagSet) diff --git a/cmd/mapt/cmd/params/params.go b/cmd/mapt/cmd/params/params.go index 26e513422..041e127e2 100644 --- a/cmd/mapt/cmd/params/params.go +++ b/cmd/mapt/cmd/params/params.go @@ -131,6 +131,12 @@ const ( RhelAIAutoStartDesc string = "automatically configure and start RHAIIS after provisioning" RhelAIExposePorts string = "expose-ports" RhelAIExposePortsDesc string = "comma-separated list of ports to expose through the load balancer and security group (e.g. 8000,8080)" + RhelAIToolCallParser string = "tool-call-parser" + RhelAIToolCallParserDesc string = "enable tool calling with the specified parser (e.g. llama3_json, hermes, mistral)" + RhelAIChatTemplate string = "chat-template" + RhelAIChatTemplateDesc string = "chat template jinja filename (e.g. tool_chat_template_llama3.2_json.jinja)" + RhelAIMaxModelLen string = "max-model-len" + RhelAIMaxModelLenDesc string = "maximum model context length in tokens (default 4096)" // Serverless Timeout string = "timeout" diff --git a/pkg/provider/aws/action/rhel-ai/rhelai.go b/pkg/provider/aws/action/rhel-ai/rhelai.go index 2863fecb1..8ffa428f1 100644 --- a/pkg/provider/aws/action/rhel-ai/rhelai.go +++ b/pkg/provider/aws/action/rhel-ai/rhelai.go @@ -46,6 +46,9 @@ type rhelAIRequest struct { hfToken *string apiKey *string autoStart bool + toolCallParser *string + chatTemplate *string + maxModelLen int exposePorts []int } @@ -85,6 +88,9 @@ func Create(mCtxArgs *mc.ContextArgs, args *apiRHELAI.RHELAIArgs) (err error) { hfToken: &args.HFToken, apiKey: &args.APIKey, autoStart: args.AutoStart, + maxModelLen: args.MaxModelLen, + toolCallParser: &args.ToolCallParser, + chatTemplate: &args.ChatTemplate, exposePorts: args.ExposePorts} if args.Spot != nil { r.spot = args.Spot.Spot @@ -373,6 +379,26 @@ func (r *rhelAIRequest) rhaiisSetupScript() string { ` && sudo sed -i 's|--model .*|--model %s \\|' %s/install.conf`, *r.model, confDir) } + script += fmt.Sprintf( + ` && GPU_COUNT=$(nvidia-smi -L 2>/dev/null | wc -l) && [ "$GPU_COUNT" -gt 0 ] && sudo sed -i "s|--tensor-parallel-size 1|--tensor-parallel-size $GPU_COUNT|" %s/install.conf`, + confDir) + maxModelLen := 4096 + if r.maxModelLen > 0 { + maxModelLen = r.maxModelLen + } + if len(*r.toolCallParser) > 0 { + toolArgs := fmt.Sprintf(`--enable-auto-tool-choice \\\n --tool-call-parser %s`, *r.toolCallParser) + if len(*r.chatTemplate) > 0 { + toolArgs += fmt.Sprintf(` \\\n --chat-template /opt/app-root/template/%s`, *r.chatTemplate) + } + script += fmt.Sprintf( + ` && sudo sed -i 's|--max-model-len.*|--max-model-len %d \\\n %s|' %s/install.conf`, + maxModelLen, toolArgs, confDir) + } else if r.maxModelLen > 0 { + script += fmt.Sprintf( + ` && sudo sed -i 's|--max-model-len.*|--max-model-len %d|' %s/install.conf`, + maxModelLen, confDir) + } if len(*r.apiKey) > 0 { script += fmt.Sprintf( " && sudo sed -i '/\\[Install\\]/i Environment=VLLM_API_KEY=%s' %s/install.conf", diff --git a/pkg/target/host/rhelai/api.go b/pkg/target/host/rhelai/api.go index c313faf64..ac688a6ab 100644 --- a/pkg/target/host/rhelai/api.go +++ b/pkg/target/host/rhelai/api.go @@ -19,7 +19,10 @@ type RHELAIArgs struct { Timeout string Model string HFToken string - APIKey string - AutoStart bool - ExposePorts []int + APIKey string + AutoStart bool + ToolCallParser string + ChatTemplate string + MaxModelLen int + ExposePorts []int } diff --git a/tkn/infra-aws-rhel-ai.yaml b/tkn/infra-aws-rhel-ai.yaml index 983a1d864..5f98f3bde 100644 --- a/tkn/infra-aws-rhel-ai.yaml +++ b/tkn/infra-aws-rhel-ai.yaml @@ -155,6 +155,15 @@ spec: - name: expose-ports description: Comma-separated list of ports to expose through the load balancer and security group (e.g. 8000,8080). default: "" + - name: tool-call-parser + description: Enable tool calling with the specified parser (e.g. llama3_json, hermes, mistral). Automatically adds --enable-auto-tool-choice. + default: "" + - name: chat-template + description: Chat template jinja filename (e.g. tool_chat_template_llama3.2_json.jinja). + default: "" + - name: max-model-len + description: Maximum model context length in tokens (default 4096). Increase for tool calling or larger models. + default: "0" # Network params - name: service-endpoints @@ -317,6 +326,15 @@ spec: if [[ "$(params.expose-ports)" != "" ]]; then cmd+="--expose-ports '$(params.expose-ports)' " fi + if [[ "$(params.tool-call-parser)" != "" ]]; then + cmd+="--tool-call-parser '$(params.tool-call-parser)' " + fi + if [[ "$(params.chat-template)" != "" ]]; then + cmd+="--chat-template '$(params.chat-template)' " + fi + if [[ "$(params.max-model-len)" != "0" ]]; then + cmd+="--max-model-len '$(params.max-model-len)' " + fi cmd+="--tags '$(params.tags)' " fi diff --git a/tkn/infra-azure-rhel-ai.yaml b/tkn/infra-azure-rhel-ai.yaml index 21c9e87b0..3fbe0c6f4 100644 --- a/tkn/infra-azure-rhel-ai.yaml +++ b/tkn/infra-azure-rhel-ai.yaml @@ -85,6 +85,12 @@ spec: - name: disk-size description: Disk size in GB for the cloud instance default: "200" + - name: gpus + description: Number of GPUs for the cloud instance (valid marketplace values are 1, 2, 4, 8) + default: "8" + - name: gpu-manufacturer + description: GPU manufacturer name for instance filtering (e.g. NVIDIA, AMD) + default: "" - name: compute-sizes description: Comma seperated list of sizes for the machines to be requested. If set this takes precedence over compute by args default: "Standard_ND96is_MI300X_v5,Standard_ND96isr_MI300X_v5" @@ -229,6 +235,12 @@ spec: if [[ "$(params.compute-sizes)" != "" ]]; then cmd+="--compute-sizes '$(params.compute-sizes)' " fi + if [[ "$(params.gpus)" != "" ]]; then + cmd+="--gpus '$(params.gpus)' " + fi + if [[ "$(params.gpu-manufacturer)" != "" ]]; then + cmd+="--gpu-manufacturer '$(params.gpu-manufacturer)' " + fi if [[ "$(params.marketplace)" == "true" ]]; then cmd+="--marketplace " cmd+="--accelerator '$(params.accelerator)' " diff --git a/tkn/template/infra-aws-rhel-ai.yaml b/tkn/template/infra-aws-rhel-ai.yaml index a3799ddc0..5681fac6a 100644 --- a/tkn/template/infra-aws-rhel-ai.yaml +++ b/tkn/template/infra-aws-rhel-ai.yaml @@ -155,6 +155,15 @@ spec: - name: expose-ports description: Comma-separated list of ports to expose through the load balancer and security group (e.g. 8000,8080). default: "" + - name: tool-call-parser + description: Enable tool calling with the specified parser (e.g. llama3_json, hermes, mistral). Automatically adds --enable-auto-tool-choice. + default: "" + - name: chat-template + description: Chat template jinja filename (e.g. tool_chat_template_llama3.2_json.jinja). + default: "" + - name: max-model-len + description: Maximum model context length in tokens (default 4096). Increase for tool calling or larger models. + default: "0" # Network params - name: service-endpoints @@ -317,6 +326,15 @@ spec: if [[ "$(params.expose-ports)" != "" ]]; then cmd+="--expose-ports '$(params.expose-ports)' " fi + if [[ "$(params.tool-call-parser)" != "" ]]; then + cmd+="--tool-call-parser '$(params.tool-call-parser)' " + fi + if [[ "$(params.chat-template)" != "" ]]; then + cmd+="--chat-template '$(params.chat-template)' " + fi + if [[ "$(params.max-model-len)" != "0" ]]; then + cmd+="--max-model-len '$(params.max-model-len)' " + fi cmd+="--tags '$(params.tags)' " fi