Skip to content

Commit

Permalink
fix file handling in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mcmonkey4eva committed Sep 22, 2024
1 parent 3ac444c commit 6fbceea
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 43 deletions.
8 changes: 3 additions & 5 deletions model_filemanager/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,8 @@ async def update_progress():
await progress_callback(model_name, status)
last_update_time = time.time()

if os.path.exists(file_path + '.tmp'):
os.remove(file_path + '.tmp')

with open(file_path + '.tmp', 'wb') as f:
temp_file_path = file_path + '.tmp'
with open(temp_file_path, 'wb') as f:
chunk_iterator = response.content.iter_chunked(8192)
while True:
try:
Expand All @@ -174,7 +172,7 @@ async def update_progress():
if time.time() - last_update_time >= interval:
await update_progress()

os.rename(file_path + '.tmp', file_path)
os.rename(temp_file_path, file_path)

await update_progress()

Expand Down
78 changes: 40 additions & 38 deletions tests-unit/prompt_server_test/download_models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,13 @@ async def test_download_model_success(temp_dir):
mock_make_request = AsyncMock(return_value=mock_response)
mock_progress_callback = AsyncMock()

# Mock file operations
mock_open = MagicMock()
mock_file = MagicMock()
mock_open.return_value.__enter__.return_value = mock_file
time_values = itertools.count(0, 0.1)

fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}

with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \
patch('model_filemanager.check_file_exists', return_value=None), \
patch('builtins.open', mock_open), \
patch('folder_paths.folder_names_and_paths', fake_paths), \
patch('folder_paths.get_folder_paths', return_value=[temp_dir]), \
patch('time.time', side_effect=time_values): # Simulate time passing

result = await download_model(
Expand Down Expand Up @@ -105,10 +99,11 @@ async def test_download_model_success(temp_dir):
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False)
)

# Verify file writing
mock_file.write.assert_any_call(b'a' * 500)
mock_file.write.assert_any_call(b'b' * 300)
mock_file.write.assert_any_call(b'c' * 200)
mock_file_path = os.path.join(temp_dir, 'model.sft')
assert os.path.exists(mock_file_path)
with open(mock_file_path, 'rb') as mock_file:
assert mock_file.read() == b''.join(chunks)
os.remove(mock_file_path)

# Verify request was made
mock_make_request.assert_called_once_with('http://example.com/model.sft')
Expand Down Expand Up @@ -192,15 +187,14 @@ async def test_download_model_invalid_folder_path():
mock_make_request = AsyncMock()
mock_progress_callback = AsyncMock()

with patch('folder_paths.get_folder_paths', return_value=['valid_path']):
result = await download_model(
mock_make_request,
'model.sft',
'http://example.com/model.sft',
'checkpoints',
'invalid_path',
mock_progress_callback
)
result = await download_model(
mock_make_request,
'model.sft',
'http://example.com/model.sft',
'checkpoints',
'invalid_path',
mock_progress_callback
)

# Assert the result
assert isinstance(result, DownloadModelStatus)
Expand Down Expand Up @@ -253,32 +247,40 @@ async def test_check_file_exists_when_file_does_not_exist(tmp_path):
mock_callback.assert_not_called()

@pytest.mark.asyncio
async def test_track_download_progress_no_content_length():
async def test_track_download_progress_no_content_length(temp_dir):
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {} # No Content-Length header
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500])
chunks = [b'a' * 500, b'b' * 500]
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)

mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock())

with patch('builtins.open', mock_open):
result = await track_download_progress(
mock_response, '/mock/path/model.sft', 'model.sft',
mock_callback, interval=0.1
)
full_path = os.path.join(temp_dir, 'model.sft')

result = await track_download_progress(
mock_response, full_path, 'model.sft',
mock_callback, interval=0.1
)

assert result.status == "completed"

assert os.path.exists(full_path)
with open(full_path, 'rb') as f:
assert f.read() == b''.join(chunks)
os.remove(full_path)

# Check that progress was reported even without knowing the total size
mock_callback.assert_any_call(
'model.sft',
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False)
)

@pytest.mark.asyncio
async def test_track_download_progress_interval():
async def test_track_download_progress_interval(temp_dir):
mock_response = AsyncMock(spec=aiohttp.ClientResponse)
mock_response.headers = {'Content-Length': '1000'}
mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10)
chunks = [b'a' * 100] * 10
mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks)

mock_callback = AsyncMock()
mock_open = MagicMock(return_value=MagicMock())
Expand All @@ -287,18 +289,18 @@ async def test_track_download_progress_interval():
mock_time = MagicMock()
mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks

with patch('builtins.open', mock_open), \
patch('time.time', mock_time):
full_path = os.path.join(temp_dir, 'model.sft')

with patch('time.time', mock_time):
await track_download_progress(
mock_response, '/mock/path/model.sft', 'model.sft',
mock_response, full_path, 'model.sft',
mock_callback, interval=1.0
)

# Print out the actual call count and the arguments of each call for debugging
print(f"mock_callback was called {mock_callback.call_count} times")
for i, call in enumerate(mock_callback.call_args_list):
args, kwargs = call
print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%")

assert os.path.exists(full_path)
with open(full_path, 'rb') as f:
assert f.read() == b''.join(chunks)
os.remove(full_path)

# Assert that progress was updated at least 3 times (start, at least one interval, and end)
assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}"
Expand Down

0 comments on commit 6fbceea

Please sign in to comment.