Skip to content

Commit

Permalink
make trivial path unsafe again
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Oct 11, 2024
1 parent 7e6a457 commit 0398ad9
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 37 deletions.
1 change: 1 addition & 0 deletions core/src/ops/cnn/conv/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ impl Conv {
c_m_axis,
c_n_axis,
ops,
packing == 0 && self.group == 1,
)?,
&wires,
)
Expand Down
3 changes: 3 additions & 0 deletions core/src/ops/einsum/optimize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,8 @@ fn optimized_mat_mul(
let outputs =
mmms.iter().map(|(mmm, _packing)| unsafe { mmm.c_view(op.c_m(), op.c_n()) }).collect();
let (mmms, packings): (Vec<_>, Vec<_>) = mmms.into_iter().unzip();
let trivial_packing =
mmms.len() == 1 && packings[0] == 0 && patch.outlet_fact(a)?.opaque_fact.is_none();
let opt = OptMatMul::new(
mmms,
c_fact,
Expand All @@ -411,6 +413,7 @@ fn optimized_mat_mul(
ProtoFusedSpec::AddMatMul { geo, a: 0, b: 1, packings },
ProtoFusedSpec::Store(outputs),
],
trivial_packing,
)
.context("Creating OptMatMul")?;
let output = patch.wire_node(name, opt, &[a, b])?[0];
Expand Down
79 changes: 42 additions & 37 deletions core/src/ops/matmul/optimized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl ProtoFusedSpec {
AsInputValue::Owned(Box::new(PanelExtractInput {
format,
data: data.clone(),
to: a_packing.downcast_ref::<PackedFormat>().unwrap().clone()
to: a_packing.downcast_ref::<PackedFormat>().unwrap().clone(),
}))
} else {
panic!("Un-matchable input and output for weights {:?} -> {a_packing}", a);
Expand Down Expand Up @@ -142,49 +142,43 @@ impl ProtoFusedSpec {
&'t self,
inputs: &'t [TValue],
output: &mut Tensor,
mmm: &dyn MatMatMul,
_mmm: &dyn MatMatMul,
scenario: usize,
) -> FusedSpec<'t> {
let fs = match self {
ProtoFusedSpec::AddMatMul { a, b, packings, .. } => {
let a = &inputs[*a];
let b = &inputs[*b];
let a = a.as_slice::<Opaque>().unwrap()[0]
.downcast_ref::<Box<dyn MMMInputValue>>()
.unwrap();
let b = b.as_slice::<Opaque>().unwrap()[0]
.downcast_ref::<Box<dyn MMMInputValue>>()
.unwrap();
let (a_packing, b_packing) = &mmm.packings()[packings[scenario]];
let pa = if a_packing.same_as(a.format()) {
AsInputValue::Borrowed(&**a)
} else if a_packing.is::<PackedFormat>()
&& a_packing.r() == a.format().r()
&& a.is::<EagerPackedInput>()
&& a.format().is::<PackedBlockQuantFormat>()
ProtoFusedSpec::AddMatMul { a, b, packings, .. } => unsafe {
debug_assert!(inputs.get(*a).is_some());
debug_assert!(inputs.get(*b).is_some());
let a = inputs.get_unchecked(*a);
let b = inputs.get_unchecked(*b);
debug_assert!(a.datum_type().is_opaque());
debug_assert!(a.len() == 1);
debug_assert!(b.datum_type().is_opaque());
debug_assert!(b.len() == 1);
let a = a.as_slice_unchecked::<Opaque>().get_unchecked(0);
let b = b.as_slice_unchecked::<Opaque>().get_unchecked(0);
debug_assert!(a.is::<Box<dyn MMMInputValue>>());
debug_assert!(b.is::<Box<dyn MMMInputValue>>());
let a = a.downcast_ref::<Box<dyn MMMInputValue>>().unwrap_unchecked();
let b = b.downcast_ref::<Box<dyn MMMInputValue>>().unwrap_unchecked();
#[cfg(debug_assertions)]
{
let format = PanelExtractFormat {
pbqf: a.format().downcast_ref::<PackedBlockQuantFormat>().unwrap().clone(),
};
let data = a.downcast_ref::<EagerPackedInput>().unwrap();
AsInputValue::Owned(Box::new(PanelExtractInput {
format,
data: data.clone(),
to: a_packing.downcast_ref::<PackedFormat>().unwrap().clone()
}))
} else {
panic!("Un-matchable input and output for weights {:?} -> {a_packing}", a);
};
assert!(
b_packing.same_as(b.format())
|| (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
);
let (a_packing, b_packing) = &_mmm.packings()[packings[scenario]];
debug_assert!(
a_packing.same_as(a.format())
|| (a_packing.is::<PackedFormat>() && a_packing.r() == a.format().r())
);
debug_assert!(
b_packing.same_as(b.format())
|| (b_packing.is::<PackedFormat>() && b_packing.r() == b.format().r())
);
}
FusedSpec::AddMatMul {
a: pa,
a: AsInputValue::Borrowed(&**a),
b: AsInputValue::Borrowed(&**b),
packing: packings[scenario],
}
}
},
ProtoFusedSpec::BinScalar(v, op) => FusedSpec::BinScalar(&inputs[*v], *op),
ProtoFusedSpec::LeakyRelu(v) => FusedSpec::LeakyRelu(&inputs[*v]),
ProtoFusedSpec::BinPerRow(v, op, _) => {
Expand Down Expand Up @@ -302,6 +296,7 @@ pub struct OptMatMul {
pub mmm: Vec<Box<dyn MatMatMul>>,
pub c_m_axis: usize,
pub c_n_axis: usize,
pub trivial_packing: bool,
pub trivial_path: bool,
}

Expand Down Expand Up @@ -591,10 +586,19 @@ impl OptMatMul {
c_m_axis: usize,
c_n_axis: usize,
micro_ops: Vec<ProtoFusedSpec>,
trivial_packing: bool,
) -> TractResult<Self> {
ensure!(c_m_axis < c_fact.rank());
ensure!(c_n_axis < c_fact.rank());
let mut it = OptMatMul { mmm, c_fact, c_m_axis, c_n_axis, micro_ops, trivial_path: false };
let mut it = OptMatMul {
mmm,
c_fact,
c_m_axis,
c_n_axis,
micro_ops,
trivial_path: false,
trivial_packing,
};
it.update_trivial_path();
Ok(it)
}
Expand Down Expand Up @@ -626,6 +630,7 @@ impl OptMatMul {
.iter()
.enumerate()
.all(|(ax, dim)| ax == self.c_m_axis || ax == self.c_n_axis || dim.is_one())
&& self.trivial_packing
&& self.micro_ops.iter().all(|o| o.is_trivial())
}

Expand Down

0 comments on commit 0398ad9

Please sign in to comment.