Skip to content

Commit

Permalink
[naga wgsl-in] Automatic conversions for local var initializers.
Browse files Browse the repository at this point in the history
  • Loading branch information
jimblandy committed Nov 25, 2023
1 parent 9e8228c commit 3320637
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 62 deletions.
7 changes: 4 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ Passing an owned value `window` to `Surface` will return a `Surface<'static>`. S
- Introduce a new `Scalar` struct type for use in Naga's IR, and update all frontend, middle, and backend code appropriately. By @jimblandy in [#4673](https://github.com/gfx-rs/wgpu/pull/4673).
- Add more metal keywords. By @fornwall in [#4707](https://github.com/gfx-rs/wgpu/pull/4707).

- Add partial support for WGSL abstract types (@jimblandy in [#4743](https://github.com/gfx-rs/wgpu/pull/4743)).
- Add partial support for WGSL abstract types (@jimblandy in [#4743](https://github.com/gfx-rs/wgpu/pull/4743), [#4755](https://github.com/gfx-rs/wgpu/pull/4755)).

Abstract types make numeric literals easier to use, by
automatically converting literals and other constant expressions
Expand All @@ -115,9 +115,10 @@ Passing an owned value `window` to `Surface` will return a `Surface<'static>`. S
Even though the literals are abstract integers, Naga recognizes
that it is safe and necessary to convert them to `f32` values in
order to build the vector. You can also use abstract values as
initializers for global constants, like this:
initializers for global constants and global and local variables,
like this:

const unit_x: vec2<f32> = vec2(1, 0);
var unit_x: vec2<f32> = vec2(1, 0);

The literals `1` and `0` are abstract integers, and the expression
`vec2(1, 0)` is an abstract vector. However, Naga recognizes that
Expand Down
70 changes: 37 additions & 33 deletions naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1162,45 +1162,49 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
return Ok(());
}
ast::LocalDecl::Var(ref v) => {
let mut emitter = Emitter::default();
emitter.start(&ctx.function.expressions);

let initializer = match v.init {
Some(init) => Some(
self.expression(init, &mut ctx.as_expression(block, &mut emitter))?,
),
None => None,
};

let explicit_ty =
v.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_global()))
v.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx.as_global()))
.transpose()?;

let ty = match (explicit_ty, initializer) {
(Some(explicit), Some(initializer)) => {
let mut ctx = ctx.as_expression(block, &mut emitter);
let initializer_ty = resolve_inner!(ctx, initializer);
if !ctx.module.types[explicit]
.inner
.equivalent(initializer_ty, &ctx.module.types)
{
let gctx = &ctx.module.to_ctx();
return Err(Error::InitializationTypeMismatch {
let mut emitter = Emitter::default();
emitter.start(&ctx.function.expressions);
let mut ectx = ctx.as_expression(block, &mut emitter);

let ty;
let initializer;
match (v.init, explicit_ty) {
(Some(init), Some(explicit_ty)) => {
let init = self.expression_for_abstract(init, &mut ectx)?;
let ty_res = crate::proc::TypeResolution::Handle(explicit_ty);
let init = ectx
.try_automatic_conversions(init, &ty_res, v.name.span)
.map_err(|error| match error {
Error::AutoConversion {
dest_span: _,
dest_type,
source_span: _,
source_type,
} => Error::InitializationTypeMismatch {
name: v.name.span,
expected: explicit.to_wgsl(gctx),
got: initializer_ty.to_wgsl(gctx),
});
}
explicit
expected: dest_type,
got: source_type,
},
other => other,
})?;
ty = explicit_ty;
initializer = Some(init);
}
(Some(explicit), None) => explicit,
(None, Some(initializer)) => ctx
.as_expression(block, &mut emitter)
.register_type(initializer)?,
(None, None) => {
return Err(Error::MissingType(v.name.span));
(Some(init), None) => {
let concretized = self.expression(init, &mut ectx)?;
ty = ectx.register_type(concretized)?;
initializer = Some(concretized);
}
};
(None, Some(explicit_ty)) => {
ty = explicit_ty;
initializer = None;
}
(None, None) => return Err(Error::MissingType(v.name.span)),
}

let (const_initializer, initializer) = {
match initializer {
Expand Down
33 changes: 33 additions & 0 deletions naga/tests/in/abstract-types-var.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,36 @@ var<private> xafpaiai: array<i32, 2> = array(1, 2);
var<private> xafpaiaf: array<f32, 2> = array(1, 2.0);
var<private> xafpafai: array<f32, 2> = array(1.0, 2);
var<private> xafpafaf: array<f32, 2> = array(1.0, 2.0);

fn f() {
var xvipaiai: vec2<i32> = vec2(42, 43);
var xvupaiai: vec2<u32> = vec2(44, 45);
var xvfpaiai: vec2<f32> = vec2(46, 47);

var xvupuai: vec2<u32> = vec2(42u, 43);
var xvupaiu: vec2<u32> = vec2(42, 43u);

var xvuuai: vec2<u32> = vec2<u32>(42u, 43);
var xvuaiu: vec2<u32> = vec2<u32>(42, 43u);

var xmfpaiaiaiai: mat2x2<f32> = mat2x2(1, 2, 3, 4);
var xmfpafaiaiai: mat2x2<f32> = mat2x2(1.0, 2, 3, 4);
var xmfpaiafaiai: mat2x2<f32> = mat2x2(1, 2.0, 3, 4);
var xmfpaiaiafai: mat2x2<f32> = mat2x2(1, 2, 3.0, 4);
var xmfpaiaiaiaf: mat2x2<f32> = mat2x2(1, 2, 3, 4.0);

var xvispai: vec2<i32> = vec2(1);
var xvfspaf: vec2<f32> = vec2(1.0);
var xvis_ai: vec2<i32> = vec2<i32>(1);
var xvus_ai: vec2<u32> = vec2<u32>(1);
var xvfs_ai: vec2<f32> = vec2<f32>(1);
var xvfs_af: vec2<f32> = vec2<f32>(1.0);

var xafafaf: array<f32, 2> = array<f32, 2>(1.0, 2.0);
var xafaiai: array<f32, 2> = array<f32, 2>(1, 2);

var xafpaiai: array<i32, 2> = array(1, 2);
var xafpaiaf: array<f32, 2> = array(1, 2.0);
var xafpafai: array<f32, 2> = array(1.0, 2);
var xafpafaf: array<f32, 2> = array(1.0, 2.0);
}
28 changes: 28 additions & 0 deletions naga/tests/out/msl/abstract-types-var.msl
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,31 @@ struct type_5 {
struct type_7 {
int inner[2];
};

void f(
) {
metal::int2 xvipaiai = metal::int2(42, 43);
metal::uint2 xvupaiai = metal::uint2(44u, 45u);
metal::float2 xvfpaiai = metal::float2(46.0, 47.0);
metal::uint2 xvupuai = metal::uint2(42u, 43u);
metal::uint2 xvupaiu = metal::uint2(42u, 43u);
metal::uint2 xvuuai = metal::uint2(42u, 43u);
metal::uint2 xvuaiu = metal::uint2(42u, 43u);
metal::float2x2 xmfpaiaiaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
metal::float2x2 xmfpafaiaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
metal::float2x2 xmfpaiafaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
metal::float2x2 xmfpaiaiafai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
metal::float2x2 xmfpaiaiaiaf = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0));
metal::int2 xvispai = metal::int2(1);
metal::float2 xvfspaf = metal::float2(1.0);
metal::int2 xvis_ai = metal::int2(1);
metal::uint2 xvus_ai = metal::uint2(1u);
metal::float2 xvfs_ai = metal::float2(1.0);
metal::float2 xvfs_af = metal::float2(1.0);
type_5 xafafaf = type_5 {1.0, 2.0};
type_5 xafaiai = type_5 {1.0, 2.0};
type_7 xafpaiai = type_7 {1, 2};
type_5 xafpaiaf = type_5 {1.0, 2.0};
type_5 xafpafai = type_5 {1.0, 2.0};
type_5 xafpafaf = type_5 {1.0, 2.0};
}
41 changes: 39 additions & 2 deletions naga/tests/out/spv/abstract-types-var.spvasm
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 70
; Bound: 104
OpCapability Shader
OpCapability Linkage
%1 = OpExtInstImport "GLSL.std.450"
Expand Down Expand Up @@ -75,4 +75,41 @@ OpDecorate %12 ArrayStride 4
%65 = OpVariable %66 Private %39
%67 = OpVariable %63 Private %37
%68 = OpVariable %63 Private %37
%69 = OpVariable %63 Private %37
%69 = OpVariable %63 Private %37
%72 = OpTypeFunction %2
%74 = OpTypePointer Function %3
%76 = OpTypePointer Function %5
%78 = OpTypePointer Function %7
%84 = OpTypePointer Function %9
%96 = OpTypePointer Function %10
%99 = OpTypePointer Function %12
%71 = OpFunction %2 None %72
%70 = OpLabel
%101 = OpVariable %96 Function %37
%97 = OpVariable %96 Function %37
%93 = OpVariable %78 Function %34
%90 = OpVariable %78 Function %34
%87 = OpVariable %84 Function %31
%83 = OpVariable %84 Function %31
%80 = OpVariable %76 Function %24
%75 = OpVariable %76 Function %18
%100 = OpVariable %96 Function %37
%95 = OpVariable %96 Function %37
%92 = OpVariable %76 Function %36
%89 = OpVariable %74 Function %33
%86 = OpVariable %84 Function %31
%82 = OpVariable %76 Function %24
%79 = OpVariable %76 Function %24
%73 = OpVariable %74 Function %15
%102 = OpVariable %96 Function %37
%98 = OpVariable %99 Function %39
%94 = OpVariable %78 Function %34
%91 = OpVariable %74 Function %33
%88 = OpVariable %84 Function %31
%85 = OpVariable %84 Function %31
%81 = OpVariable %76 Function %24
%77 = OpVariable %78 Function %21
OpBranch %103
%103 = OpLabel
OpReturn
OpFunctionEnd
76 changes: 52 additions & 24 deletions naga/tests/out/wgsl/abstract-types-var.wgsl
Original file line number Diff line number Diff line change
@@ -1,25 +1,53 @@
var<private> xvipaiai: vec2<i32> = vec2<i32>(42, 43);
var<private> xvupaiai: vec2<u32> = vec2<u32>(44u, 45u);
var<private> xvfpaiai: vec2<f32> = vec2<f32>(46.0, 47.0);
var<private> xvupuai: vec2<u32> = vec2<u32>(42u, 43u);
var<private> xvupaiu: vec2<u32> = vec2<u32>(42u, 43u);
var<private> xvuuai: vec2<u32> = vec2<u32>(42u, 43u);
var<private> xvuaiu: vec2<u32> = vec2<u32>(42u, 43u);
var<private> xmfpaiaiaiai: mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
var<private> xmfpafaiaiai: mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
var<private> xmfpaiafaiai: mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
var<private> xmfpaiaiafai: mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
var<private> xmfpaiaiaiaf: mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
var<private> xvispai: vec2<i32> = vec2(1);
var<private> xvfspaf: vec2<f32> = vec2(1.0);
var<private> xvis_ai: vec2<i32> = vec2(1);
var<private> xvus_ai: vec2<u32> = vec2(1u);
var<private> xvfs_ai: vec2<f32> = vec2(1.0);
var<private> xvfs_af: vec2<f32> = vec2(1.0);
var<private> xafafaf: array<f32, 2> = array<f32, 2>(1.0, 2.0);
var<private> xafaiai: array<f32, 2> = array<f32, 2>(1.0, 2.0);
var<private> xafpaiai: array<i32, 2> = array<i32, 2>(1, 2);
var<private> xafpaiaf: array<f32, 2> = array<f32, 2>(1.0, 2.0);
var<private> xafpafai: array<f32, 2> = array<f32, 2>(1.0, 2.0);
var<private> xafpafaf: array<f32, 2> = array<f32, 2>(1.0, 2.0);
var<private> xvipaiai_1: vec2<i32> = vec2<i32>(42, 43);
var<private> xvupaiai_1: vec2<u32> = vec2<u32>(44u, 45u);
var<private> xvfpaiai_1: vec2<f32> = vec2<f32>(46.0, 47.0);
var<private> xvupuai_1: vec2<u32> = vec2<u32>(42u, 43u);
var<private> xvupaiu_1: vec2<u32> = vec2<u32>(42u, 43u);
var<private> xvuuai_1: vec2<u32> = vec2<u32>(42u, 43u);
var<private> xvuaiu_1: vec2<u32> = vec2<u32>(42u, 43u);
var<private> xmfpaiaiaiai_1: mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
var<private> xmfpafaiaiai_1: mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
var<private> xmfpaiafaiai_1: mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
var<private> xmfpaiaiafai_1: mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
var<private> xmfpaiaiaiaf_1: mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
var<private> xvispai_1: vec2<i32> = vec2(1);
var<private> xvfspaf_1: vec2<f32> = vec2(1.0);
var<private> xvis_ai_1: vec2<i32> = vec2(1);
var<private> xvus_ai_1: vec2<u32> = vec2(1u);
var<private> xvfs_ai_1: vec2<f32> = vec2(1.0);
var<private> xvfs_af_1: vec2<f32> = vec2(1.0);
var<private> xafafaf_1: array<f32, 2> = array<f32, 2>(1.0, 2.0);
var<private> xafaiai_1: array<f32, 2> = array<f32, 2>(1.0, 2.0);
var<private> xafpaiai_1: array<i32, 2> = array<i32, 2>(1, 2);
var<private> xafpaiaf_1: array<f32, 2> = array<f32, 2>(1.0, 2.0);
var<private> xafpafai_1: array<f32, 2> = array<f32, 2>(1.0, 2.0);
var<private> xafpafaf_1: array<f32, 2> = array<f32, 2>(1.0, 2.0);

fn f() {
var xvipaiai: vec2<i32> = vec2<i32>(42, 43);
var xvupaiai: vec2<u32> = vec2<u32>(44u, 45u);
var xvfpaiai: vec2<f32> = vec2<f32>(46.0, 47.0);
var xvupuai: vec2<u32> = vec2<u32>(42u, 43u);
var xvupaiu: vec2<u32> = vec2<u32>(42u, 43u);
var xvuuai: vec2<u32> = vec2<u32>(42u, 43u);
var xvuaiu: vec2<u32> = vec2<u32>(42u, 43u);
var xmfpaiaiaiai: mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
var xmfpafaiaiai: mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
var xmfpaiafaiai: mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
var xmfpaiaiafai: mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
var xmfpaiaiaiaf: mat2x2<f32> = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
var xvispai: vec2<i32> = vec2(1);
var xvfspaf: vec2<f32> = vec2(1.0);
var xvis_ai: vec2<i32> = vec2(1);
var xvus_ai: vec2<u32> = vec2(1u);
var xvfs_ai: vec2<f32> = vec2(1.0);
var xvfs_af: vec2<f32> = vec2(1.0);
var xafafaf: array<f32, 2> = array<f32, 2>(1.0, 2.0);
var xafaiai: array<f32, 2> = array<f32, 2>(1.0, 2.0);
var xafpaiai: array<i32, 2> = array<i32, 2>(1, 2);
var xafpaiaf: array<f32, 2> = array<f32, 2>(1.0, 2.0);
var xafpafai: array<f32, 2> = array<f32, 2>(1.0, 2.0);
var xafpafaf: array<f32, 2> = array<f32, 2>(1.0, 2.0);

}

0 comments on commit 3320637

Please sign in to comment.