diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index 0a74e93baa4e5..a17da2a19bb99 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -147,7 +147,7 @@ class OpKernel { // @param input_idx : The input index of the tensor in this kernel. // @param pre_packed_tensor: The prepacked tensor read from onnx data file and use the prepacked tensor // to restore prepacked weight buffer. - virtual Status SetPrePackTensor(int /*input_idx*/, Tensor& /*pre_packed_tensor*/) { + virtual Status SetPrePackTensor(int /*input_idx*/, const Tensor& /*pre_packed_tensor*/) { return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 7c568a6f7547b..cee3dfc6b3f28 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -109,7 +109,7 @@ class MatMulNBits final : public OpKernel { std::optional GetPrePackTensor(int /*input_idx*/) override; - Status SetPrePackTensor(int input_idx, Tensor& pre_packed_tensor) override; + Status SetPrePackTensor(int input_idx, const Tensor& pre_packed_tensor) override; private: const size_t K_; @@ -286,12 +286,13 @@ std::optional MatMulNBits::GetPrePackTensor(int input_idx) { } template -Status MatMulNBits::SetPrePackTensor(int input_idx, Tensor& pre_packed_tensor) { +Status MatMulNBits::SetPrePackTensor(int input_idx, const Tensor& pre_packed_tensor) { if (input_idx == 1) { // pre_packed_tensor is constant initialized tensor and its lifecycle is managed by session_state, // session_state will release memory from pre_packed_tensor. packed_b_ will not release memory so // pass empty/default buffer deleter here. - packed_b_ = BufferUniquePtr(pre_packed_tensor.MutableDataRaw(), BufferDeleter()); + // const_cast here is temporary, will fix in follow up PR. + packed_b_ = BufferUniquePtr(const_cast(pre_packed_tensor.DataRaw()), BufferDeleter()); } return Status::OK(); diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index abb6e315f9cb1..943db091b341f 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -437,8 +437,7 @@ Status SessionState::PrepackConstantInitializedTensors(InlinedHashMapSetPrePackTensor(input_idx, - *(constant_initialized_tensors[ort_value_idx].GetMutable()))); + ORT_THROW_IF_ERROR(kernel->SetPrePackTensor(input_idx, const_initialized_tensor)); } // Caching pre-packed weights is limited to shared initializers associated with the CPU EP for now else if (is_shared_initializer && should_cache_prepacked_weights_for_shared_initializers &&