diff --git a/project_milestone/models/project.py b/project_milestone/models/project.py index 4716e5d406..8a84ac9033 100644 --- a/project_milestone/models/project.py +++ b/project_milestone/models/project.py @@ -17,3 +17,25 @@ class Project(models.Model): def _onchange_use_milestones(self): if not self.use_milestones and self.milestones_required: self.milestones_required = False + + @api.returns("self", lambda value: value.id) + def copy(self, default=None): + project = super(Project, self).copy(default) + project._link_tasks_to_milestones() + return project + + def _link_tasks_to_milestones(self): + for task in self.with_context(active_test=False).task_ids.filtered( + "milestone_id" + ): + task.milestone_id = self._find_equivalent_milestone(task.milestone_id) + + def _find_equivalent_milestone(self, milestone): + return next( + ( + m + for m in self.with_context(active_test=False).milestone_ids + if m.name == milestone.name + ), + None, + ) diff --git a/project_milestone/tests/test_project_milestone.py b/project_milestone/tests/test_project_milestone.py index 7a63940faf..dafd4c9f7c 100644 --- a/project_milestone/tests/test_project_milestone.py +++ b/project_milestone/tests/test_project_milestone.py @@ -65,3 +65,12 @@ def test_sub_task(self): with Form(Task) as task: task.project_id = self.test_project task.name = "SubTask" + + def test_copy_project(self): + project = self.test_project.copy({}) + tasks = project.with_context(active_test=False).task_ids + milestone = project.milestone_ids.filtered( + lambda milestone: "2" not in milestone.name + ) + self.assertEqual(tasks[0].milestone_id, milestone) + self.assertEqual(tasks[1].milestone_id, milestone)