Skip to content

Commit

Permalink
Another round of bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
zgornel committed Sep 17, 2018
1 parent f92e467 commit 6c0a6e1
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
32 changes: 17 additions & 15 deletions src/conceptnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function init(::Type{ConceptNet})
DataDeps.unpack,
:multi),
("English",
"https://conceptnet.s3.amazonaws.com/downloads/2017/numberbatch/numberbatch-17.06.txt.gz",
"https://conceptnet.s3.amazonaws.com/downloads/2017/numberbatch/numberbatch-en-17.06.txt.gz",
"numberbatch-en-17.06.txt",
"72faf0a487c61b9a6a8c9ff0a1440d2f4936bb19102bddf27a833c2567620f2d",
DataDeps.unpack,
Expand Down Expand Up @@ -47,15 +47,10 @@ end

function _load_embeddings(::Type{<:ConceptNet}, embedding_file, max_vocab_size, keep_words)
local LL, indexed_words, index
if any(endswith.(embedding_file, [".h5", ".hdf5"]))
LL, indexed_words = _load_hdf5_embeddings(embedding_file,
max_vocab_size=max_vocab_size,
keep_words=keep_words)
else
LL, indexed_words = _load_txt_embeddings(embedding_file,
max_vocab_size=max_vocab_size,
keep_words=keep_words)
end
_loader = ifelse(any(endswith.(embedding_file, [".h5", ".hdf5"])),
_load_hdf5_embeddings,
_load_txt_embeddings)
LL, indexed_words = _loader(embedding_file, max_vocab_size, keep_words)
return LL, indexed_words
end

Expand All @@ -69,6 +64,7 @@ end


function _load_txt_embeddings(file::AbstractString, max_vocab_size, keep_words)
local LL, indexed_words
open(file, "r") do fid
vocab_size, vector_size = map(x->parse(Int,x), split(readline(fid)))
max_stored_vocab_size = _get_vocab_size(vocab_size, max_vocab_size)
Expand All @@ -84,18 +80,23 @@ function _load_txt_embeddings(file::AbstractString, max_vocab_size, keep_words)
return word, embedding
end

# TODO Improve performance of this bit
cnt = 0
indices = Int[]
for (index, row) in enumerate(data)
word, embedding = _parseline(row)
LL[:, index] = embedding
indexed_words[index] = word
if length(keep_words)==0 || word in keep_words
LL[:, index] = embedding
idexed_words[index] = word
push!(indices, index)
cnt+=1
if cnt > max_stored_vocab_size
if cnt > max_stored_vocab_size-1
break
end
end
end
LL = LL[:, indices]
indexed_words = indexed_words[indices]
end
return LL, indexed_words
end
Expand All @@ -106,14 +107,15 @@ function _load_hdf5_embeddings(file::AbstractString, max_vocab_size, keep_words)
payload = h5open(read, file)["mat"]
words = payload["axis1"]
vectors = payload["block0_values"]
max_vocab_size = _get_vocab_size(length(words), max_vocab_size)
max_stored_vocab_size = _get_vocab_size(length(words), max_vocab_size)

indices = Int[]
cnt = 0
for (index, word) in enumerate(words)
if length(keep_words)==0 || word in keep_words
push!(indices, index)
cnt+=1
if cnt > max_stored_vocab_size
if cnt > max_stored_vocab_size-1
break
end
end
Expand Down
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,9 @@ end
@test embs_mini.vocab == embs_full.vocab[1:100]

embs_specific = load_embeddings(ConceptNet{:compressed};
keep_words=Set(["red", "green", "blue"]))
keep_words=Set(["/c/en/red", "/c/en/green", "/c/en/blue"]))

@test size(embs_specific.embeddings) == (300, 3)
@test Set(embs_specific.vocab) == Set(["red", "green", "blue"])
@test Set(embs_specific.vocab) == Set(["/c/en/red", "/c/en/green", "/c/en/blue"])
end
end

0 comments on commit 6c0a6e1

Please sign in to comment.