diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index bd20420e5d7d0..33f2fbe652d31 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -674,5 +674,14 @@ WebGpuContext& WebGpuContextFactory::GetContext(int context_id) { return *it->second; } +void WebGpuContextFactory::Cleanup() { + std::lock_guard lock(mutex_); + contexts_.clear(); +} + +void CleanupWebGpuContexts() { + WebGpuContextFactory::Cleanup(); +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index cbabf778e1338..be05b06523b9c 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -34,6 +34,8 @@ class WebGpuContextFactory { ValidationMode validation_mode); static WebGpuContext& GetContext(int context_id); + static void Cleanup(); + private: WebGpuContextFactory() {} diff --git a/onnxruntime/core/session/ort_env.cc b/onnxruntime/core/session/ort_env.cc index ef84875df18a3..335ebbf203e7c 100644 --- a/onnxruntime/core/session/ort_env.cc +++ b/onnxruntime/core/session/ort_env.cc @@ -17,6 +17,14 @@ using namespace onnxruntime; using namespace onnxruntime::logging; +#ifdef USE_WEBGPU +namespace onnxruntime { +namespace webgpu { +void CleanupWebGpuContexts(); +} // namespace webgpu +} // namespace onnxruntime +#endif + std::unique_ptr OrtEnv::p_instance_; int OrtEnv::ref_count_ = 0; std::mutex OrtEnv::m_; @@ -26,6 +34,10 @@ OrtEnv::OrtEnv(std::unique_ptr value1) } OrtEnv::~OrtEnv() { +#ifdef USE_WEBGPU + webgpu::CleanupWebGpuContexts(); +#endif + // We don't support any shared providers in the minimal build yet #if !defined(ORT_MINIMAL_BUILD) UnloadSharedProviders();