mirror of https://github.com/tracel-ai/burn.git
Force constant vec indexing
This commit is contained in:
parent
6c708527b9
commit
85264532d1
|
@ -249,6 +249,18 @@ impl Display for Instruction {
|
||||||
}
|
}
|
||||||
Instruction::Not { input, out } => f.write_fmt(format_args!("{out} = !{input};\n")),
|
Instruction::Not { input, out } => f.write_fmt(format_args!("{out} = !{input};\n")),
|
||||||
Instruction::Index { lhs, rhs, out } => {
|
Instruction::Index { lhs, rhs, out } => {
|
||||||
|
if let Variable::Local {
|
||||||
|
index: _,
|
||||||
|
item: _,
|
||||||
|
scope_depth: _,
|
||||||
|
} = out
|
||||||
|
{
|
||||||
|
match rhs {
|
||||||
|
Variable::GlobalScalar(_, _, _) => todo!(),
|
||||||
|
_ => panic!("Only constant indexing is supported, got {:?}", rhs),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let item = out.item();
|
let item = out.item();
|
||||||
f.write_fmt(format_args!("{out} = {item}({lhs}[{rhs}]);\n"))
|
f.write_fmt(format_args!("{out} = {item}({lhs}[{rhs}]);\n"))
|
||||||
}
|
}
|
||||||
|
@ -379,57 +391,71 @@ for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{
|
||||||
|
|
||||||
f.write_str("}\n")
|
f.write_str("}\n")
|
||||||
}
|
}
|
||||||
Instruction::IndexAssign { lhs, rhs, out } => match lhs.item() {
|
Instruction::IndexAssign { lhs, rhs, out } => {
|
||||||
Item::Vec4(elem) => {
|
if let Variable::Local {
|
||||||
let lhs0 = lhs.index(0);
|
index: _,
|
||||||
let lhs1 = lhs.index(1);
|
item: _,
|
||||||
let lhs2 = lhs.index(2);
|
scope_depth: _,
|
||||||
let lhs3 = lhs.index(3);
|
} = out
|
||||||
|
{
|
||||||
|
match lhs {
|
||||||
|
Variable::GlobalScalar(_, _, _) => todo!(),
|
||||||
|
_ => panic!("Only constant indexing is supported, got {:?}", lhs),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let rhs0 = rhs.index(0);
|
match lhs.item() {
|
||||||
let rhs1 = rhs.index(1);
|
Item::Vec4(elem) => {
|
||||||
let rhs2 = rhs.index(2);
|
let lhs0 = lhs.index(0);
|
||||||
let rhs3 = rhs.index(3);
|
let lhs1 = lhs.index(1);
|
||||||
|
let lhs2 = lhs.index(2);
|
||||||
|
let lhs3 = lhs.index(3);
|
||||||
|
|
||||||
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
let rhs0 = rhs.index(0);
|
||||||
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?;
|
let rhs1 = rhs.index(1);
|
||||||
f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))?;
|
let rhs2 = rhs.index(2);
|
||||||
f.write_fmt(format_args!("{out}[{lhs3}] = {elem}({rhs3});\n"))
|
let rhs3 = rhs.index(3);
|
||||||
|
|
||||||
|
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
||||||
|
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?;
|
||||||
|
f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))?;
|
||||||
|
f.write_fmt(format_args!("{out}[{lhs3}] = {elem}({rhs3});\n"))
|
||||||
|
}
|
||||||
|
Item::Vec3(elem) => {
|
||||||
|
let lhs0 = lhs.index(0);
|
||||||
|
let lhs1 = lhs.index(1);
|
||||||
|
let lhs2 = lhs.index(2);
|
||||||
|
|
||||||
|
let rhs0 = rhs.index(0);
|
||||||
|
let rhs1 = rhs.index(1);
|
||||||
|
let rhs2 = rhs.index(2);
|
||||||
|
|
||||||
|
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
||||||
|
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?;
|
||||||
|
f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))
|
||||||
|
}
|
||||||
|
Item::Vec2(elem) => {
|
||||||
|
let lhs0 = lhs.index(0);
|
||||||
|
let lhs1 = lhs.index(1);
|
||||||
|
|
||||||
|
let rhs0 = rhs.index(0);
|
||||||
|
let rhs1 = rhs.index(1);
|
||||||
|
|
||||||
|
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
||||||
|
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))
|
||||||
|
}
|
||||||
|
Item::Scalar(_elem) => {
|
||||||
|
let elem_out = out.elem();
|
||||||
|
let casting_type = match rhs.item() {
|
||||||
|
Item::Vec4(_) => Item::Vec4(elem_out),
|
||||||
|
Item::Vec3(_) => Item::Vec3(elem_out),
|
||||||
|
Item::Vec2(_) => Item::Vec2(elem_out),
|
||||||
|
Item::Scalar(_) => Item::Scalar(elem_out),
|
||||||
|
};
|
||||||
|
f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Item::Vec3(elem) => {
|
}
|
||||||
let lhs0 = lhs.index(0);
|
|
||||||
let lhs1 = lhs.index(1);
|
|
||||||
let lhs2 = lhs.index(2);
|
|
||||||
|
|
||||||
let rhs0 = rhs.index(0);
|
|
||||||
let rhs1 = rhs.index(1);
|
|
||||||
let rhs2 = rhs.index(2);
|
|
||||||
|
|
||||||
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
|
||||||
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?;
|
|
||||||
f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))
|
|
||||||
}
|
|
||||||
Item::Vec2(elem) => {
|
|
||||||
let lhs0 = lhs.index(0);
|
|
||||||
let lhs1 = lhs.index(1);
|
|
||||||
|
|
||||||
let rhs0 = rhs.index(0);
|
|
||||||
let rhs1 = rhs.index(1);
|
|
||||||
|
|
||||||
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
|
||||||
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))
|
|
||||||
}
|
|
||||||
Item::Scalar(_elem) => {
|
|
||||||
let elem_out = out.elem();
|
|
||||||
let casting_type = match rhs.item() {
|
|
||||||
Item::Vec4(_) => Item::Vec4(elem_out),
|
|
||||||
Item::Vec3(_) => Item::Vec3(elem_out),
|
|
||||||
Item::Vec2(_) => Item::Vec2(elem_out),
|
|
||||||
Item::Scalar(_) => Item::Scalar(elem_out),
|
|
||||||
};
|
|
||||||
f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n"))
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Instruction::If { cond, instructions } => {
|
Instruction::If { cond, instructions } => {
|
||||||
f.write_fmt(format_args!("if {cond} {{\n"))?;
|
f.write_fmt(format_args!("if {cond} {{\n"))?;
|
||||||
for i in instructions {
|
for i in instructions {
|
||||||
|
|
Loading…
Reference in New Issue