-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: overload mul!
#230
feat: overload mul!
#230
Conversation
@nospecialize(C::TracedRArray{T1,2}), | ||
@nospecialize(A::TracedRArray{T2,2}), | ||
@nospecialize(B::TracedRArray{T3,2}), | ||
) where {T1,T2,T3} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth overriding the version with alpha and beta?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is ine too tho, obviously
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah we need to add that, else it does the for loop
835af61
to
54c7897
Compare
src/TracedRArray.jl
Outdated
if isone(α) | ||
C.mlir_data = res | ||
else | ||
C.mlir_data = (TracedRArray{T1,2}((), res, size(C)) .* α).mlir_data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doing a broadcast here probably won't optimize well for now (or rather we need to write optimizations on the batch op. Being able to driectly emit the elementwise mul of a broadcast will likely be a nontrivial perf win as a result
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, but probably worth not using broadcast internally for perf reasons below
partially addresses LuxDL/Lux.jl#1025