Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .chat_completion import * # noqa: F403
from .completion import * # noqa: F403
from .model_endpoints import * # noqa: F403
from .vllm import * # noqa: F403
Original file line number Diff line number Diff line change
Expand Up @@ -665,23 +665,41 @@ async def create_text_generation_inference_bundle(
).model_bundle_id

def load_model_weights_sub_commands(
self, framework, framework_image_tag, checkpoint_path, final_weights_folder
self,
framework,
framework_image_tag,
checkpoint_path,
final_weights_folder,
trust_remote_code: bool = False,
):
if checkpoint_path.startswith("s3://"):
return self.load_model_weights_sub_commands_s3(
framework, framework_image_tag, checkpoint_path, final_weights_folder
framework,
framework_image_tag,
checkpoint_path,
final_weights_folder,
trust_remote_code,
)
elif checkpoint_path.startswith("azure://") or "blob.core.windows.net" in checkpoint_path:
return self.load_model_weights_sub_commands_abs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we update this too to keep things in sync?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kk done

framework, framework_image_tag, checkpoint_path, final_weights_folder
framework,
framework_image_tag,
checkpoint_path,
final_weights_folder,
trust_remote_code,
)
else:
raise ObjectHasInvalidValueException(
f"Only S3 and Azure Blob Storage paths are supported. Given checkpoint path: {checkpoint_path}."
)

