-
Notifications
You must be signed in to change notification settings - Fork 0
/
viterbi.py
111 lines (91 loc) · 3.03 KB
/
viterbi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# -*- coding: utf-8 -*-
"""Viterbi.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1Yug27ixq5-fdzgxewo1HB5KfFI4EepYr
"""
x = int(input("Number of word to get input : "))
hash_bag = {}
words = []
for i in range(x):
val = 0
inp = input("Enter the word : ")
words.append(inp)
printer = [inp+ "-" + "start",inp+ "-" + "NP",inp+ "-" + "VP"]
for i in printer:
val = float(input("Enter " + i +" : "))
hash_bag[i] = val
print(hash_bag)
lister1 = ['start','NP','VP']
lister2 = ['NP','VP','end']
for i in lister1 :
for j in lister2 :
stry = i +"-"+j
hash_bag[stry] = float(input("Enter " + stry + " : "))
hash_bag
class Dnode :
def __init__(self,name,prev1,prev2):
self.name = name
self.prev1 = prev1
self.prev2 = prev2
self.preferred = None
def calculateValue(self,dicty,word):
self.hit = 1 if self.name == "start" or self.name == "end" else dicty[word+"-"+self.name]
if self.name == "start":
# Assume start
self.state_prob = 1
elif self.prev1.name == "start":
self.state_prob = dicty[self.prev1.name+"-"+self.name]
self.preferred = self.prev1
else:
# prev1
prev1_prob = dicty[self.prev1.name+"-"+self.name] * self.prev1.hit * self.prev1.state_prob
# prev1
prev2_prob = dicty[self.prev2.name+"-"+self.name] * self.prev2.hit * self.prev2.state_prob
if prev1_prob > prev2_prob:
self.state_prob = prev1_prob
self.preferred = self.prev1
else:
self.state_prob = prev2_prob
self.preferred = self.prev2
x = int(input("Enter no of words to find probs"))
l = [input(f"Enter word {i} : ") for i in range(0,x)]
nodes = []
startNode = Dnode(name="start",prev1=None,prev2=None)
startNode.calculateValue(hash_bag,None)
prevNP,prevVP = None , None
for i,val in enumerate(l):
if i == 0:
prevNP = Dnode(name="NP",prev1=startNode,prev2=None)
prevNP.calculateValue(hash_bag,val)
prevVP = Dnode(name="VP",prev1=startNode,prev2=None)
prevVP.calculateValue(hash_bag,val)
nodes.append(prevNP)
nodes.append(prevVP)
else:
a = Dnode(name="VP",prev1=prevNP,prev2=prevVP)
a.calculateValue(hash_bag,val)
b = Dnode(name="NP",prev1=prevNP,prev2=prevVP)
b.calculateValue(hash_bag,val)
prevVP = a
prevNP = b
# if i == 1:
# print(prevN)
endNode = Dnode("end",prevNP,prevVP)
endNode.calculateValue(hash_bag,None)
print("End node probability :",endNode.state_prob)
"""### BackTracking"""
a=endNode
finalOrder = []
finalOrderValues = []
while True:
finalOrder.append(a.name)
finalOrderValues.append(a.state_prob)
if a.name == "start":
break
else:
a = a.preferred
finalOrder = finalOrder[::-1]
finalOrderValues = finalOrderValues[::-1]
print ("Order :",finalOrder)
print ("Order values :",finalOrderValues)