Skip to content

Commit

Permalink
clip
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Jul 31, 2023
1 parent 8118dc6 commit d1e20b8
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 18 deletions.
1 change: 1 addition & 0 deletions core/src/model/rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ type GenRewriteRule<Ctx> =
Box<dyn Fn(&Ctx, &TypedModel, &TypedNode) -> TractResult<Option<TypedModelPatch>>>;

#[derive(Default)]
#[allow(clippy::type_complexity)]
pub struct Rewriter<Ctx> {
rules: HashMap<TypeId, Vec<(Cow<'static, str>, GenRewriteRule<Ctx>)>>,
}
Expand Down
5 changes: 4 additions & 1 deletion test-rt/infra/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ where
failure_persistence: Some(Box::new(FileFailurePersistence::Off)),
..Config::default()
});
runner.run(&any_with::<A>(self.0.clone()), |v| Ok(v.run(runtime).unwrap()))?;
runner.run(&any_with::<A>(self.0.clone()), |v| {
v.run(runtime).unwrap();
Ok(())
})?;
Ok(())
}
}
6 changes: 2 additions & 4 deletions test-rt/suite-conv/src/conv_f32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,8 @@ impl ConvProblem {

// pytorch semantics diverge from onnx (and onnx are super weird)
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Pool.h#L48C2-L54C6
if *ceil {
if (out - 1) * stride >= input + l {
out -= 1;
}
if *ceil && (out - 1) * stride >= input + l {
out -= 1;
}
(out, *l)
})
Expand Down
2 changes: 1 addition & 1 deletion test-rt/suite-conv/src/conv_q.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use infra::{Test, TestSuite};
use proptest::collection::vec;
use proptest::prelude::*;
use proptest::*;
// use proptest::*;
use tract_core::internal::*;
use tract_core::ops::cnn::KernelFormat::*;
use tract_core::ops::cnn::{ConvUnary, KernelFormat, PaddingSpec, PoolSpec};
Expand Down
7 changes: 3 additions & 4 deletions tflite/src/ops/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fn de_concat(op: &mut DeserOp) -> TractResult<TVec<OutletId>> {
let axis =
if options.axis() < 0 { rank as i32 + options.axis() } else { options.axis() } as usize;
let dt = DatumType::super_type_for(op.facts()?.iter().map(|f| f.datum_type)).unwrap();
let inputs = wire_cast(&op.prefix, &mut op.ctx.target, &op.inputs, dt)?;
let inputs = wire_cast(op.prefix, op.ctx.target, op.inputs, dt)?;
let wires = op.ctx.target.wire_node(op.prefix, TypedConcat::new(axis), &inputs)?;
wire_fused_activation(op, &wires, &options.fused_activation_function())
}
Expand Down Expand Up @@ -91,8 +91,7 @@ fn de_squeeze(op: &mut DeserOp) -> TractResult<TVec<OutletId>> {
let rank = op.facts()?[0].rank();
for (ix, axis) in options.squeeze_dims().unwrap().iter().sorted().enumerate() {
let axis = if axis < 0 { rank as i32 + axis } else { axis } as usize;
wire =
op.ctx.target.wire_node(format!("{prefix}.{ix}"), AxisOp::Rm(axis as usize), &wire)?;
wire = op.ctx.target.wire_node(format!("{prefix}.{ix}"), AxisOp::Rm(axis), &wire)?;
}
Ok(wire)
}
Expand All @@ -107,7 +106,7 @@ fn de_strided_slice(op: &mut DeserOp) -> TractResult<TVec<OutletId>> {
optional_axes_input: None,
optional_steps_input: Some(3),
};
op.ctx.target.wire_node(op.prefix, slice, &op.inputs)
op.ctx.target.wire_node(op.prefix, slice, op.inputs)
}

fn de_transpose(op: &mut DeserOp) -> TractResult<TVec<OutletId>> {
Expand Down
1 change: 1 addition & 0 deletions tflite/src/rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ fn padding(
let fact = model.outlet_fact(node.inputs[0])?;
let shape = conv.pool_spec.data_format.shape(&fact.shape)?;
let actual = conv.pool_spec.computed_padding(shape.hw_dims());
#[allow(clippy::single_element_loop)]
for pad in [PaddingSpec::Valid /*, PaddingSpec::SameUpper*/] {
let found = pad.compute(
shape.hw_dims(),
Expand Down
20 changes: 12 additions & 8 deletions tflite/src/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ impl<'f, 'b> ModelBuilder<'f, 'b> {
let subgraph = subgraph.finish(model)?;
let subgraphs = vec![subgraph];
let subgraphs = self.builder.create_vector(&subgraphs);
let buffers = self.builder.create_vector(&mut self.buffers);
let buffers = self.builder.create_vector(self.buffers);
let operator_codes = self
.op_codes
.iter()
.map(|code| {
OperatorCode::create(
&mut self.builder,
self.builder,
&OperatorCodeArgs {
deprecated_builtin_code: code.deprecated_builtin_code,
custom_code: None,
Expand All @@ -51,7 +51,7 @@ impl<'f, 'b> ModelBuilder<'f, 'b> {
.collect_vec();
let operator_codes = self.builder.create_vector(&operator_codes);
let model = Model::create(
&mut self.builder,
self.builder,
&ModelArgs {
version: 3,
operator_codes: Some(operator_codes),
Expand Down Expand Up @@ -98,14 +98,18 @@ impl<'f, 'b, 'mb> SubgraphBuilder<'f, 'b, 'mb> {
where
'f: 'short,
{
&mut self.model.builder
self.model.builder
}

pub fn write_fact(&mut self, name: impl AsRef<str>, fact: impl Into<TypedFact>) -> TractResult<i32> {
pub fn write_fact(
&mut self,
name: impl AsRef<str>,
fact: impl Into<TypedFact>,
) -> TractResult<i32> {
let fact = fact.into();
let buffer = if let Some(k) = &fact.konst {
let data = self.fb().create_vector(unsafe { k.as_bytes() });
let buffer = Buffer::create(&mut self.fb(), &BufferArgs { data: Some(data) });
let buffer = Buffer::create(self.fb(), &BufferArgs { data: Some(data) });
self.model.buffers.push(buffer);
self.model.buffers.len() as u32 - 1
} else {
Expand Down Expand Up @@ -169,8 +173,8 @@ impl<'f, 'b, 'mb> SubgraphBuilder<'f, 'b, 'mb> {
builtin_options: WIPOffset<UnionWIPOffset>,
) -> TractResult<()> {
let opcode_index = self.model.operator_code_index(op);
let inputs = self.fb().create_vector(&inputs);
let outputs = self.fb().create_vector(&outputs);
let inputs = self.fb().create_vector(inputs);
let outputs = self.fb().create_vector(outputs);
let operator = Operator::create(
self.fb(),
&OperatorArgs {
Expand Down

0 comments on commit d1e20b8

Please sign in to comment.