transformers推理中仅支持贪心采样
在使用 transformers 进行推理时,如果不进行代码修改,你会发现它只能支持贪婪搜索。当你尝试为 generate() 函数启用 do_sample 功能时,你可能会遇到如下 ACL 问题:
RuntimeError: ACL stream synchronize failed, error code:507018
[W compiler_depend.ts:409] Warning: NPU warning, error code is 507018[Error]:
[Error]: The aicpu execution is abnormal.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronize execute failed, reason=[aicpu exception][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 37944] 2025-03-19-06:12:26.371.588 wait for compute device to finish failed, runtime result = 507018.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function npuSynchronizeUsedDevices)
[W compiler_depend.ts:392] Warning: NPU warning, error code is 507018[Error]:
[Error]: The aicpu execution is abnormal.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronize execute failed, reason=[aicpu exception][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 37944] 2025-03-19-06:12:26.373.863 wait for compute device to finish failed, runtime result = 507018.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function npuSynchronizeDevice)
[W compiler_depend.ts:392] Warning: NPU warning, error code is 507018[Error]:
[Error]: The aicpu execution is abnormal.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronize execute failed, reason=[aicpu exception][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 37944] 2025-03-19-06:12:26.375.369 wait for compute device to finish failed, runtime result = 507018.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function npuSynchronizeDevice)
[W compiler_depend.ts:392] Warning: NPU warning, error code is 507018[Error]:
[Error]: The aicpu execution is abnormal.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronize execute failed, reason=[aicpu exception][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 37944] 2025-03-19-06:12:26.394.159 wait for compute device to finish failed, runtime result = 507018.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function npuSynchronizeDevice)
[W compiler_depend.ts:392] Warning: NPU warning, error code is 507018[Error]:
[Error]: The aicpu execution is abnormal.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronize execute failed, reason=[aicpu exception][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 37944] 2025-03-19-06:12:26.395.596 wait for compute device to finish failed, runtime result = 507018.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function npuSynchronizeDevice)
[W compiler_depend.ts:392] Warning: NPU warning, error code is 507018[Error]:
[Error]: The aicpu execution is abnormal.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronize execute failed, reason=[aicpu exception][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 37944] 2025-03-19-06:12:26.397.006 wait for compute device to finish failed, runtime result = 507018.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function npuSynchronizeDevice)
[W compiler_depend.ts:392] Warning: NPU warning, error code is 507018[Error]:
[Error]: The aicpu execution is abnormal.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronize execute failed, reason=[aicpu exception][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 37944] 2025-03-19-06:12:26.398.441 wait for compute device to finish failed, runtime result = 507018.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function npuSynchronizeDevice)
[W compiler_depend.ts:392] Warning: NPU warning, error code is 507018[Error]:
[Error]: The aicpu execution is abnormal.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronize execute failed, reason=[aicpu exception][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 37944] 2025-03-19-06:12:26.399.848 wait for compute device to finish failed, runtime result = 507018.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function npuSynchronizeDevice)
[W compiler_depend.ts:392] Warning: NPU warning, error code is 507018[Error]:
[Error]: The aicpu execution is abnormal.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronize execute failed, reason=[aicpu exception][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 37944] 2025-03-19-06:12:26.401.257 wait for compute device to finish failed, runtime result = 507018.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function npuSynchronizeDevice)
/root/miniconda3/envs/MindIE_1.0.T65/lib/python3.10/tempfile.py:869: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/tmp/tmp8tky0pbq'>
_warnings.warn(warn_message, ResourceWarning)
[ERROR] 2025-03-19-06:12:27 (PID:37944, Device:1, RankID:-1) ERR99999 UNKNOWN application exception
这是由于 Ascend NPU(910A) 无法正确支持把 float(‘Inf’) 作为 torch.Tensor.masked_fill 的 filter_value 的问题。这种情况下,logits_processor 会出现将 ’nan’ 填充到 next_token_scores 中的现象。
# src/transformers/generation/utils.py _sample() 函数
next_token_scores = logits_processor(input_ids, next_token_logits)
根据官方文档,NPU 不支持 float16 类型的 inf/nan 数据输入或输出。但根据我的一个简单验证,如果没有事先通过 torch_npu.npu.set_compile_mode(jit_compile=False) 禁用 JIT 编译,NPU(910A) 可以正确支持带有 float(‘Inf’) 的 masked_fill。
然而我在没有用 torch_npu.npu.set_compile_mode(jit_compile=False) 的情况下,在使用 transformers 进行推理时,程序卡住了,显示一直在等待。
这个问题我已经报告给 Ascend NPU 支持团队了。
目前我使用的一个临时解决方案是在 logits_process.py 中将 filter_value 修改为较大的数字,如 float(6.5e4)。
如果有任何纰漏或者其他更好的解决方案,欢迎通过邮件告诉我。