def load_model_weights_sub_commands_s3(
self, framework, framework_image_tag, checkpoint_path, final_weights_folder
self,
framework,
framework_image_tag,
checkpoint_path,
final_weights_folder,
trust_remote_code: bool,
):
subcommands = []
s5cmd = "s5cmd"
Expand All @@ -700,14 +718,23 @@ def load_model_weights_sub_commands_s3(
validate_checkpoint_files(checkpoint_files)

# filter to configs ('*.model' and '*.json') and weights ('*.safetensors')
# For models that are not supported by transformers directly, we need to include '*.py' and '*.bin'
# to load the model. Only set this flag if "trust_remote_code" is set to True
file_selection_str = '--include "*.model" --include "*.json" --include "*.safetensors" --exclude "optimizer*"'
if trust_remote_code:
file_selection_str += ' --include "*.py"'
subcommands.append(
f"{s5cmd} --numworkers 512 cp --concurrency 10 {file_selection_str} {os.path.join(checkpoint_path, '*')} {final_weights_folder}"
)
return subcommands

def load_model_weights_sub_commands_abs(
self, framework, framework_image_tag, checkpoint_path, final_weights_folder
self,
framework,
framework_image_tag,
checkpoint_path,
final_weights_folder,
trust_remote_code: bool,
):
subcommands = []

Expand All @@ -729,9 +756,8 @@ def load_model_weights_sub_commands_abs(
]
)
else:
file_selection_str = (
'--include-pattern "*.model;*.json;*.safetensors" --exclude-pattern "optimizer*"'
)
additional_pattern = ";*.py" if trust_remote_code else ""
file_selection_str = f'--include-pattern "*.model;*.json;*.safetensors{additional_pattern}" --exclude-pattern "optimizer*"'
subcommands.append(
f"azcopy copy --recursive {file_selection_str} {os.path.join(checkpoint_path, '*')} {final_weights_folder}"
)
Expand Down Expand Up @@ -861,6 +887,8 @@ def _create_vllm_bundle_command(
subcommands = []

checkpoint_path = get_checkpoint_path(model_name, checkpoint_path)
additional_args = infer_addition_engine_args_from_model_name(model_name)

# added as workaround since transformers doesn't support mistral yet, vllm expects "mistral" in model weights folder
if "mistral" in model_name:
final_weights_folder = "mistral_files"
Expand All @@ -871,6 +899,7 @@ def _create_vllm_bundle_command(
framework_image_tag,
checkpoint_path,
final_weights_folder,
trust_remote_code=additional_args.trust_remote_code or False,
)

if multinode and not is_worker:
Expand Down Expand Up @@ -905,8 +934,6 @@ def _create_vllm_bundle_command(
if hmi_config.sensitive_log_mode: # pragma: no cover
vllm_cmd += " --disable-log-requests"

additional_args = infer_addition_engine_args_from_model_name(model_name)

for field in VLLMModelConfig.model_fields.keys():
config_value = getattr(additional_args, field, None)
if config_value is not None:
Expand Down
7 changes: 4 additions & 3 deletions model-engine/model_engine_server/inference/vllm/vllm_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ async def dummy_receive() -> MutableMapping[str, Any]:
)


async def download_model(checkpoint_path: str, target_dir: str) -> None:
s5cmd = f"./s5cmd --numworkers 512 sync --concurrency 10 --include '*.model' --include '*.json' --include '*.bin' --include '*.safetensors' --exclude 'optimizer*' --exclude 'train*' {os.path.join(checkpoint_path, '*')} {target_dir}"
async def download_model(checkpoint_path: str, target_dir: str, trust_remote_code: bool) -> None:
additional_include = "--include '*.py'" if trust_remote_code else ""
s5cmd = f"./s5cmd --numworkers 512 sync --concurrency 10 --include '*.model' --include '*.json' --include '*.safetensors' {additional_include} --exclude 'optimizer*' --exclude 'train*' {os.path.join(checkpoint_path, '*')} {target_dir}"
env = os.environ.copy()
env["AWS_PROFILE"] = os.getenv("S3_WRITE_AWS_PROFILE", "default")
# Need to override these env vars so s5cmd uses AWS_PROFILE
Expand Down Expand Up @@ -319,11 +320,11 @@ async def handle_batch_job(request: CreateBatchCompletionsEngineRequest) -> None
metrics_gateway = DatadogInferenceMonitoringMetricsGateway()

model = get_model_name(request.model_cfg)

if request.model_cfg.checkpoint_path:
await download_model(
checkpoint_path=request.model_cfg.checkpoint_path,
target_dir=MODEL_WEIGHTS_FOLDER,
trust_remote_code=request.model_cfg.trust_remote_code or False,
)

content = load_batch_content(request)
Expand Down
22 changes: 22 additions & 0 deletions model-engine/tests/unit/domain/test_llm_use_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,16 @@ def test_load_model_weights_sub_commands(
]
assert expected_result == subcommands

trust_remote_code = True
subcommands = llm_bundle_use_case.load_model_weights_sub_commands(
framework, framework_image_tag, checkpoint_path, final_weights_folder, trust_remote_code
)

expected_result = [
'./s5cmd --numworkers 512 cp --concurrency 10 --include "*.model" --include "*.json" --include "*.safetensors" --exclude "optimizer*" --include "*.py" s3://fake-checkpoint/* test_folder',
]
assert expected_result == subcommands

framework = LLMInferenceFramework.TEXT_GENERATION_INFERENCE
framework_image_tag = "1.0.0"
checkpoint_path = "s3://fake-checkpoint"
Expand Down Expand Up @@ -555,6 +565,18 @@ def test_load_model_weights_sub_commands(
]
assert expected_result == subcommands

trust_remote_code = True
subcommands = llm_bundle_use_case.load_model_weights_sub_commands(
framework, framework_image_tag, checkpoint_path, final_weights_folder, trust_remote_code
)

expected_result = [
"export AZCOPY_AUTO_LOGIN_TYPE=WORKLOAD",
"curl -L https://aka.ms/downloadazcopy-v10-linux | tar --strip-components=1 -C /usr/local/bin --no-same-owner --exclude=*.txt -xzvf - && chmod 755 /usr/local/bin/azcopy",
'azcopy copy --recursive --include-pattern "*.model;*.json;*.safetensors;*.py" --exclude-pattern "optimizer*" azure://fake-checkpoint/* test_folder',
]
assert expected_result == subcommands


@pytest.mark.asyncio
async def test_create_model_endpoint_trt_llm_use_case_success(
Expand Down