Skip to content

Commit

Permalink
fix API
Browse files Browse the repository at this point in the history
  • Loading branch information
frank-dong-ms committed Oct 24, 2024
1 parent e6b86e6 commit a3e7314
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion include/onnxruntime/core/framework/op_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class OpKernel {
// @param input_idx : The input index of the tensor in this kernel.
// @param pre_packed_tensor: The prepacked tensor read from onnx data file and use the prepacked tensor
// to restore prepacked weight buffer.
virtual Status SetPrePackTensor(int /*input_idx*/, Tensor& /*pre_packed_tensor*/) {
virtual Status SetPrePackTensor(int /*input_idx*/, const Tensor& /*pre_packed_tensor*/) {
return Status::OK();
}

Expand Down
7 changes: 4 additions & 3 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class MatMulNBits final : public OpKernel {

std::optional<Tensor> GetPrePackTensor(int /*input_idx*/) override;

Status SetPrePackTensor(int input_idx, Tensor& pre_packed_tensor) override;
Status SetPrePackTensor(int input_idx, const Tensor& pre_packed_tensor) override;

private:
const size_t K_;
Expand Down Expand Up @@ -286,12 +286,13 @@ std::optional<Tensor> MatMulNBits<T1>::GetPrePackTensor(int input_idx) {
}

template <typename T1>
Status MatMulNBits<T1>::SetPrePackTensor(int input_idx, Tensor& pre_packed_tensor) {
Status MatMulNBits<T1>::SetPrePackTensor(int input_idx, const Tensor& pre_packed_tensor) {
if (input_idx == 1) {
// pre_packed_tensor is constant initialized tensor and its lifecycle is managed by session_state,
// session_state will release memory from pre_packed_tensor. packed_b_ will not release memory so
// pass empty/default buffer deleter here.
packed_b_ = BufferUniquePtr(pre_packed_tensor.MutableDataRaw(), BufferDeleter());
// const_cast here is temporary, will fix in follow up PR.
packed_b_ = BufferUniquePtr(const_cast<void*>(pre_packed_tensor.DataRaw()), BufferDeleter());
}

return Status::OK();
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,7 @@ Status SessionState::PrepackConstantInitializedTensors(InlinedHashMap<std::strin
// If prepacked weights already read from ONNX data file (this happens we ORT reads data file with prepacked
// weights serialized), only need to set prepacked weights once to kernel.
is_kernel_prepacked = true;
ORT_THROW_IF_ERROR(kernel->SetPrePackTensor(input_idx,
*(constant_initialized_tensors[ort_value_idx].GetMutable<Tensor>())));
ORT_THROW_IF_ERROR(kernel->SetPrePackTensor(input_idx, const_initialized_tensor));
}
// Caching pre-packed weights is limited to shared initializers associated with the CPU EP for now
else if (is_shared_initializer && should_cache_prepacked_weights_for_shared_initializers &&
Expand Down

0 comments on commit a3e7314

Please sign in to comment.