From 6fbceea4a6213e2aa05849569b13c279c04245c5 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Sun, 22 Sep 2024 16:56:52 +0900 Subject: [PATCH] fix file handling in tests --- model_filemanager/download_models.py | 8 +- .../download_models_test.py | 78 ++++++++++--------- 2 files changed, 43 insertions(+), 43 deletions(-) diff --git a/model_filemanager/download_models.py b/model_filemanager/download_models.py index 5b0642e3666..ae3032ecb2f 100644 --- a/model_filemanager/download_models.py +++ b/model_filemanager/download_models.py @@ -158,10 +158,8 @@ async def update_progress(): await progress_callback(model_name, status) last_update_time = time.time() - if os.path.exists(file_path + '.tmp'): - os.remove(file_path + '.tmp') - - with open(file_path + '.tmp', 'wb') as f: + temp_file_path = file_path + '.tmp' + with open(temp_file_path, 'wb') as f: chunk_iterator = response.content.iter_chunked(8192) while True: try: @@ -174,7 +172,7 @@ async def update_progress(): if time.time() - last_update_time >= interval: await update_progress() - os.rename(file_path + '.tmp', file_path) + os.rename(temp_file_path, file_path) await update_progress() diff --git a/tests-unit/prompt_server_test/download_models_test.py b/tests-unit/prompt_server_test/download_models_test.py index c495f344f3e..8f633f8cf09 100644 --- a/tests-unit/prompt_server_test/download_models_test.py +++ b/tests-unit/prompt_server_test/download_models_test.py @@ -60,19 +60,13 @@ async def test_download_model_success(temp_dir): mock_make_request = AsyncMock(return_value=mock_response) mock_progress_callback = AsyncMock() - # Mock file operations - mock_open = MagicMock() - mock_file = MagicMock() - mock_open.return_value.__enter__.return_value = mock_file time_values = itertools.count(0, 0.1) fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)} with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \ patch('model_filemanager.check_file_exists', return_value=None), \ - patch('builtins.open', mock_open), \ patch('folder_paths.folder_names_and_paths', fake_paths), \ - patch('folder_paths.get_folder_paths', return_value=[temp_dir]), \ patch('time.time', side_effect=time_values): # Simulate time passing result = await download_model( @@ -105,10 +99,11 @@ async def test_download_model_success(temp_dir): DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False) ) - # Verify file writing - mock_file.write.assert_any_call(b'a' * 500) - mock_file.write.assert_any_call(b'b' * 300) - mock_file.write.assert_any_call(b'c' * 200) + mock_file_path = os.path.join(temp_dir, 'model.sft') + assert os.path.exists(mock_file_path) + with open(mock_file_path, 'rb') as mock_file: + assert mock_file.read() == b''.join(chunks) + os.remove(mock_file_path) # Verify request was made mock_make_request.assert_called_once_with('http://example.com/model.sft') @@ -192,15 +187,14 @@ async def test_download_model_invalid_folder_path(): mock_make_request = AsyncMock() mock_progress_callback = AsyncMock() - with patch('folder_paths.get_folder_paths', return_value=['valid_path']): - result = await download_model( - mock_make_request, - 'model.sft', - 'http://example.com/model.sft', - 'checkpoints', - 'invalid_path', - mock_progress_callback - ) + result = await download_model( + mock_make_request, + 'model.sft', + 'http://example.com/model.sft', + 'checkpoints', + 'invalid_path', + mock_progress_callback + ) # Assert the result assert isinstance(result, DownloadModelStatus) @@ -253,21 +247,28 @@ async def test_check_file_exists_when_file_does_not_exist(tmp_path): mock_callback.assert_not_called() @pytest.mark.asyncio -async def test_track_download_progress_no_content_length(): +async def test_track_download_progress_no_content_length(temp_dir): mock_response = AsyncMock(spec=aiohttp.ClientResponse) mock_response.headers = {} # No Content-Length header - mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500]) + chunks = [b'a' * 500, b'b' * 500] + mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks) mock_callback = AsyncMock() - mock_open = MagicMock(return_value=MagicMock()) - with patch('builtins.open', mock_open): - result = await track_download_progress( - mock_response, '/mock/path/model.sft', 'model.sft', - mock_callback, interval=0.1 - ) + full_path = os.path.join(temp_dir, 'model.sft') + + result = await track_download_progress( + mock_response, full_path, 'model.sft', + mock_callback, interval=0.1 + ) assert result.status == "completed" + + assert os.path.exists(full_path) + with open(full_path, 'rb') as f: + assert f.read() == b''.join(chunks) + os.remove(full_path) + # Check that progress was reported even without knowing the total size mock_callback.assert_any_call( 'model.sft', @@ -275,10 +276,11 @@ async def test_track_download_progress_no_content_length(): ) @pytest.mark.asyncio -async def test_track_download_progress_interval(): +async def test_track_download_progress_interval(temp_dir): mock_response = AsyncMock(spec=aiohttp.ClientResponse) mock_response.headers = {'Content-Length': '1000'} - mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10) + chunks = [b'a' * 100] * 10 + mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks) mock_callback = AsyncMock() mock_open = MagicMock(return_value=MagicMock()) @@ -287,18 +289,18 @@ async def test_track_download_progress_interval(): mock_time = MagicMock() mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks - with patch('builtins.open', mock_open), \ - patch('time.time', mock_time): + full_path = os.path.join(temp_dir, 'model.sft') + + with patch('time.time', mock_time): await track_download_progress( - mock_response, '/mock/path/model.sft', 'model.sft', + mock_response, full_path, 'model.sft', mock_callback, interval=1.0 ) - - # Print out the actual call count and the arguments of each call for debugging - print(f"mock_callback was called {mock_callback.call_count} times") - for i, call in enumerate(mock_callback.call_args_list): - args, kwargs = call - print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%") + + assert os.path.exists(full_path) + with open(full_path, 'rb') as f: + assert f.read() == b''.join(chunks) + os.remove(full_path) # Assert that progress was updated at least 3 times (start, at least one interval, and end) assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"