Skip to content
This repository has been archived by the owner on Dec 1, 2021. It is now read-only.

Adding checker of hardware constrain for Conv #639

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions dlk/python/dlk/core/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,30 @@ def __init__(self,
# if kernel shape is not assigned, estimate kernel shape from input W's shape

def _check_consistency(self) -> None:
"""
This checks the following condition:
1. Kernel size must be 1x1 or 3x3.
2. Max input channel size allowed is 1024.
3. Input channel size is multiple of 32.
"""
super()._check_consistency()
if self.kernel_shape[0] != self.kernel_shape[1] or self.kernel_shape[0] not in (1, 3):
warnings.warn(warning_sign +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SInce this violation is fatal, it looks better to exit here if this constraint violation happens.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will use assert upon detecting this violation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit a35e4a8, but there are dlk test using 2x2 conv, will need to fix them.

f" Kernel size needs to be 1x1 or 3x3 but got "
f"{self.kernel_shape[0]}x{self.kernel_shape[1]} for {self.name} of {self.op_type}",
stacklevel=2)
if self.input_ops['X'].channel > 1024 or self.channel > 1024:
warnings.warn(warning_sign +
f" Input and output channel size need to be less than 1024, but got "
f"input: {self.input_ops['X'].channel} and output: {self.channel} "
f"for {self.name} of {self.op_type}",
stacklevel=2)
if self.input_ops['X'].channel % 32 != 0:
warnings.warn(warning_sign +
f" Input channel size need be multiple of 32, but got "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remembered correctly, If channels size is not multiple of 32, rest channels are zero padded, so it should work, though it may not be efficient. How about changing this message to "Input channel size should be multiple of 32"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Padding is happened only if the conv are quantized, but whether the conv are qunatized or not is not known in this time...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change the message, but indeed at the importing stage it is hard to tell whether the conv is quantized. Maybe by checking the quantizer operators around the conv could help?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about changing fp32 operator to pad automatically also? (You don't need to implement on this pull request, it's just a proposal of future direction.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding padding feature in fp32 conv later?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry overlooked, I didn't get this, do you mean we automatically pad the fp32 conv during importing?

f"{self.input_ops['X'].channel} for {self.name} of {self.op_type}",
stacklevel=2)

self._assert(len(self.shape) == self._num_dimensions + 2,
f'{self.name} has illegal shape {self.shape}')
self._assert(len(self.kernel_shape) == self._num_dimensions,
Expand Down