mirror of https://github.com/tracel-ai/burn.git
Force constant vec indexing
This commit is contained in:
parent
6c708527b9
commit
85264532d1
|
@ -249,6 +249,18 @@ impl Display for Instruction {
|
|||
}
|
||||
Instruction::Not { input, out } => f.write_fmt(format_args!("{out} = !{input};\n")),
|
||||
Instruction::Index { lhs, rhs, out } => {
|
||||
if let Variable::Local {
|
||||
index: _,
|
||||
item: _,
|
||||
scope_depth: _,
|
||||
} = out
|
||||
{
|
||||
match rhs {
|
||||
Variable::GlobalScalar(_, _, _) => todo!(),
|
||||
_ => panic!("Only constant indexing is supported, got {:?}", rhs),
|
||||
}
|
||||
};
|
||||
|
||||
let item = out.item();
|
||||
f.write_fmt(format_args!("{out} = {item}({lhs}[{rhs}]);\n"))
|
||||
}
|
||||
|
@ -379,57 +391,71 @@ for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{
|
|||
|
||||
f.write_str("}\n")
|
||||
}
|
||||
Instruction::IndexAssign { lhs, rhs, out } => match lhs.item() {
|
||||
Item::Vec4(elem) => {
|
||||
let lhs0 = lhs.index(0);
|
||||
let lhs1 = lhs.index(1);
|
||||
let lhs2 = lhs.index(2);
|
||||
let lhs3 = lhs.index(3);
|
||||
Instruction::IndexAssign { lhs, rhs, out } => {
|
||||
if let Variable::Local {
|
||||
index: _,
|
||||
item: _,
|
||||
scope_depth: _,
|
||||
} = out
|
||||
{
|
||||
match lhs {
|
||||
Variable::GlobalScalar(_, _, _) => todo!(),
|
||||
_ => panic!("Only constant indexing is supported, got {:?}", lhs),
|
||||
}
|
||||
};
|
||||
|
||||
let rhs0 = rhs.index(0);
|
||||
let rhs1 = rhs.index(1);
|
||||
let rhs2 = rhs.index(2);
|
||||
let rhs3 = rhs.index(3);
|
||||
match lhs.item() {
|
||||
Item::Vec4(elem) => {
|
||||
let lhs0 = lhs.index(0);
|
||||
let lhs1 = lhs.index(1);
|
||||
let lhs2 = lhs.index(2);
|
||||
let lhs3 = lhs.index(3);
|
||||
|
||||
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs3}] = {elem}({rhs3});\n"))
|
||||
let rhs0 = rhs.index(0);
|
||||
let rhs1 = rhs.index(1);
|
||||
let rhs2 = rhs.index(2);
|
||||
let rhs3 = rhs.index(3);
|
||||
|
||||
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs3}] = {elem}({rhs3});\n"))
|
||||
}
|
||||
Item::Vec3(elem) => {
|
||||
let lhs0 = lhs.index(0);
|
||||
let lhs1 = lhs.index(1);
|
||||
let lhs2 = lhs.index(2);
|
||||
|
||||
let rhs0 = rhs.index(0);
|
||||
let rhs1 = rhs.index(1);
|
||||
let rhs2 = rhs.index(2);
|
||||
|
||||
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))
|
||||
}
|
||||
Item::Vec2(elem) => {
|
||||
let lhs0 = lhs.index(0);
|
||||
let lhs1 = lhs.index(1);
|
||||
|
||||
let rhs0 = rhs.index(0);
|
||||
let rhs1 = rhs.index(1);
|
||||
|
||||
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))
|
||||
}
|
||||
Item::Scalar(_elem) => {
|
||||
let elem_out = out.elem();
|
||||
let casting_type = match rhs.item() {
|
||||
Item::Vec4(_) => Item::Vec4(elem_out),
|
||||
Item::Vec3(_) => Item::Vec3(elem_out),
|
||||
Item::Vec2(_) => Item::Vec2(elem_out),
|
||||
Item::Scalar(_) => Item::Scalar(elem_out),
|
||||
};
|
||||
f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n"))
|
||||
}
|
||||
}
|
||||
Item::Vec3(elem) => {
|
||||
let lhs0 = lhs.index(0);
|
||||
let lhs1 = lhs.index(1);
|
||||
let lhs2 = lhs.index(2);
|
||||
|
||||
let rhs0 = rhs.index(0);
|
||||
let rhs1 = rhs.index(1);
|
||||
let rhs2 = rhs.index(2);
|
||||
|
||||
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))
|
||||
}
|
||||
Item::Vec2(elem) => {
|
||||
let lhs0 = lhs.index(0);
|
||||
let lhs1 = lhs.index(1);
|
||||
|
||||
let rhs0 = rhs.index(0);
|
||||
let rhs1 = rhs.index(1);
|
||||
|
||||
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))
|
||||
}
|
||||
Item::Scalar(_elem) => {
|
||||
let elem_out = out.elem();
|
||||
let casting_type = match rhs.item() {
|
||||
Item::Vec4(_) => Item::Vec4(elem_out),
|
||||
Item::Vec3(_) => Item::Vec3(elem_out),
|
||||
Item::Vec2(_) => Item::Vec2(elem_out),
|
||||
Item::Scalar(_) => Item::Scalar(elem_out),
|
||||
};
|
||||
f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n"))
|
||||
}
|
||||
},
|
||||
}
|
||||
Instruction::If { cond, instructions } => {
|
||||
f.write_fmt(format_args!("if {cond} {{\n"))?;
|
||||
for i in instructions {
|
||||
|
|
Loading…
Reference in New Issue