diff --git a/crates/burn-wgpu/src/compiler/wgsl/instructions.rs b/crates/burn-wgpu/src/compiler/wgsl/instructions.rs index 0d9319ac5..67f808178 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/instructions.rs @@ -249,6 +249,18 @@ impl Display for Instruction { } Instruction::Not { input, out } => f.write_fmt(format_args!("{out} = !{input};\n")), Instruction::Index { lhs, rhs, out } => { + if let Variable::Local { + index: _, + item: _, + scope_depth: _, + } = out + { + match rhs { + Variable::GlobalScalar(_, _, _) => todo!(), + _ => panic!("Only constant indexing is supported, got {:?}", rhs), + } + }; + let item = out.item(); f.write_fmt(format_args!("{out} = {item}({lhs}[{rhs}]);\n")) } @@ -379,57 +391,71 @@ for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{ f.write_str("}\n") } - Instruction::IndexAssign { lhs, rhs, out } => match lhs.item() { - Item::Vec4(elem) => { - let lhs0 = lhs.index(0); - let lhs1 = lhs.index(1); - let lhs2 = lhs.index(2); - let lhs3 = lhs.index(3); + Instruction::IndexAssign { lhs, rhs, out } => { + if let Variable::Local { + index: _, + item: _, + scope_depth: _, + } = out + { + match lhs { + Variable::GlobalScalar(_, _, _) => todo!(), + _ => panic!("Only constant indexing is supported, got {:?}", lhs), + } + }; - let rhs0 = rhs.index(0); - let rhs1 = rhs.index(1); - let rhs2 = rhs.index(2); - let rhs3 = rhs.index(3); + match lhs.item() { + Item::Vec4(elem) => { + let lhs0 = lhs.index(0); + let lhs1 = lhs.index(1); + let lhs2 = lhs.index(2); + let lhs3 = lhs.index(3); - f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?; - f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?; - f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))?; - f.write_fmt(format_args!("{out}[{lhs3}] = {elem}({rhs3});\n")) + let rhs0 = rhs.index(0); + let rhs1 = rhs.index(1); + let rhs2 = rhs.index(2); + let rhs3 = rhs.index(3); + + f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?; + f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?; + f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))?; + f.write_fmt(format_args!("{out}[{lhs3}] = {elem}({rhs3});\n")) + } + Item::Vec3(elem) => { + let lhs0 = lhs.index(0); + let lhs1 = lhs.index(1); + let lhs2 = lhs.index(2); + + let rhs0 = rhs.index(0); + let rhs1 = rhs.index(1); + let rhs2 = rhs.index(2); + + f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?; + f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?; + f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n")) + } + Item::Vec2(elem) => { + let lhs0 = lhs.index(0); + let lhs1 = lhs.index(1); + + let rhs0 = rhs.index(0); + let rhs1 = rhs.index(1); + + f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?; + f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n")) + } + Item::Scalar(_elem) => { + let elem_out = out.elem(); + let casting_type = match rhs.item() { + Item::Vec4(_) => Item::Vec4(elem_out), + Item::Vec3(_) => Item::Vec3(elem_out), + Item::Vec2(_) => Item::Vec2(elem_out), + Item::Scalar(_) => Item::Scalar(elem_out), + }; + f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n")) + } } - Item::Vec3(elem) => { - let lhs0 = lhs.index(0); - let lhs1 = lhs.index(1); - let lhs2 = lhs.index(2); - - let rhs0 = rhs.index(0); - let rhs1 = rhs.index(1); - let rhs2 = rhs.index(2); - - f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?; - f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?; - f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n")) - } - Item::Vec2(elem) => { - let lhs0 = lhs.index(0); - let lhs1 = lhs.index(1); - - let rhs0 = rhs.index(0); - let rhs1 = rhs.index(1); - - f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?; - f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n")) - } - Item::Scalar(_elem) => { - let elem_out = out.elem(); - let casting_type = match rhs.item() { - Item::Vec4(_) => Item::Vec4(elem_out), - Item::Vec3(_) => Item::Vec3(elem_out), - Item::Vec2(_) => Item::Vec2(elem_out), - Item::Scalar(_) => Item::Scalar(elem_out), - }; - f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n")) - } - }, + } Instruction::If { cond, instructions } => { f.write_fmt(format_args!("if {cond} {{\n"))?; for i in instructions {